7
7
"""
8
8
9
9
import asyncio
10
- import json
11
10
import os
12
11
import threading
13
12
import time
14
- import webbrowser
15
13
from datetime import timedelta
16
14
from http .server import BaseHTTPRequestHandler , HTTPServer
17
15
from typing import Any
18
16
from urllib .parse import parse_qs , urlparse
19
17
20
- from mcp .client .auth import (
21
- OAuthClientProvider ,
22
- discover_oauth_metadata ,
23
- )
24
18
from mcp .client .oauth_auth import OAuthAuth
25
19
from mcp .client .session import ClientSession
26
20
from mcp .client .streamable_http import streamablehttp_client
27
- from mcp .shared .auth import OAuthClientInformationFull , OAuthClientMetadata , OAuthToken
21
+ from mcp .shared .auth import OAuthClientMetadata
28
22
29
23
30
24
class CallbackHandler (BaseHTTPRequestHandler ):
@@ -114,141 +108,14 @@ def wait_for_callback(self, timeout=300):
114
108
raise Exception ("Timeout waiting for OAuth callback" )
115
109
116
110
117
- class JsonSerializableOAuthClientMetadata (OAuthClientMetadata ):
118
- """OAuth client metadata that handles JSON serialization properly."""
119
-
120
- def model_dump (self , ** kwargs ) -> dict [str , Any ]:
121
- """Override to ensure URLs are serialized as strings and exclude null values."""
122
- # Exclude null values by default
123
- kwargs .setdefault ("exclude_none" , True )
124
- data = super ().model_dump (** kwargs )
125
-
126
- # Convert AnyHttpUrl objects to strings
127
- if "redirect_uris" in data :
128
- data ["redirect_uris" ] = [str (url ) for url in data ["redirect_uris" ]]
129
-
130
- # Debug: print what we're sending
131
- print (f"🐛 Client metadata being sent: { json .dumps (data , indent = 2 )} " )
132
- return data
133
-
134
-
135
- class SimpleOAuthProvider (OAuthClientProvider ):
136
- """Simple OAuth client provider for demonstration purposes."""
137
-
138
- def __init__ (self , server_url : str , callback_port : int = 3000 ):
139
- self ._callback_port = callback_port
140
- self ._redirect_uri = f"http://localhost:{ callback_port } /callback"
141
- self ._server_url = server_url
142
- self ._callback_server = None
143
- print (f"🐛 OAuth provider initialized with redirect URI: { self ._redirect_uri } " )
144
- # Store the raw data for easy serialization - scope will be updated dynamically
145
- self ._client_metadata_dict = {
146
- "client_name" : "Simple Auth Client" ,
147
- "redirect_uris" : [self ._redirect_uri ],
148
- "grant_types" : ["authorization_code" , "refresh_token" ],
149
- "response_types" : ["code" ],
150
- "token_endpoint_auth_method" : "client_secret_post" , # Use client secret
151
- "scope" : "read" , # Default scope, will be updated
152
- }
153
- self ._client_info : OAuthClientInformationFull | None = None
154
- self ._tokens : OAuthToken | None = None
155
- self ._code_verifier : str | None = None
156
- self ._authorization_code : str | None = None
157
- self ._metadata_discovered = False
158
-
159
- @property
160
- def redirect_url (self ) -> str :
161
- return self ._redirect_uri
162
-
163
- async def _discover_and_update_metadata (self ):
164
- """Discover server OAuth metadata and update client scope accordingly."""
165
- if self ._metadata_discovered :
166
- return
167
-
168
- try :
169
- print ("🐛 Discovering OAuth metadata..." )
170
- metadata = await discover_oauth_metadata (self ._server_url )
171
- if metadata and metadata .scopes_supported :
172
- scope = " " .join (metadata .scopes_supported )
173
- self ._client_metadata_dict ["scope" ] = scope
174
- print (f"🐛 Updated scope to: { scope } " )
175
- self ._metadata_discovered = True
176
- except Exception as e :
177
- print (f"🐛 Failed to discover metadata: { e } , using default scope" )
178
- self ._metadata_discovered = True
179
-
180
- @property
181
- def client_metadata (self ) -> OAuthClientMetadata :
182
- # Create a fresh instance each time using our custom serializable version
183
- return JsonSerializableOAuthClientMetadata .model_validate (
184
- self ._client_metadata_dict
185
- )
186
-
187
- async def client_information (self ) -> OAuthClientInformationFull | None :
188
- return self ._client_info
189
-
190
- async def save_client_information (
191
- self , client_information : OAuthClientInformationFull
192
- ) -> None :
193
- self ._client_info = client_information
194
- print (f"Saved client information: { client_information .client_id } " )
195
-
196
- async def tokens (self ) -> OAuthToken | None :
197
- return self ._tokens
198
-
199
- async def save_tokens (self , tokens : OAuthToken ) -> None :
200
- self ._tokens = tokens
201
- print (f"Saved OAuth tokens: { tokens .access_token [:10 ]} ..." )
202
-
203
- async def redirect_to_authorization (self , authorization_url : str ) -> None :
204
- # Start callback server
205
- self ._callback_server = CallbackServer (self ._callback_port )
206
- self ._callback_server .start ()
207
-
208
- print ("\n 🌐 Opening authorization URL in your default browser..." )
209
- print (f"URL: { authorization_url } " )
210
- webbrowser .open (authorization_url )
211
-
212
- print ("⏳ Waiting for authorization callback..." )
213
- print ("(Complete the authorization in your browser)" )
214
-
215
- try :
216
- # Wait for the callback with authorization code
217
- authorization_code = self ._callback_server .wait_for_callback (timeout = 300 )
218
- print (f"✅ Received authorization code: { authorization_code [:20 ]} ..." )
219
-
220
- # Store the authorization code so auth() can handle token exchange
221
- self ._authorization_code = authorization_code
222
- print ("🎉 OAuth callback received successfully!" )
223
-
224
- except Exception as e :
225
- print (f"❌ OAuth flow failed: { e } " )
226
- raise
227
- finally :
228
- # Always stop the callback server
229
- if self ._callback_server :
230
- self ._callback_server .stop ()
231
- self ._callback_server = None
232
-
233
- async def save_code_verifier (self , code_verifier : str ) -> None :
234
- self ._code_verifier = code_verifier
235
-
236
- async def code_verifier (self ) -> str :
237
- if self ._code_verifier is None :
238
- raise ValueError ("No code verifier available" )
239
- return self ._code_verifier
240
-
241
-
242
111
class SimpleAuthClient :
243
112
"""Simple MCP client with auth support."""
244
113
245
114
def __init__ (self , server_url : str ):
246
115
self .server_url = server_url
247
116
# Extract base URL for auth server (remove /mcp endpoint for auth endpoints)
248
- auth_server_url = server_url .replace ("/mcp" , "" )
249
117
# Use default redirect URI - this is where the auth server will redirect the user
250
118
# The user will need to copy the authorization code from this callback URL
251
- self .auth_provider = SimpleOAuthProvider (auth_server_url )
252
119
self .session : ClientSession | None = None
253
120
254
121
async def connect (self ):
@@ -269,10 +136,21 @@ async def callback_handler() -> tuple[str, str | None]:
269
136
finally :
270
137
callback_server .stop ()
271
138
139
+ client_metadata_dict = {
140
+ "client_name" : "Simple Auth Client" ,
141
+ "redirect_uris" : ["http://localhost:3000/callback" ],
142
+ "grant_types" : ["authorization_code" , "refresh_token" ],
143
+ "response_types" : ["code" ],
144
+ "token_endpoint_auth_method" : "client_secret_post" , # Use client secret
145
+ "scope" : "read" , # Default scope, will be updated
146
+ }
147
+
272
148
# Create OAuth authentication handler using the new interface
273
149
oauth_auth = OAuthAuth (
274
150
server_url = self .server_url .replace ("/mcp" , "" ),
275
- client_metadata = self .auth_provider .client_metadata ,
151
+ client_metadata = OAuthClientMetadata .model_validate (
152
+ client_metadata_dict
153
+ ),
276
154
storage = None , # Use in-memory storage
277
155
redirect_handler = None , # Use default (open browser)
278
156
callback_handler = callback_handler ,
0 commit comments