2
2
"""
3
3
Simple MCP client example with OAuth authentication support.
4
4
5
- This client connects to an MCP server using streamable HTTP transport with OAuth authentication .
6
- It provides an interactive command-line interface to list tools and execute them.
5
+ This client connects to an MCP server using streamable HTTP transport with OAuth.
6
+
7
7
"""
8
8
9
9
import asyncio
21
21
OAuthClientProvider ,
22
22
discover_oauth_metadata ,
23
23
)
24
+ from mcp .client .oauth_auth import OAuthAuth
24
25
from mcp .client .session import ClientSession
25
26
from mcp .client .streamable_http import streamablehttp_client
26
27
from mcp .shared .auth import OAuthClientInformationFull , OAuthClientMetadata , OAuthToken
27
- from pydantic import AnyHttpUrl
28
28
29
29
30
30
class CallbackHandler (BaseHTTPRequestHandler ):
31
31
"""Simple HTTP handler to capture OAuth callback."""
32
32
33
33
authorization_code = None
34
+ state = None
34
35
error = None
35
36
36
37
def do_GET (self ):
@@ -40,6 +41,7 @@ def do_GET(self):
40
41
41
42
if "code" in query_params :
42
43
CallbackHandler .authorization_code = query_params ["code" ][0 ]
44
+ CallbackHandler .state = query_params .get ("state" , [None ])[0 ]
43
45
self .send_response (200 )
44
46
self .send_header ("Content-type" , "text/html" )
45
47
self .end_headers ()
@@ -116,8 +118,11 @@ class JsonSerializableOAuthClientMetadata(OAuthClientMetadata):
116
118
"""OAuth client metadata that handles JSON serialization properly."""
117
119
118
120
def model_dump (self , ** kwargs ) -> dict [str , Any ]:
119
- """Override to ensure URLs are serialized as strings."""
121
+ """Override to ensure URLs are serialized as strings and exclude null values."""
122
+ # Exclude null values by default
123
+ kwargs .setdefault ("exclude_none" , True )
120
124
data = super ().model_dump (** kwargs )
125
+
121
126
# Convert AnyHttpUrl objects to strings
122
127
if "redirect_uris" in data :
123
128
data ["redirect_uris" ] = [str (url ) for url in data ["redirect_uris" ]]
@@ -193,9 +198,7 @@ async def tokens(self) -> OAuthToken | None:
193
198
194
199
async def save_tokens (self , tokens : OAuthToken ) -> None :
195
200
self ._tokens = tokens
196
- print (
197
- f"Saved OAuth tokens, access token starts with: { tokens .access_token [:10 ]} ..."
198
- )
201
+ print (f"Saved OAuth tokens: { tokens .access_token [:10 ]} ..." )
199
202
200
203
async def redirect_to_authorization (self , authorization_url : str ) -> None :
201
204
# Start callback server
@@ -252,66 +255,41 @@ async def connect(self):
252
255
"""Connect to the MCP server."""
253
256
print (f"🔗 Attempting to connect to { self .server_url } ..." )
254
257
255
- # The streamable HTTP transport will handle the OAuth flow automatically
256
- # We just need to wait for it to complete successfully
257
258
try :
258
- # Discover OAuth metadata first to set proper scopes
259
- await self .auth_provider ._discover_and_update_metadata ()
260
-
261
- # Check if we already have tokens, if not do auth flow first
262
- existing_tokens = await self .auth_provider .tokens ()
263
- if not existing_tokens :
264
- print ("🔐 No existing tokens found, initiating OAuth flow..." )
265
- await self .auth_provider ._discover_and_update_metadata ()
266
-
267
- # Start the auth flow to get tokens
268
- from mcp .client .auth import auth
269
-
270
- auth_result = await auth (
271
- self .auth_provider , server_url = self .server_url .replace ("/mcp" , "" )
272
- )
273
-
274
- if auth_result == "REDIRECT" :
275
- print ("🔄 Waiting for OAuth completion..." )
276
- # Wait for authorization code to be set by the redirect handler
277
- timeout = 300 # 5 minutes
278
- start_time = time .time ()
279
- while (
280
- not self .auth_provider ._authorization_code
281
- and time .time () - start_time < timeout
282
- ):
283
- await asyncio .sleep (0.1 )
284
-
285
- if not self .auth_provider ._authorization_code :
286
- raise Exception ("Timeout waiting for OAuth authorization" )
287
-
288
- # Now exchange the authorization code for tokens
289
- auth_result = await auth (
290
- self .auth_provider ,
291
- server_url = self .server_url .replace ("/mcp" , "" ),
292
- authorization_code = self .auth_provider ._authorization_code ,
293
- )
294
-
295
- if auth_result != "AUTHORIZED" :
296
- raise Exception ("Failed to authorize with server" )
297
-
298
- # Verify we have tokens now
299
- tokens = await self .auth_provider .tokens ()
300
- if not tokens :
301
- raise Exception ("OAuth completed but no tokens were saved" )
259
+ # Set up callback server
260
+ callback_server = CallbackServer (port = 3000 )
261
+ callback_server .start ()
262
+
263
+ async def callback_handler () -> tuple [str , str | None ]:
264
+ """Wait for OAuth callback and return auth code and state."""
265
+ print ("⏳ Waiting for authorization callback..." )
266
+ try :
267
+ auth_code = callback_server .wait_for_callback (timeout = 300 )
268
+ return auth_code , CallbackHandler .state
269
+ finally :
270
+ callback_server .stop ()
271
+
272
+ # Create OAuth authentication handler using the new interface
273
+ oauth_auth = OAuthAuth (
274
+ server_url = self .server_url .replace ("/mcp" , "" ),
275
+ client_metadata = self .auth_provider .client_metadata ,
276
+ storage = None , # Use in-memory storage
277
+ redirect_handler = None , # Use default (open browser)
278
+ callback_handler = callback_handler ,
279
+ )
302
280
303
- print (
304
- f"✅ OAuth authorization successful! Access token: { tokens .access_token [:20 ]} ..."
305
- )
281
+ # Initialize the auth handler and ensure we have tokens
306
282
307
- # Create streamable HTTP transport with auth
283
+ # Create streamable HTTP transport with auth handler
308
284
stream_context = streamablehttp_client (
309
285
url = self .server_url ,
310
- auth_provider = self . auth_provider ,
311
- timeout = timedelta (seconds = 60 ), # Longer timeout for OAuth flow
286
+ auth = oauth_auth ,
287
+ timeout = timedelta (seconds = 60 ),
312
288
)
313
289
314
- print ("📡 Opening transport connection..." )
290
+ print (
291
+ "📡 Opening transport connection (HTTPX handles auth automatically)..."
292
+ )
315
293
async with stream_context as (read_stream , write_stream , get_session_id ):
316
294
print ("🤝 Initializing MCP session..." )
317
295
async with ClientSession (read_stream , write_stream ) as session :
@@ -365,7 +343,7 @@ async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None = Non
365
343
print (f"\n 🔧 Tool '{ tool_name } ' result:" )
366
344
if hasattr (result , "content" ):
367
345
for content in result .content :
368
- if hasattr ( content , "text" ) :
346
+ if content . type == "text" :
369
347
print (content .text )
370
348
else :
371
349
print (content )
0 commit comments