Skip to content

Commit 153942f

Browse files
committed
Update Auth Provider to AuthClientProvider
1 parent 8342340 commit 153942f

File tree

2 files changed

+30
-29
lines changed

2 files changed

+30
-29
lines changed

src/mcp/client/streamable_http.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,14 @@ class RequestContext:
7474
sse_read_timeout: timedelta
7575

7676

77-
class AuthTokenProvider(Protocol):
78-
"""Protocol that can be extended to implement custom client-to-server authentication
79-
The get_token method is invoked before each request to the MCP Server to retrieve a
80-
fresh authentication token and update the request headers."""
77+
class AuthClientProvider(Protocol):
78+
"""Base class that can be extended to implement custom client-to-server
79+
authentication"""
8180

8281
async def get_token(self) -> str:
83-
"""Get an authentication token.
82+
"""Get a token for authenticating to an MCP server. The token is assumed to
83+
be short-lived; clients may call this API multiple times per
84+
request to an MCP server.
8485
8586
Returns:
8687
str: The authentication token.
@@ -97,7 +98,7 @@ def __init__(
9798
headers: dict[str, Any] | None = None,
9899
timeout: timedelta = timedelta(seconds=30),
99100
sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
100-
auth_token_provider: AuthTokenProvider | None = None,
101+
auth_client_provider: AuthClientProvider | None = None,
101102
) -> None:
102103
"""Initialize the StreamableHTTP transport.
103104
@@ -117,7 +118,7 @@ def __init__(
117118
CONTENT_TYPE: JSON,
118119
**self.headers,
119120
}
120-
self.auth_token_provider = auth_token_provider
121+
self.auth_client_provider = auth_client_provider
121122

122123
def _update_headers_with_session(
123124
self, base_headers: dict[str, str]
@@ -133,10 +134,10 @@ async def _update_headers_with_token(
133134
) -> dict[str, str]:
134135
"""Update headers with token if token provider is specified and authorization
135136
header is not present."""
136-
if self.auth_token_provider is None or "Authorization" in base_headers:
137+
if self.auth_client_provider is None or "Authorization" in base_headers:
137138
return base_headers
138139

139-
token = await self.auth_token_provider.get_token()
140+
token = await self.auth_client_provider.get_token()
140141
headers = base_headers.copy()
141142
headers["Authorization"] = f"Bearer {token}"
142143
return headers
@@ -462,7 +463,7 @@ async def streamablehttp_client(
462463
timeout: timedelta = timedelta(seconds=30),
463464
sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
464465
terminate_on_close: bool = True,
465-
auth_token_provider: AuthTokenProvider | None = None,
466+
auth_client_provider: AuthClientProvider | None = None,
466467
) -> AsyncGenerator[
467468
tuple[
468469
MemoryObjectReceiveStream[SessionMessage | Exception],
@@ -477,7 +478,7 @@ async def streamablehttp_client(
477478
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
478479
event before disconnecting. All other HTTP operations are controlled by `timeout`.
479480
480-
`auth_token_provider` is an optional protocol that can be extended to implement
481+
`auth_client_provider` is an optional protocol that can be extended to implement
481482
custom client-to-server authentication. Before each request to the MCP Server,
482483
the get_token method is invoked to retrieve a fresh authentication token and
483484
update the request headers. Note that if the passed in headers already
@@ -490,7 +491,7 @@ async def streamablehttp_client(
490491
- get_session_id_callback: Function to retrieve the current session ID
491492
"""
492493
transport = StreamableHTTPTransport(
493-
url, headers, timeout, sse_read_timeout, auth_token_provider
494+
url, headers, timeout, sse_read_timeout, auth_client_provider
494495
)
495496

