Skip to content

Commit 4e6ea99

Browse files
committed
feat: Migrate Python e2e tests to Strands Agents SDK
1 parent 182510f commit 4e6ea99

File tree

13 files changed

+650
-1513
lines changed

13 files changed

+650
-1513
lines changed
Lines changed: 341 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,341 @@
1+
"""
2+
Automated OAuth server client for MCP servers requiring OAuth authentication.
3+
4+
This client handles OAuth using client credentials grant (machine-to-machine authentication)
5+
without browser interaction.
6+
"""
7+
8+
import logging
9+
import time
10+
from datetime import timedelta
11+
from typing import Optional
12+
13+
import httpx
14+
15+
from mcp.client.auth import OAuthClientProvider, TokenStorage
16+
from mcp.client.streamable_http import streamablehttp_client, MCP_PROTOCOL_VERSION
17+
from mcp.shared.auth import (
18+
OAuthClientInformationFull,
19+
OAuthClientMetadata,
20+
OAuthMetadata,
21+
OAuthToken,
22+
ProtectedResourceMetadata,
23+
)
24+
from mcp.types import LATEST_PROTOCOL_VERSION
25+
26+
27+
class InMemoryTokenStorage(TokenStorage):
28+
"""
29+
Simple in-memory token storage implementation for automated OAuth.
30+
In production, you should persist tokens securely. However, for
31+
the demo chatbot, it's ok to ask the user to re-authenticate
32+
in the browser each time they run the chatbot.
33+
"""
34+
35+
def __init__(self):
36+
self._client_info: Optional[OAuthClientInformationFull] = None
37+
self._token: Optional[OAuthToken] = None
38+
39+
async def get_client_info(self) -> Optional[OAuthClientInformationFull]:
40+
return self._client_info
41+
42+
async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
43+
self._client_info = client_info
44+
45+
async def get_token(self) -> Optional[OAuthToken]:
46+
return self._token
47+
48+
async def set_token(self, token: OAuthToken) -> None:
49+
self._token = token
50+
51+
async def clear_token(self) -> None:
52+
self._token = None
53+
54+
55+
class AutomatedOAuthClientProvider(OAuthClientProvider):
56+
"""
57+
OAuth client provider for automated (client credentials) OAuth flows.
58+
This provider handles machine-to-machine authentication without user interaction.
59+
"""
60+
61+
def __init__(
62+
self,
63+
server_url: str,
64+
client_metadata: OAuthClientMetadata,
65+
storage: TokenStorage,
66+
client_id: str,
67+
client_secret: str,
68+
authorization_server_url: str,
69+
):
70+
# Create dummy handlers since they won't be used in client credentials flow
71+
async def dummy_redirect_handler(url: str) -> None:
72+
raise RuntimeError(
73+
"Redirect handler should not be called in automated OAuth flow"
74+
)
75+
76+
async def dummy_callback_handler() -> tuple[str, Optional[str]]:
77+
raise RuntimeError(
78+
"Callback handler should not be called in automated OAuth flow"
79+
)
80+
81+
super().__init__(
82+
server_url=server_url,
83+
client_metadata=client_metadata,
84+
storage=storage,
85+
redirect_handler=dummy_redirect_handler,
86+
callback_handler=dummy_callback_handler,
87+
)
88+
89+
self.client_id = client_id
90+
self.client_secret = client_secret
91+
self.authorization_server_url = authorization_server_url
92+
93+
async def perform_client_credentials_flow(self) -> None:
94+
"""Performs the client credentials OAuth flow to obtain access tokens."""
95+
try:
96+
# Check if we already have valid tokens
97+
current_tokens = await self.context.storage.get_token()
98+
if current_tokens and current_tokens.access_token:
99+
self.context.current_tokens = current_tokens
100+
self.context.update_token_expiry(current_tokens)
101+
if self.context.is_token_valid():
102+
logging.debug("Using existing valid access token")
103+
return
104+
105+
logging.debug("Performing client credentials flow...")
106+
107+
# Set auth server URL and discover OAuth metadata using upstream logic
108+
self.context.auth_server_url = self.authorization_server_url
109+
await self._discover_oauth_metadata()
110+
111+
if (
112+
not self.context.oauth_metadata
113+
or not self.context.oauth_metadata.token_endpoint
114+
):
115+
raise RuntimeError("No token endpoint found in OAuth metadata")
116+
117+
# Create client info and store it
118+
client_info = OAuthClientInformationFull(
119+
client_id=self.client_id,
120+
client_secret=self.client_secret,
121+
client_id_issued_at=int(time.time()),
122+
**self.context.client_metadata.model_dump(exclude_unset=True),
123+
)
124+
await self.context.storage.set_client_info(client_info)
125+
self.context.client_info = client_info
126+
127+
# Perform client credentials token request
128+
token_data = {
129+
"grant_type": "client_credentials",
130+
"client_id": self.client_id,
131+
"client_secret": self.client_secret,
132+
}
133+
134+
# Add scope if specified
135+
if self.context.client_metadata.scope:
136+
token_data["scope"] = self.context.client_metadata.scope
137+
138+
logging.debug(
139+
f"Making token request to: {self.context.oauth_metadata.token_endpoint}"
140+
)
141+
142+
async with httpx.AsyncClient() as client:
143+
response = await client.post(
144+
str(self.context.oauth_metadata.token_endpoint),
145+
data=token_data,
146+
headers={"Content-Type": "application/x-www-form-urlencoded"},
147+
)
148+
149+
if response.status_code != 200:
150+
error_text = response.text
151+
raise RuntimeError(
152+
f"Token request failed: HTTP {response.status_code} - {error_text}"
153+
)
154+
155+
token_response = response.json()
156+
157+
# Create and store tokens
158+
tokens = OAuthToken(
159+
access_token=token_response["access_token"],
160+
token_type=token_response.get("token_type", "Bearer"),
161+
expires_in=token_response.get("expires_in"),
162+
refresh_token=token_response.get("refresh_token"),
163+
scope=token_response.get("scope"),
164+
)
165+
166+
await self.context.storage.set_token(tokens)
167+
self.context.current_tokens = tokens
168+
self.context.update_token_expiry(tokens)
169+
170+
logging.debug(
171+
"Successfully obtained access token via client credentials flow"
172+
)
173+
174+
except Exception as error:
175+
logging.error(f"Client credentials flow failed: {error}")
176+
raise
177+
178+
async def _discover_oauth_metadata(self) -> None:
179+
"""Discover OAuth metadata using upstream MCP SDK discovery logic."""
180+
# Use the inherited _get_discovery_urls method from parent class
181+
discovery_urls = self._get_discovery_urls()
182+
183+
async with httpx.AsyncClient() as client:
184+
for metadata_url in discovery_urls:
185+
try:
186+
response = await client.get(metadata_url, follow_redirects=True)
187+
if response.status_code == 200:
188+
metadata = OAuthMetadata.model_validate_json(response.content)
189+
self.context.oauth_metadata = metadata
190+
logging.debug(
191+
f"Successfully discovered OAuth metadata from: {metadata_url}"
192+
)
193+
return
194+
except Exception as e:
195+
logging.debug(
196+
f"Failed to discover OAuth metadata from {metadata_url}: {e}"
197+
)
198+
continue
199+
200+
raise RuntimeError(
201+
"Failed to discover OAuth metadata from any well-known endpoint"
202+
)
203+
204+
205+
class AutomatedOAuthClient:
206+
"""
207+
Manages OAuth authentication for MCP servers requiring automated OAuth.
208+
209+
This client handles OAuth using client credentials grant (machine-to-machine authentication)
210+
without browser interaction.
211+
"""
212+
213+
def __init__(self, name: str, server_url: str, client_id: str, client_secret: str):
214+
self.name = name
215+
self.server_url = server_url
216+
self.client_id = client_id
217+
self.client_secret = client_secret
218+
219+
async def create_transport(self):
220+
"""Create OAuth-authenticated transport for MCP communication."""
221+
logging.debug(f"Connecting to OAuth-protected MCP server: {self.server_url}")
222+
223+
# Discover the required scope from the server
224+
scope, authorization_server_url = await self._discover_scope_and_auth_server(self.server_url)
225+
226+
# Get OAuth client configuration (handled by mcp_clients.py)
227+
if not self.client_id or not self.client_secret:
228+
raise ValueError("client_id and client_secret must be provided")
229+
230+
# Create client metadata
231+
client_metadata = OAuthClientMetadata(
232+
client_name=f"MCP Client - {self.name}",
233+
redirect_uris=["http://localhost"], # Required but not used in client credentials flow
234+
grant_types=["authorization_code"], # Required format, though we use client_credentials
235+
response_types=["code"], # Required for authorization_code grant type
236+
token_endpoint_auth_method="client_secret_post",
237+
scope=scope,
238+
)
239+
240+
# Create storage and OAuth provider
241+
storage = InMemoryTokenStorage()
242+
oauth_provider = AutomatedOAuthClientProvider(
243+
server_url=self.server_url,
244+
client_metadata=client_metadata,
245+
storage=storage,
246+
client_id=self.client_id,
247+
client_secret=self.client_secret,
248+
authorization_server_url=authorization_server_url,
249+
)
250+
251+
# Perform client credentials flow
252+
logging.debug("Starting automated OAuth flow...")
253+
await oauth_provider.perform_client_credentials_flow()
254+
255+
# Create transport with OAuth provider
256+
logging.debug("Creating transport with automated OAuth provider...")
257+
return streamablehttp_client(
258+
url=self.server_url,
259+
auth=oauth_provider,
260+
timeout=timedelta(seconds=60),
261+
)
262+
263+
async def _discover_scope_and_auth_server(self, server_url: str) -> tuple[str, str]:
264+
"""Discovers the required scope and authorization server from OAuth protected resource metadata."""
265+
logging.debug("Making initial request to discover OAuth metadata...")
266+
267+
async with httpx.AsyncClient() as client:
268+
headers = {MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION}
269+
response = await client.post(
270+
server_url,
271+
headers=headers,
272+
follow_redirects=True,
273+
json={"jsonrpc": "2.0", "method": "ping", "id": 1},
274+
)
275+
276+
if response.status_code != 401:
277+
raise RuntimeError(
278+
f"Expected 401 response for OAuth discovery, got {response.status_code}"
279+
)
280+
281+
# Extract resource metadata URL from WWW-Authenticate header
282+
www_auth_header = response.headers.get("WWW-Authenticate")
283+
resource_metadata_url = None
284+
285+
if www_auth_header:
286+
# Simple extraction of resource_metadata URL
287+
import re
288+
289+
pattern = r'resource_metadata=(?:"([^"]+)"|([^\s,]+))'
290+
match = re.search(pattern, www_auth_header)
291+
if match:
292+
resource_metadata_url = match.group(1) or match.group(2)
293+
294+
if not resource_metadata_url:
295+
# Fallback to well-known discovery
296+
from urllib.parse import urlparse, urljoin
297+
298+
parsed = urlparse(server_url)
299+
base_url = f"{parsed.scheme}://{parsed.netloc}"
300+
resource_metadata_url = urljoin(
301+
base_url, "/.well-known/oauth-protected-resource"
302+
)
303+
304+
logging.debug(f"Discovered resource metadata URL: {resource_metadata_url}")
305+
306+
# Fetch protected resource metadata
307+
logging.debug("Fetching OAuth protected resource metadata...")
308+
metadata_response = await client.get(
309+
resource_metadata_url, headers=headers, follow_redirects=True
310+
)
311+
312+
if metadata_response.status_code != 200:
313+
raise RuntimeError(
314+
f"Failed to fetch resource metadata: HTTP {metadata_response.status_code}"
315+
)
316+
317+
resource_metadata = ProtectedResourceMetadata.model_validate_json(
318+
metadata_response.content
319+
)
320+
321+
# Extract authorization server
322+
if not resource_metadata.authorization_servers:
323+
raise RuntimeError(
324+
"No authorization server found in OAuth protected resource metadata"
325+
)
326+
327+
authorization_server_url = str(resource_metadata.authorization_servers[0])
328+
329+
# Extract scope
330+
if not resource_metadata.scopes_supported:
331+
logging.warning(
332+
"No scopes found in OAuth protected resource metadata. Using empty scope."
333+
)
334+
scope = ""
335+
else:
336+
scope = " ".join(resource_metadata.scopes_supported)
337+
338+
logging.debug(f"Discovered scope: {scope}")
339+
logging.debug(f"Discovered authorization server: {authorization_server_url}")
340+
341+
return scope, authorization_server_url

0 commit comments

Comments
 (0)