33from contextlib import AsyncExitStack
44from datetime import timedelta
55from types import TracebackType
6- from typing import Any , Generic , TypeVar
6+ from typing import Any , Generic , Protocol , TypeVar
77
88import anyio
99import httpx
2424 JSONRPCNotification ,
2525 JSONRPCRequest ,
2626 JSONRPCResponse ,
27+ ProgressNotification ,
28+ ProgressNotificationParams ,
29+ ProgressToken ,
2730 RequestParams ,
2831 ServerNotification ,
2932 ServerRequest ,
3942 "ReceiveNotificationT" , ClientNotification , ServerNotification
4043)
4144
45+
46+ class ProgressFnT (Protocol ):
47+ async def __call__ (
48+ self ,
49+ params : ProgressNotificationParams ,
50+ ) -> None : ...
51+
52+
4253RequestId = str | int
4354
4455
@@ -168,7 +179,9 @@ class BaseSession(
168179 RequestId , MemoryObjectSendStream [JSONRPCResponse | JSONRPCError ]
169180 ]
170181 _request_id : int
182+ _progress_id : int
171183 _in_flight : dict [RequestId , RequestResponder [ReceiveRequestT , SendResultT ]]
184+ _in_progress : dict [ProgressToken , ProgressFnT ]
172185
173186 def __init__ (
174187 self ,
@@ -187,6 +200,8 @@ def __init__(
187200 self ._receive_notification_type = receive_notification_type
188201 self ._session_read_timeout_seconds = read_timeout_seconds
189202 self ._in_flight = {}
203+ self ._progress_id = 0
204+ self ._in_progress = {}
190205 self ._exit_stack = AsyncExitStack ()
191206
192207 async def __aenter__ (self ) -> Self :
@@ -214,19 +229,44 @@ async def send_request(
214229 result_type : type [ReceiveResultT ],
215230 request_read_timeout_seconds : timedelta | None = None ,
216231 metadata : MessageMetadata = None ,
232+ progress_callback : ProgressFnT | None = None ,
217233 ) -> ReceiveResultT :
218234 """
219235 Sends a request and wait for a response. Raises an McpError if the
220236 response contains an error. If a request read timeout is provided, it
221237 will take precedence over the session read timeout.
222238
239+ If progress_callback is provided any progress notifications sent from the
240+ receiver will be passed back to the sender
241+
223242 Do not use this method to emit notifications! Use send_notification()
224243 instead.
225244 """
226245
227246 request_id = self ._request_id
228247 self ._request_id = request_id + 1
229248
249+ progress_id = None
250+ send_request = None
251+
252+ if progress_callback is not None :
253+ if request .root .params is not None :
254+ progress_id = self ._progress_id
255+ self ._progress_id = progress_id + 1
256+ new_params = request .root .params .model_copy (
257+ update = {"meta" : RequestParams .Meta (progressToken = progress_id )}
258+ )
259+ new_root = request .root .model_copy (update = {"params" : new_params })
260+ send_request = request .model_copy (update = {"root" : new_root })
261+ self ._in_progress [progress_id ] = progress_callback
262+ else :
263+ raise ValueError (
264+ f"{ type (request .root ).__name__ } does not support progress"
265+ )
266+
267+ if send_request is None :
268+ send_request = request
269+
230270 response_stream , response_stream_reader = anyio .create_memory_object_stream [
231271 JSONRPCResponse | JSONRPCError
232272 ](1 )
@@ -236,11 +276,11 @@ async def send_request(
236276 jsonrpc_request = JSONRPCRequest (
237277 jsonrpc = "2.0" ,
238278 id = request_id ,
239- ** request .model_dump (by_alias = True , mode = "json" , exclude_none = True ),
279+ ** send_request .model_dump (
280+ by_alias = True , mode = "json" , exclude_none = True
281+ ),
240282 )
241283
242- # TODO: Support progress callbacks
243-
244284 await self ._write_stream .send (
245285 SessionMessage (
246286 message = JSONRPCMessage (jsonrpc_request ), metadata = metadata
@@ -276,6 +316,8 @@ async def send_request(
276316
277317 finally :
278318 self ._response_streams .pop (request_id , None )
319+ if progress_id is not None :
320+ self ._in_progress .pop (progress_id , None )
279321 await response_stream .aclose ()
280322 await response_stream_reader .aclose ()
281323
@@ -364,6 +406,20 @@ async def _receive_loop(self) -> None:
364406 if cancelled_id in self ._in_flight :
365407 await self ._in_flight [cancelled_id ].cancel ()
366408 else :
409+ match notification .root :
410+ case ProgressNotification (params = params ):
411+ if params .progressToken in self ._in_progress :
412+ progress_callback = self ._in_progress [
413+ params .progressToken
414+ ]
415+ await progress_callback (params )
416+ else :
417+ logging .warning (
418+ "Unknown progress token %s" ,
419+ params .progressToken ,
420+ )
421+ case _:
422+ pass
367423 await self ._received_notification (notification )
368424 await self ._handle_incoming (notification )
369425 except Exception as e :
0 commit comments