Skip to content

Commit c711cd7

Browse files
committed
auth httpx
1 parent b949581 commit c711cd7

File tree

4 files changed

+558
-130
lines changed

4 files changed

+558
-130
lines changed

examples/clients/simple-auth-client/mcp_simple_auth_client/main.py

Lines changed: 39 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
"""
33
Simple MCP client example with OAuth authentication support.
44
5-
This client connects to an MCP server using streamable HTTP transport with OAuth authentication.
6-
It provides an interactive command-line interface to list tools and execute them.
5+
This client connects to an MCP server using streamable HTTP transport with OAuth.
6+
77
"""
88

99
import asyncio
@@ -21,16 +21,17 @@
2121
OAuthClientProvider,
2222
discover_oauth_metadata,
2323
)
24+
from mcp.client.oauth_auth import OAuthAuth
2425
from mcp.client.session import ClientSession
2526
from mcp.client.streamable_http import streamablehttp_client
2627
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
27-
from pydantic import AnyHttpUrl
2828

2929

3030
class CallbackHandler(BaseHTTPRequestHandler):
3131
"""Simple HTTP handler to capture OAuth callback."""
3232

3333
authorization_code = None
34+
state = None
3435
error = None
3536

3637
def do_GET(self):
@@ -40,6 +41,7 @@ def do_GET(self):
4041

4142
if "code" in query_params:
4243
CallbackHandler.authorization_code = query_params["code"][0]
44+
CallbackHandler.state = query_params.get("state", [None])[0]
4345
self.send_response(200)
4446
self.send_header("Content-type", "text/html")
4547
self.end_headers()
@@ -116,8 +118,11 @@ class JsonSerializableOAuthClientMetadata(OAuthClientMetadata):
116118
"""OAuth client metadata that handles JSON serialization properly."""
117119

118120
def model_dump(self, **kwargs) -> dict[str, Any]:
119-
"""Override to ensure URLs are serialized as strings."""
121+
"""Override to ensure URLs are serialized as strings and exclude null values."""
122+
# Exclude null values by default
123+
kwargs.setdefault("exclude_none", True)
120124
data = super().model_dump(**kwargs)
125+
121126
# Convert AnyHttpUrl objects to strings
122127
if "redirect_uris" in data:
123128
data["redirect_uris"] = [str(url) for url in data["redirect_uris"]]
@@ -193,9 +198,7 @@ async def tokens(self) -> OAuthToken | None:
193198

194199
async def save_tokens(self, tokens: OAuthToken) -> None:
195200
self._tokens = tokens
196-
print(
197-
f"Saved OAuth tokens, access token starts with: {tokens.access_token[:10]}..."
198-
)
201+
print(f"Saved OAuth tokens: {tokens.access_token[:10]}...")
199202

200203
async def redirect_to_authorization(self, authorization_url: str) -> None:
201204
# Start callback server
@@ -252,66 +255,41 @@ async def connect(self):
252255
"""Connect to the MCP server."""
253256
print(f"🔗 Attempting to connect to {self.server_url}...")
254257

255-
# The streamable HTTP transport will handle the OAuth flow automatically
256-
# We just need to wait for it to complete successfully
257258
try:
258-
# Discover OAuth metadata first to set proper scopes
259-
await self.auth_provider._discover_and_update_metadata()
260-
261-
# Check if we already have tokens, if not do auth flow first
262-
existing_tokens = await self.auth_provider.tokens()
263-
if not existing_tokens:
264-
print("🔐 No existing tokens found, initiating OAuth flow...")
265-
await self.auth_provider._discover_and_update_metadata()
266-
267-
# Start the auth flow to get tokens
268-
from mcp.client.auth import auth
269-
270-
auth_result = await auth(
271-
self.auth_provider, server_url=self.server_url.replace("/mcp", "")
272-
)
273-
274-
if auth_result == "REDIRECT":
275-
print("🔄 Waiting for OAuth completion...")
276-
# Wait for authorization code to be set by the redirect handler
277-
timeout = 300 # 5 minutes
278-
start_time = time.time()
279-
while (
280-
not self.auth_provider._authorization_code
281-
and time.time() - start_time < timeout
282-
):
283-
await asyncio.sleep(0.1)
284-
285-
if not self.auth_provider._authorization_code:
286-
raise Exception("Timeout waiting for OAuth authorization")
287-
288-
# Now exchange the authorization code for tokens
289-
auth_result = await auth(
290-
self.auth_provider,
291-
server_url=self.server_url.replace("/mcp", ""),
292-
authorization_code=self.auth_provider._authorization_code,
293-
)
294-
295-
if auth_result != "AUTHORIZED":
296-
raise Exception("Failed to authorize with server")
297-
298-
# Verify we have tokens now
299-
tokens = await self.auth_provider.tokens()
300-
if not tokens:
301-
raise Exception("OAuth completed but no tokens were saved")
259+
# Set up callback server
260+
callback_server = CallbackServer(port=3000)
261+
callback_server.start()
262+
263+
async def callback_handler() -> tuple[str, str | None]:
264+
"""Wait for OAuth callback and return auth code and state."""
265+
print("⏳ Waiting for authorization callback...")
266+
try:
267+
auth_code = callback_server.wait_for_callback(timeout=300)
268+
return auth_code, CallbackHandler.state
269+
finally:
270+
callback_server.stop()
271+
272+
# Create OAuth authentication handler using the new interface
273+
oauth_auth = OAuthAuth(
274+
server_url=self.server_url.replace("/mcp", ""),
275+
client_metadata=self.auth_provider.client_metadata,
276+
storage=None, # Use in-memory storage
277+
redirect_handler=None, # Use default (open browser)
278+
callback_handler=callback_handler,
279+
)
302280

303-
print(
304-
f"✅ OAuth authorization successful! Access token: {tokens.access_token[:20]}..."
305-
)
281+
# Initialize the auth handler and ensure we have tokens
306282

307-
# Create streamable HTTP transport with auth
283+
# Create streamable HTTP transport with auth handler
308284
stream_context = streamablehttp_client(
309285
url=self.server_url,
310-
auth_provider=self.auth_provider,
311-
timeout=timedelta(seconds=60), # Longer timeout for OAuth flow
286+
auth=oauth_auth,
287+
timeout=timedelta(seconds=60),
312288
)
313289

314-
print("📡 Opening transport connection...")
290+
print(
291+
"📡 Opening transport connection (HTTPX handles auth automatically)..."
292+
)
315293
async with stream_context as (read_stream, write_stream, get_session_id):
316294
print("🤝 Initializing MCP session...")
317295
async with ClientSession(read_stream, write_stream) as session:
@@ -365,7 +343,7 @@ async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None = Non
365343
print(f"\n🔧 Tool '{tool_name}' result:")
366344
if hasattr(result, "content"):
367345
for content in result.content:
368-
if hasattr(content, "text"):
346+
if content.type == "text":
369347
print(content.text)
370348
else:
371349
print(content)

0 commit comments

Comments
 (0)