@@ -90,6 +90,19 @@ async def dummy_callback_handler() -> tuple[str, Optional[str]]:
9090 self .client_secret = client_secret
9191 self .authorization_server_url = authorization_server_url
9292
93+ async def async_auth_flow (self , request ):
94+ """Override the parent's auth flow to use client credentials only."""
95+ await self .perform_client_credentials_flow ()
96+
97+ # Add the access token to the request
98+ if self .context .current_tokens and self .context .current_tokens .access_token :
99+ token_type = self .context .current_tokens .token_type or "Bearer"
100+ request .headers ["Authorization" ] = (
101+ f"{ token_type } { self .context .current_tokens .access_token } "
102+ )
103+
104+ yield request
105+
93106 async def perform_client_credentials_flow (self ) -> None :
94107 """Performs the client credentials OAuth flow to obtain access tokens."""
95108 try :
@@ -221,7 +234,9 @@ async def create_transport(self):
221234 logging .debug (f"Connecting to OAuth-protected MCP server: { self .server_url } " )
222235
223236 # Discover the required scope from the server
224- scope , authorization_server_url = await self ._discover_scope_and_auth_server (self .server_url )
237+ scope , authorization_server_url = await self ._discover_scope_and_auth_server (
238+ self .server_url
239+ )
225240
226241 # Get OAuth client configuration (handled by mcp_clients.py)
227242 if not self .client_id or not self .client_secret :
@@ -230,9 +245,8 @@ async def create_transport(self):
230245 # Create client metadata
231246 client_metadata = OAuthClientMetadata (
232247 client_name = f"MCP Client - { self .name } " ,
233- redirect_uris = ["http://localhost" ], # Required but not used in client credentials flow
234- grant_types = ["authorization_code" ], # Required format, though we use client_credentials
235- response_types = ["code" ], # Required for authorization_code grant type
248+ redirect_uris = ["http://localhost" ],
249+ grant_types = ["client_credentials" ],
236250 token_endpoint_auth_method = "client_secret_post" ,
237251 scope = scope ,
238252 )
@@ -336,6 +350,8 @@ async def _discover_scope_and_auth_server(self, server_url: str) -> tuple[str, s
336350 scope = " " .join (resource_metadata .scopes_supported )
337351
338352 logging .debug (f"Discovered scope: { scope } " )
339- logging .debug (f"Discovered authorization server: { authorization_server_url } " )
353+ logging .debug (
354+ f"Discovered authorization server: { authorization_server_url } "
355+ )
340356
341357 return scope , authorization_server_url
0 commit comments