496497
read_stream_writer, read_stream = anyio.create_memory_object_stream[

tests/shared/test_streamable_http.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1226,8 +1226,8 @@ async def sampling_callback(
12261226
)
12271227

12281228

1229-
class MockAuthTokenProvider:
1230-
"""Mock implementation of AuthTokenProvider for testing."""
1229+
class MockAuthClientProvider:
1230+
"""Mock implementation of AuthClientProvider for testing."""
12311231

12321232
def __init__(self, token: str):
12331233
self.token = token
@@ -1237,15 +1237,15 @@ async def get_token(self) -> str:
12371237

12381238

12391239
@pytest.mark.anyio
1240-
async def test_auth_token_provider_headers(basic_server, basic_server_url):
1240+
async def test_auth_client_provider_headers(basic_server, basic_server_url):
12411241
"""Test that auth token provider correctly sets Authorization header."""
12421242
# Create a mock token provider
1243-
token_provider = MockAuthTokenProvider("test-token-123")
1244-
token_provider.get_token = AsyncMock(return_value="test-token-123")
1243+
client_provider = MockAuthClientProvider("test-token-123")
1244+
client_provider.get_token = AsyncMock(return_value="test-token-123")
12451245

12461246
# Create client with token provider
12471247
async with streamablehttp_client(
1248-
f"{basic_server_url}/mcp", auth_token_provider=token_provider
1248+
f"{basic_server_url}/mcp", auth_client_provider=client_provider
12491249
) as (read_stream, write_stream, _):
12501250
async with ClientSession(read_stream, write_stream) as session:
12511251
# Initialize the session
@@ -1256,19 +1256,19 @@ async def test_auth_token_provider_headers(basic_server, basic_server_url):
12561256
tools = await session.list_tools()
12571257
assert len(tools.tools) == 4
12581258

1259-
token_provider.get_token.assert_called()
1259+
client_provider.get_token.assert_called()
12601260

12611261

12621262
@pytest.mark.anyio
1263-
async def test_auth_token_provider_token_update(basic_server, basic_server_url):
1263+
async def test_auth_client_provider_token_update(basic_server, basic_server_url):
12641264
"""Test that auth token provider can return different tokens."""
12651265
# Create a dynamic token provider
1266-
token_provider = MockAuthTokenProvider("test-token-123")
1267-
token_provider.get_token = AsyncMock(return_value="test-token-123")
1266+
client_provider = MockAuthClientProvider("test-token-123")
1267+
client_provider.get_token = AsyncMock(return_value="test-token-123")
12681268

12691269
# Create client with dynamic token provider
12701270
async with streamablehttp_client(
1271-
f"{basic_server_url}/mcp", auth_token_provider=token_provider
1271+
f"{basic_server_url}/mcp", auth_client_provider=client_provider
12721272
) as (read_stream, write_stream, _):
12731273
async with ClientSession(read_stream, write_stream) as session:
12741274
# Initialize the session
@@ -1280,22 +1280,22 @@ async def test_auth_token_provider_token_update(basic_server, basic_server_url):
12801280
tools = await session.list_tools()
12811281
assert len(tools.tools) == 4
12821282

1283-
token_provider.get_token.call_count > 1
1283+
client_provider.get_token.call_count > 1
12841284

12851285

12861286
@pytest.mark.anyio
1287-
async def test_auth_token_provider_headers_not_overridden(
1287+
async def test_auth_client_provider_headers_not_overridden(
12881288
basic_server, basic_server_url
12891289
):
12901290
"""Test that auth token provider correctly sets Authorization header."""
12911291
# Create a mock token provider
1292-
token_provider = MockAuthTokenProvider("test-token-123")
1293-
token_provider.get_token = AsyncMock(return_value="test-token-123")
1292+
client_provider = MockAuthClientProvider("test-token-123")
1293+
client_provider.get_token = AsyncMock(return_value="test-token-123")
12941294

12951295
# Create client with token provider
12961296
async with streamablehttp_client(
12971297
f"{basic_server_url}/mcp",
1298-
auth_token_provider=token_provider,
1298+
auth_client_provider=client_provider,
12991299
headers={"Authorization": "test-token-123"},
13001300
) as (read_stream, write_stream, _):
13011301
async with ClientSession(read_stream, write_stream) as session:
@@ -1307,4 +1307,4 @@ async def test_auth_token_provider_headers_not_overridden(
13071307
tools = await session.list_tools()
13081308
assert len(tools.tools) == 4
13091309

1310-
token_provider.get_token.assert_not_called()
1310+
client_provider.get_token.assert_not_called()

0 commit comments

Comments
 (0)