Skip to content

Commit efc1d04

Browse files
committed
Add support for get tokens method
1 parent fdb538b commit efc1d04

File tree

2 files changed

+100
-6
lines changed

2 files changed

+100
-6
lines changed

src/mcp/client/streamable_http.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from contextlib import asynccontextmanager
1212
from dataclasses import dataclass
1313
from datetime import timedelta
14-
from typing import Any
14+
from typing import Any, Protocol
1515

1616
import anyio
1717
import httpx
@@ -74,6 +74,18 @@ class RequestContext:
7474
sse_read_timeout: timedelta
7575

7676

77+
class AuthTokenProvider(Protocol):
78+
"""Protocol for providers that supply authentication tokens."""
79+
80+
async def get_token(self) -> str:
81+
"""Get an authentication token.
82+
83+
Returns:
84+
str: The authentication token.
85+
"""
86+
...
87+
88+
7789
class StreamableHTTPTransport:
7890
"""StreamableHTTP client transport implementation."""
7991

@@ -83,6 +95,7 @@ def __init__(
8395
headers: dict[str, Any] | None = None,
8496
timeout: timedelta = timedelta(seconds=30),
8597
sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
98+
auth_token_provider: AuthTokenProvider | None = None,
8699
) -> None:
87100
"""Initialize the StreamableHTTP transport.
88101
@@ -102,6 +115,7 @@ def __init__(
102115
CONTENT_TYPE: JSON,
103116
**self.headers,
104117
}
118+
self.auth_token_provider = auth_token_provider
105119

106120
def _update_headers_with_session(
107121
self, base_headers: dict[str, str]
@@ -112,6 +126,24 @@ def _update_headers_with_session(
112126
headers[MCP_SESSION_ID] = self.session_id
113127
return headers
114128

129+
async def _update_headers_with_token(
130+
self, base_headers: dict[str, str]
131+
) -> dict[str, str]:
132+
"""Update headers with token if token provider is specified."""
133+
if self.auth_token_provider is None:
134+
return base_headers
135+
136+
token = await self.auth_token_provider.get_token()
137+
headers = base_headers.copy()
138+
headers["Authorization"] = f"Bearer {token}"
139+
return headers
140+
141+
async def _update_headers(self, base_headers: dict[str, str]) -> dict[str, str]:
142+
"""Update headers with session ID and token if available."""
143+
headers = self._update_headers_with_session(base_headers)
144+
headers = await self._update_headers_with_token(headers)
145+
return headers
146+
115147
def _is_initialization_request(self, message: JSONRPCMessage) -> bool:
116148
"""Check if the message is an initialization request."""
117149
return (
@@ -184,7 +216,7 @@ async def handle_get_stream(
184216
if not self.session_id:
185217
return
186218

187-
headers = self._update_headers_with_session(self.request_headers)
219+
headers = await self._update_headers(self.request_headers)
188220

189221
async with aconnect_sse(
190222
client,
@@ -206,7 +238,7 @@ async def handle_get_stream(
206238

207239
async def _handle_resumption_request(self, ctx: RequestContext) -> None:
208240
"""Handle a resumption request using GET with SSE."""
209-
headers = self._update_headers_with_session(ctx.headers)
241+
headers = await self._update_headers(ctx.headers)
210242
if ctx.metadata and ctx.metadata.resumption_token:
211243
headers[LAST_EVENT_ID] = ctx.metadata.resumption_token
212244
else:
@@ -241,7 +273,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
241273

242274
async def _handle_post_request(self, ctx: RequestContext) -> None:
243275
"""Handle a POST request with response processing."""
244-
headers = self._update_headers_with_session(ctx.headers)
276+
headers = await self._update_headers(ctx.headers)
245277
message = ctx.session_message.message
246278
is_initialization = self._is_initialization_request(message)
247279

@@ -405,7 +437,7 @@ async def terminate_session(self, client: httpx.AsyncClient) -> None:
405437
return
406438

407439
try:
408-
headers = self._update_headers_with_session(self.request_headers)
440+
headers = await self._update_headers(self.request_headers)
409441
response = await client.delete(self.url, headers=headers)
410442

411443
if response.status_code == 405:
@@ -427,6 +459,7 @@ async def streamablehttp_client(
427459
timeout: timedelta = timedelta(seconds=30),
428460
sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
429461
terminate_on_close: bool = True,
462+
auth_token_provider: AuthTokenProvider | None = None,
430463
) -> AsyncGenerator[
431464
tuple[
432465
MemoryObjectReceiveStream[SessionMessage | Exception],
@@ -447,7 +480,9 @@ async def streamablehttp_client(
447480
- write_stream: Stream for sending messages to the server
448481
- get_session_id_callback: Function to retrieve the current session ID
449482
"""
450-
transport = StreamableHTTPTransport(url, headers, timeout, sse_read_timeout)
483+
transport = StreamableHTTPTransport(
484+
url, headers, timeout, sse_read_timeout, auth_token_provider
485+
)
451486

