11
11
from contextlib import asynccontextmanager
12
12
from dataclasses import dataclass
13
13
from datetime import timedelta
14
- from typing import Any
14
+ from typing import Any , Protocol
15
15
16
16
import anyio
17
17
import httpx
@@ -74,6 +74,18 @@ class RequestContext:
74
74
sse_read_timeout : timedelta
75
75
76
76
77
+ class AuthTokenProvider (Protocol ):
78
+ """Protocol for providers that supply authentication tokens."""
79
+
80
+ async def get_token (self ) -> str :
81
+ """Get an authentication token.
82
+
83
+ Returns:
84
+ str: The authentication token.
85
+ """
86
+ ...
87
+
88
+
77
89
class StreamableHTTPTransport :
78
90
"""StreamableHTTP client transport implementation."""
79
91
@@ -83,6 +95,7 @@ def __init__(
83
95
headers : dict [str , Any ] | None = None ,
84
96
timeout : timedelta = timedelta (seconds = 30 ),
85
97
sse_read_timeout : timedelta = timedelta (seconds = 60 * 5 ),
98
+ auth_token_provider : AuthTokenProvider | None = None ,
86
99
) -> None :
87
100
"""Initialize the StreamableHTTP transport.
88
101
@@ -102,6 +115,7 @@ def __init__(
102
115
CONTENT_TYPE : JSON ,
103
116
** self .headers ,
104
117
}
118
+ self .auth_token_provider = auth_token_provider
105
119
106
120
def _update_headers_with_session (
107
121
self , base_headers : dict [str , str ]
@@ -112,6 +126,24 @@ def _update_headers_with_session(
112
126
headers [MCP_SESSION_ID ] = self .session_id
113
127
return headers
114
128
129
+ async def _update_headers_with_token (
130
+ self , base_headers : dict [str , str ]
131
+ ) -> dict [str , str ]:
132
+ """Update headers with token if token provider is specified."""
133
+ if self .auth_token_provider is None :
134
+ return base_headers
135
+
136
+ token = await self .auth_token_provider .get_token ()
137
+ headers = base_headers .copy ()
138
+ headers ["Authorization" ] = f"Bearer { token } "
139
+ return headers
140
+
141
+ async def _update_headers (self , base_headers : dict [str , str ]) -> dict [str , str ]:
142
+ """Update headers with session ID and token if available."""
143
+ headers = self ._update_headers_with_session (base_headers )
144
+ headers = await self ._update_headers_with_token (headers )
145
+ return headers
146
+
115
147
def _is_initialization_request (self , message : JSONRPCMessage ) -> bool :
116
148
"""Check if the message is an initialization request."""
117
149
return (
@@ -184,7 +216,7 @@ async def handle_get_stream(
184
216
if not self .session_id :
185
217
return
186
218
187
- headers = self ._update_headers_with_session (self .request_headers )
219
+ headers = await self ._update_headers (self .request_headers )
188
220
189
221
async with aconnect_sse (
190
222
client ,
@@ -206,7 +238,7 @@ async def handle_get_stream(
206
238
207
239
async def _handle_resumption_request (self , ctx : RequestContext ) -> None :
208
240
"""Handle a resumption request using GET with SSE."""
209
- headers = self ._update_headers_with_session (ctx .headers )
241
+ headers = await self ._update_headers (ctx .headers )
210
242
if ctx .metadata and ctx .metadata .resumption_token :
211
243
headers [LAST_EVENT_ID ] = ctx .metadata .resumption_token
212
244
else :
@@ -241,7 +273,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
241
273
242
274
async def _handle_post_request (self , ctx : RequestContext ) -> None :
243
275
"""Handle a POST request with response processing."""
244
- headers = self ._update_headers_with_session (ctx .headers )
276
+ headers = await self ._update_headers (ctx .headers )
245
277
message = ctx .session_message .message
246
278
is_initialization = self ._is_initialization_request (message )
247
279
@@ -405,7 +437,7 @@ async def terminate_session(self, client: httpx.AsyncClient) -> None:
405
437
return
406
438
407
439
try :
408
- headers = self ._update_headers_with_session (self .request_headers )
440
+ headers = await self ._update_headers (self .request_headers )
409
441
response = await client .delete (self .url , headers = headers )
410
442
411
443
if response .status_code == 405 :
@@ -427,6 +459,7 @@ async def streamablehttp_client(
427
459
timeout : timedelta = timedelta (seconds = 30 ),
428
460
sse_read_timeout : timedelta = timedelta (seconds = 60 * 5 ),
429
461
terminate_on_close : bool = True ,
462
+ auth_token_provider : AuthTokenProvider | None = None ,
430
463
) -> AsyncGenerator [
431
464
tuple [
432
465
MemoryObjectReceiveStream [SessionMessage | Exception ],
@@ -447,7 +480,9 @@ async def streamablehttp_client(
447
480
- write_stream: Stream for sending messages to the server
448
481
- get_session_id_callback: Function to retrieve the current session ID
449
482
"""
450
- transport = StreamableHTTPTransport (url , headers , timeout , sse_read_timeout )
483
+ transport = StreamableHTTPTransport (
484
+ url , headers , timeout , sse_read_timeout , auth_token_provider
485
+ )
451
486
452
487
read_stream_writer , read_stream = anyio .create_memory_object_stream [
453
488
SessionMessage | Exception
0 commit comments