Skip to content

Commit 8342340

Browse files
committed
Address comments
1 parent efc1d04 commit 8342340

File tree

2 files changed

+39
-4
lines changed

2 files changed

+39
-4
lines changed

src/mcp/client/streamable_http.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,9 @@ class RequestContext:
7575

7676

7777
class AuthTokenProvider(Protocol):
78-
"""Protocol for providers that supply authentication tokens."""
78+
"""Protocol that can be extended to implement custom client-to-server authentication
79+
The get_token method is invoked before each request to the MCP Server to retrieve a
80+
fresh authentication token and update the request headers."""
7981

8082
async def get_token(self) -> str:
8183
"""Get an authentication token.
@@ -129,8 +131,9 @@ def _update_headers_with_session(
129131
async def _update_headers_with_token(
130132
self, base_headers: dict[str, str]
131133
) -> dict[str, str]:
132-
"""Update headers with token if token provider is specified."""
133-
if self.auth_token_provider is None:
134+
"""Update headers with token if token provider is specified and authorization
135+
header is not present."""
136+
if self.auth_token_provider is None or "Authorization" in base_headers:
134137
return base_headers
135138

136139
token = await self.auth_token_provider.get_token()
@@ -474,6 +477,12 @@ async def streamablehttp_client(
474477
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
475478
event before disconnecting. All other HTTP operations are controlled by `timeout`.
476479
480+
`auth_token_provider` is an optional protocol that can be extended to implement
481+
custom client-to-server authentication. Before each request to the MCP Server,
482+
the get_token method is invoked to retrieve a fresh authentication token and
483+
update the request headers. Note that if the passed in headers already
484+
contain an authorization header, this provider will not be called.
485+
477486
Yields:
478487
Tuple containing:
479488
- read_stream: Stream for reading messages from the server

tests/shared/test_streamable_http.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1279,6 +1279,32 @@ async def test_auth_token_provider_token_update(basic_server, basic_server_url):
12791279
for i in range(3):
12801280
tools = await session.list_tools()
12811281
assert len(tools.tools) == 4
1282-
await anyio.sleep(0.1) # Small delay to ensure token updates
12831282

12841283
token_provider.get_token.call_count > 1
1284+
1285+
1286+
@pytest.mark.anyio
1287+
async def test_auth_token_provider_headers_not_overridden(
1288+
basic_server, basic_server_url
1289+
):
1290+
"""Test that auth token provider correctly sets Authorization header."""
1291+
# Create a mock token provider
1292+
token_provider = MockAuthTokenProvider("test-token-123")
1293+
token_provider.get_token = AsyncMock(return_value="test-token-123")
1294+
1295+
# Create client with token provider
1296+
async with streamablehttp_client(
1297+
f"{basic_server_url}/mcp",
1298+
auth_token_provider=token_provider,
1299+
headers={"Authorization": "test-token-123"},
1300+
) as (read_stream, write_stream, _):
1301+
async with ClientSession(read_stream, write_stream) as session:
1302+
# Initialize the session
1303+
result = await session.initialize()
1304+
assert isinstance(result, InitializeResult)
1305+
1306+
# Make a request to verify headers
1307+
tools = await session.list_tools()
1308+
assert len(tools.tools) == 4
1309+
1310+
token_provider.get_token.assert_not_called()

0 commit comments

Comments
 (0)