1111from contextlib import asynccontextmanager
1212from dataclasses import dataclass
1313from datetime import timedelta
14- from typing import Any
14+ from typing import Any , Protocol
1515
1616import anyio
1717import httpx
@@ -74,6 +74,18 @@ class RequestContext:
7474 sse_read_timeout : timedelta
7575
7676
77+ class AuthTokenProvider (Protocol ):
78+ """Protocol for providers that supply authentication tokens."""
79+
80+ async def get_token (self ) -> str :
81+ """Get an authentication token.
82+
83+ Returns:
84+ str: The authentication token.
85+ """
86+ ...
87+
88+
7789class StreamableHTTPTransport :
7890 """StreamableHTTP client transport implementation."""
7991
@@ -83,6 +95,7 @@ def __init__(
8395 headers : dict [str , Any ] | None = None ,
8496 timeout : timedelta = timedelta (seconds = 30 ),
8597 sse_read_timeout : timedelta = timedelta (seconds = 60 * 5 ),
98+ auth_token_provider : AuthTokenProvider | None = None ,
8699 ) -> None :
87100 """Initialize the StreamableHTTP transport.
88101
@@ -102,6 +115,7 @@ def __init__(
102115 CONTENT_TYPE : JSON ,
103116 ** self .headers ,
104117 }
118+ self .auth_token_provider = auth_token_provider
105119
106120 def _update_headers_with_session (
107121 self , base_headers : dict [str , str ]
@@ -112,6 +126,24 @@ def _update_headers_with_session(
112126 headers [MCP_SESSION_ID ] = self .session_id
113127 return headers
114128
129+ async def _update_headers_with_token (
130+ self , base_headers : dict [str , str ]
131+ ) -> dict [str , str ]:
132+ """Update headers with token if token provider is specified."""
133+ if self .auth_token_provider is None :
134+ return base_headers
135+
136+ token = await self .auth_token_provider .get_token ()
137+ headers = base_headers .copy ()
138+ headers ["Authorization" ] = f"Bearer { token } "
139+ return headers
140+
141+ async def _update_headers (self , base_headers : dict [str , str ]) -> dict [str , str ]:
142+ """Update headers with session ID and token if available."""
143+ headers = self ._update_headers_with_session (base_headers )
144+ headers = await self ._update_headers_with_token (headers )
145+ return headers
146+
115147 def _is_initialization_request (self , message : JSONRPCMessage ) -> bool :
116148 """Check if the message is an initialization request."""
117149 return (
@@ -184,7 +216,7 @@ async def handle_get_stream(
184216 if not self .session_id :
185217 return
186218
187- headers = self ._update_headers_with_session (self .request_headers )
219+ headers = await self ._update_headers (self .request_headers )
188220
189221 async with aconnect_sse (
190222 client ,
@@ -206,7 +238,7 @@ async def handle_get_stream(
206238
207239 async def _handle_resumption_request (self , ctx : RequestContext ) -> None :
208240 """Handle a resumption request using GET with SSE."""
209- headers = self ._update_headers_with_session (ctx .headers )
241+ headers = await self ._update_headers (ctx .headers )
210242 if ctx .metadata and ctx .metadata .resumption_token :
211243 headers [LAST_EVENT_ID ] = ctx .metadata .resumption_token
212244 else :
@@ -241,7 +273,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
241273
242274 async def _handle_post_request (self , ctx : RequestContext ) -> None :
243275 """Handle a POST request with response processing."""
244- headers = self ._update_headers_with_session (ctx .headers )
276+ headers = await self ._update_headers (ctx .headers )
245277 message = ctx .session_message .message
246278 is_initialization = self ._is_initialization_request (message )
247279
@@ -405,7 +437,7 @@ async def terminate_session(self, client: httpx.AsyncClient) -> None:
405437 return
406438
407439 try :
408- headers = self ._update_headers_with_session (self .request_headers )
440+ headers = await self ._update_headers (self .request_headers )
409441 response = await client .delete (self .url , headers = headers )
410442
411443 if response .status_code == 405 :
@@ -427,6 +459,7 @@ async def streamablehttp_client(
427459 timeout : timedelta = timedelta (seconds = 30 ),
428460 sse_read_timeout : timedelta = timedelta (seconds = 60 * 5 ),
429461 terminate_on_close : bool = True ,
462+ auth_token_provider : AuthTokenProvider | None = None ,
430463) -> AsyncGenerator [
431464 tuple [
432465 MemoryObjectReceiveStream [SessionMessage | Exception ],
@@ -447,7 +480,9 @@ async def streamablehttp_client(
447480 - write_stream: Stream for sending messages to the server
448481 - get_session_id_callback: Function to retrieve the current session ID
449482 """
450- transport = StreamableHTTPTransport (url , headers , timeout , sse_read_timeout )
483+ transport = StreamableHTTPTransport (
484+ url , headers , timeout , sse_read_timeout , auth_token_provider
485+ )
451486
452487 read_stream_writer , read_stream = anyio .create_memory_object_stream [
453488 SessionMessage | Exception
0 commit comments