452487
read_stream_writer, read_stream = anyio.create_memory_object_stream[
453488
SessionMessage | Exception

tests/shared/test_streamable_http.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import time
1010
from collections.abc import Generator
1111
from typing import Any
12+
from unittest.mock import AsyncMock
1213

1314
import anyio
1415
import httpx
@@ -1223,3 +1224,61 @@ async def sampling_callback(
12231224
captured_message_params.messages[0].content.text
12241225
== "Server needs client sampling"
12251226
)
1227+
1228+
1229+
class MockAuthTokenProvider:
1230+
"""Mock implementation of AuthTokenProvider for testing."""
1231+
1232+
def __init__(self, token: str):
1233+
self.token = token
1234+
1235+
async def get_token(self) -> str:
1236+
return self.token
1237+
1238+
1239+
@pytest.mark.anyio
1240+
async def test_auth_token_provider_headers(basic_server, basic_server_url):
1241+
"""Test that auth token provider correctly sets Authorization header."""
1242+
# Create a mock token provider
1243+
token_provider = MockAuthTokenProvider("test-token-123")
1244+
token_provider.get_token = AsyncMock(return_value="test-token-123")
1245+
1246+
# Create client with token provider
1247+
async with streamablehttp_client(
1248+
f"{basic_server_url}/mcp", auth_token_provider=token_provider
1249+
) as (read_stream, write_stream, _):
1250+
async with ClientSession(read_stream, write_stream) as session:
1251+
# Initialize the session
1252+
result = await session.initialize()
1253+
assert isinstance(result, InitializeResult)
1254+
1255+
# Make a request to verify headers
1256+
tools = await session.list_tools()
1257+
assert len(tools.tools) == 4
1258+
1259+
token_provider.get_token.assert_called()
1260+
1261+
1262+
@pytest.mark.anyio
1263+
async def test_auth_token_provider_token_update(basic_server, basic_server_url):
1264+
"""Test that auth token provider can return different tokens."""
1265+
# Create a dynamic token provider
1266+
token_provider = MockAuthTokenProvider("test-token-123")
1267+
token_provider.get_token = AsyncMock(return_value="test-token-123")
1268+
1269+
# Create client with dynamic token provider
1270+
async with streamablehttp_client(
1271+
f"{basic_server_url}/mcp", auth_token_provider=token_provider
1272+
) as (read_stream, write_stream, _):
1273+
async with ClientSession(read_stream, write_stream) as session:
1274+
# Initialize the session
1275+
result = await session.initialize()
1276+
assert isinstance(result, InitializeResult)
1277+
1278+
# Make multiple requests to verify token updates
1279+
for i in range(3):
1280+
tools = await session.list_tools()
1281+
assert len(tools.tools) == 4
1282+
await anyio.sleep(0.1) # Small delay to ensure token updates
1283+
1284+
token_provider.get_token.call_count > 1

0 commit comments

Comments
 (0)