-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Dynamic Authorization in Streamable HTTP Client #700
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
7d9a84a
289c03a
c99d4f7
785964e
d3f0dea
c4fb621
28ae4f7
8ab1d66
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,7 +11,7 @@ | |
from contextlib import asynccontextmanager | ||
from dataclasses import dataclass | ||
from datetime import timedelta | ||
from typing import Any | ||
from typing import Any, Protocol | ||
|
||
import anyio | ||
import httpx | ||
|
@@ -74,6 +74,20 @@ class RequestContext: | |
sse_read_timeout: timedelta | ||
|
||
|
||
class AuthClientProvider(Protocol): | ||
"""Base class that can be extended to implement custom client-to-server | ||
authentication""" | ||
|
||
async def get_auth_headers(self) -> dict[str, str]: | ||
"""Gets auth headers for authenticating to an MCP server. | ||
Clients may call this API multiple times per request to an MCP server. | ||
|
||
Returns: | ||
dict[str, str]: The authentication headers. | ||
""" | ||
... | ||
|
||
|
||
class StreamableHTTPTransport: | ||
"""StreamableHTTP client transport implementation.""" | ||
aravind-segu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
@@ -83,6 +97,7 @@ def __init__( | |
headers: dict[str, Any] | None = None, | ||
timeout: timedelta = timedelta(seconds=30), | ||
sse_read_timeout: timedelta = timedelta(seconds=60 * 5), | ||
auth_client_provider: AuthClientProvider | None = None, | ||
) -> None: | ||
"""Initialize the StreamableHTTP transport. | ||
|
||
|
@@ -102,6 +117,7 @@ def __init__( | |
CONTENT_TYPE: JSON, | ||
**self.headers, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. BTW I would expect any auth headers passed in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good call, I added the behaviour to not override passed in headers, and add a test case as well |
||
} | ||
self.auth_client_provider = auth_client_provider | ||
|
||
def _update_headers_with_session( | ||
self, base_headers: dict[str, str] | ||
|
@@ -112,6 +128,24 @@ def _update_headers_with_session( | |
headers[MCP_SESSION_ID] = self.session_id | ||
return headers | ||
|
||
async def _update_headers_with_auth_headers( | ||
self, base_headers: dict[str, str] | ||
) -> dict[str, str]: | ||
"""Update headers with auth_headers if auth client provider is specified. | ||
The headers are merged giving precedence to the base_headers to | ||
avoid overwriting existing Authorization headers""" | ||
aravind-segu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if self.auth_client_provider is None: | ||
return base_headers | ||
|
||
auth_headers = await self.auth_client_provider.get_auth_headers() | ||
aravind-segu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return {**auth_headers, **base_headers} | ||
|
||
async def _update_headers(self, base_headers: dict[str, str]) -> dict[str, str]: | ||
"""Update headers with session ID and token if available.""" | ||
headers = self._update_headers_with_session(base_headers) | ||
headers = await self._update_headers_with_auth_headers(headers) | ||
return headers | ||
|
||
def _is_initialization_request(self, message: JSONRPCMessage) -> bool: | ||
"""Check if the message is an initialization request.""" | ||
return ( | ||
|
@@ -184,7 +218,7 @@ async def handle_get_stream( | |
if not self.session_id: | ||
return | ||
|
||
headers = self._update_headers_with_session(self.request_headers) | ||
headers = await self._update_headers(self.request_headers) | ||
|
||
async with aconnect_sse( | ||
client, | ||
|
@@ -206,7 +240,7 @@ async def handle_get_stream( | |
|
||
async def _handle_resumption_request(self, ctx: RequestContext) -> None: | ||
"""Handle a resumption request using GET with SSE.""" | ||
headers = self._update_headers_with_session(ctx.headers) | ||
headers = await self._update_headers(ctx.headers) | ||
if ctx.metadata and ctx.metadata.resumption_token: | ||
headers[LAST_EVENT_ID] = ctx.metadata.resumption_token | ||
else: | ||
|
@@ -216,7 +250,6 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None: | |
original_request_id = None | ||
if isinstance(ctx.session_message.message.root, JSONRPCRequest): | ||
original_request_id = ctx.session_message.message.root.id | ||
|
||
async with aconnect_sse( | ||
ctx.client, | ||
"GET", | ||
|
@@ -241,7 +274,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None: | |
|
||
async def _handle_post_request(self, ctx: RequestContext) -> None: | ||
"""Handle a POST request with response processing.""" | ||
headers = self._update_headers_with_session(ctx.headers) | ||
headers = await self._update_headers(ctx.headers) | ||
message = ctx.session_message.message | ||
is_initialization = self._is_initialization_request(message) | ||
|
||
|
@@ -268,7 +301,6 @@ async def _handle_post_request(self, ctx: RequestContext) -> None: | |
self._maybe_extract_session_id_from_response(response) | ||
|
||
content_type = response.headers.get(CONTENT_TYPE, "").lower() | ||
|
||
if content_type.startswith(JSON): | ||
await self._handle_json_response(response, ctx.read_stream_writer) | ||
elif content_type.startswith(SSE): | ||
|
@@ -405,7 +437,7 @@ async def terminate_session(self, client: httpx.AsyncClient) -> None: | |
return | ||
|
||
try: | ||
headers = self._update_headers_with_session(self.request_headers) | ||
headers = await self._update_headers(self.request_headers) | ||
response = await client.delete(self.url, headers=headers) | ||
|
||
if response.status_code == 405: | ||
|
@@ -427,6 +459,7 @@ async def streamablehttp_client( | |
timeout: timedelta = timedelta(seconds=30), | ||
sse_read_timeout: timedelta = timedelta(seconds=60 * 5), | ||
terminate_on_close: bool = True, | ||
auth_client_provider: AuthClientProvider | None = None, | ||
) -> AsyncGenerator[ | ||
tuple[ | ||
MemoryObjectReceiveStream[SessionMessage | Exception], | ||
|
@@ -441,13 +474,22 @@ async def streamablehttp_client( | |
`sse_read_timeout` determines how long (in seconds) the client will wait for a new | ||
event before disconnecting. All other HTTP operations are controlled by `timeout`. | ||
|
||
`auth_client_provider` instance of `AuthClientProvider` that can be passed to | ||
support client-to-server authentication. Before each request to the MCP Server, | ||
the auth_client_provider.get_token method is invoked to retrieve a fresh | ||
authentication token and update the request headers. Note that if the passed in | ||
aravind-segu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
`headers` already contain an Authorization header, that header will take precedence | ||
over any tokens generated by this provider. | ||
|
||
Yields: | ||
Tuple containing: | ||
- read_stream: Stream for reading messages from the server | ||
- write_stream: Stream for sending messages to the server | ||
- get_session_id_callback: Function to retrieve the current session ID | ||
""" | ||
transport = StreamableHTTPTransport(url, headers, timeout, sse_read_timeout) | ||
transport = StreamableHTTPTransport( | ||
url, headers, timeout, sse_read_timeout, auth_client_provider | ||
) | ||
|
||
read_stream_writer, read_stream = anyio.create_memory_object_stream[ | ||
SessionMessage | Exception | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,7 @@ | |
import time | ||
from collections.abc import Generator | ||
from typing import Any | ||
from unittest.mock import AsyncMock | ||
|
||
import anyio | ||
import httpx | ||
|
@@ -1223,3 +1224,64 @@ async def sampling_callback( | |
captured_message_params.messages[0].content.text | ||
== "Server needs client sampling" | ||
) | ||
|
||
|
||
class MockAuthClientProvider: | ||
aravind-segu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Mock implementation of AuthClientProvider for testing.""" | ||
|
||
def __init__(self, token: str): | ||
self.token = token | ||
|
||
async def get_auth_headers(self) -> dict[str, str]: | ||
return {"Authorization": "Bearer " + self.token} | ||
|
||
|
||
@pytest.mark.anyio | ||
async def test_auth_client_provider_headers(basic_server, basic_server_url): | ||
"""Test that auth token provider correctly sets Authorization header.""" | ||
# Create a mock token provider | ||
client_provider = MockAuthClientProvider("test-token-123") | ||
client_provider.get_auth_headers = AsyncMock( | ||
return_value={"Authorization": "Bearer test-token-123"} | ||
) | ||
|
||
# Create client with token provider | ||
async with streamablehttp_client( | ||
f"{basic_server_url}/mcp", auth_client_provider=client_provider | ||
) as (read_stream, write_stream, _): | ||
async with ClientSession(read_stream, write_stream) as session: | ||
# Initialize the session | ||
result = await session.initialize() | ||
assert isinstance(result, InitializeResult) | ||
|
||
# Make a request to verify headers | ||
tools = await session.list_tools() | ||
assert len(tools.tools) == 4 | ||
|
||
client_provider.get_auth_headers.assert_called() | ||
|
||
|
||
@pytest.mark.anyio | ||
async def test_auth_client_provider_called_per_request(basic_server, basic_server_url): | ||
"""Test that auth token provider can return different tokens.""" | ||
# Create a dynamic token provider | ||
client_provider = MockAuthClientProvider("test-token-123") | ||
client_provider.get_auth_headers = AsyncMock( | ||
return_value={"Authorization": "Bearer test-token-123"} | ||
) | ||
|
||
# Create client with dynamic token provider | ||
async with streamablehttp_client( | ||
f"{basic_server_url}/mcp", auth_client_provider=client_provider | ||
) as (read_stream, write_stream, _): | ||
async with ClientSession(read_stream, write_stream) as session: | ||
# Initialize the session | ||
result = await session.initialize() | ||
assert isinstance(result, InitializeResult) | ||
|
||
# Make multiple requests to verify token updates | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Dumb question, where do we verify the token is actually updated? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is really hard in this testing environment to get the headers and verify the implementation. There is a mock server which hosts a list and set tools. We create a session, and add our messages to the write stream. This is then read by our transport layer and a request is sent to the server. I could not find a way to intercept or inspect this request object to verify the headers. So I just ensured the method was being called There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could try to build a custom app, that looks at the header, then calls the server, and returns the auth headers in the response headers. I will wait for the maintainer to chime in if they have better ideas on how to test this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. off the cuff I suspect you'd want to patch:
In the appropriate places to catch all calls with headers, and assert your headers from the provider are there. I think the mocks could just pass through to the original function |
||
for i in range(3): | ||
tools = await session.list_tools() | ||
assert len(tools.tools) == 4 | ||
|
||
client_provider.get_auth_headers.call_count > 1 | ||
aravind-segu marked this conversation as resolved.
Show resolved
Hide resolved
|
Uh oh!
There was an error while loading. Please reload this page.