Skip to content

Commit 78ba2c9

Browse files
committed
fix: Force client credentials flow for OAuth in e2e tests
1 parent 4e6ea99 commit 78ba2c9

File tree

1 file changed

+21
-5
lines changed

1 file changed

+21
-5
lines changed

e2e_tests/python/automated_oauth.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,19 @@ async def dummy_callback_handler() -> tuple[str, Optional[str]]:
9090
self.client_secret = client_secret
9191
self.authorization_server_url = authorization_server_url
9292

93+
async def async_auth_flow(self, request):
94+
"""Override the parent's auth flow to use client credentials only."""
95+
await self.perform_client_credentials_flow()
96+
97+
# Add the access token to the request
98+
if self.context.current_tokens and self.context.current_tokens.access_token:
99+
token_type = self.context.current_tokens.token_type or "Bearer"
100+
request.headers["Authorization"] = (
101+
f"{token_type} {self.context.current_tokens.access_token}"
102+
)
103+
104+
yield request
105+
93106
async def perform_client_credentials_flow(self) -> None:
94107
"""Performs the client credentials OAuth flow to obtain access tokens."""
95108
try:
@@ -221,7 +234,9 @@ async def create_transport(self):
221234
logging.debug(f"Connecting to OAuth-protected MCP server: {self.server_url}")
222235

223236
# Discover the required scope from the server
224-
scope, authorization_server_url = await self._discover_scope_and_auth_server(self.server_url)
237+
scope, authorization_server_url = await self._discover_scope_and_auth_server(
238+
self.server_url
239+
)
225240

226241
# Get OAuth client configuration (handled by mcp_clients.py)
227242
if not self.client_id or not self.client_secret:
@@ -230,9 +245,8 @@ async def create_transport(self):
230245
# Create client metadata
231246
client_metadata = OAuthClientMetadata(
232247
client_name=f"MCP Client - {self.name}",
233-
redirect_uris=["http://localhost"], # Required but not used in client credentials flow
234-
grant_types=["authorization_code"], # Required format, though we use client_credentials
235-
response_types=["code"], # Required for authorization_code grant type
248+
redirect_uris=["http://localhost"],
249+
grant_types=["client_credentials"],
236250
token_endpoint_auth_method="client_secret_post",
237251
scope=scope,
238252
)
@@ -336,6 +350,8 @@ async def _discover_scope_and_auth_server(self, server_url: str) -> tuple[str, s
336350
scope = " ".join(resource_metadata.scopes_supported)
337351

338352
logging.debug(f"Discovered scope: {scope}")
339-
logging.debug(f"Discovered authorization server: {authorization_server_url}")
353+
logging.debug(
354+
f"Discovered authorization server: {authorization_server_url}"
355+
)
340356

341357
return scope, authorization_server_url

0 commit comments

Comments
 (0)