1+ import logging
12from contextlib import AbstractAsyncContextManager
23from datetime import timedelta
3- from typing import Generic , TypeVar
4+ from typing import Any , Callable , Generic , TypeVar
45
56import anyio
67import anyio .lowlevel
1011
1112from mcp .shared .exceptions import McpError
1213from mcp .types import (
14+ CancelledNotification ,
1315 ClientNotification ,
1416 ClientRequest ,
1517 ClientResult ,
3840
3941
4042class RequestResponder (Generic [ReceiveRequestT , SendResultT ]):
43+ """Handles responding to MCP requests and manages request lifecycle.
44+
45+ This class MUST be used as a context manager to ensure proper cleanup and
46+ cancellation handling:
47+
48+ Example:
49+ with request_responder as resp:
50+ await resp.respond(result)
51+
52+ The context manager ensures:
53+ 1. Proper cancellation scope setup and cleanup
54+ 2. Request completion tracking
55+ 3. Cleanup of in-flight requests
56+ """
57+
4158 def __init__ (
4259 self ,
4360 request_id : RequestId ,
4461 request_meta : RequestParams .Meta | None ,
4562 request : ReceiveRequestT ,
4663 session : "BaseSession" ,
64+ on_complete : Callable [["RequestResponder[ReceiveRequestT, SendResultT]" ], Any ],
4765 ) -> None :
4866 self .request_id = request_id
4967 self .request_meta = request_meta
5068 self .request = request
5169 self ._session = session
52- self ._responded = False
70+ self ._completed = False
71+ self ._cancel_scope = anyio .CancelScope ()
72+ self ._on_complete = on_complete
73+ self ._entered = False # Track if we're in a context manager
74+
75+ def __enter__ (self ) -> "RequestResponder[ReceiveRequestT, SendResultT]" :
76+ """Enter the context manager, enabling request cancellation tracking."""
77+ self ._entered = True
78+ self ._cancel_scope = anyio .CancelScope ()
79+ self ._cancel_scope .__enter__ ()
80+ return self
81+
82+ def __exit__ (self , exc_type , exc_val , exc_tb ) -> None :
83+ """Exit the context manager, performing cleanup and notifying completion."""
84+ try :
85+ if self ._completed :
86+ self ._on_complete (self )
87+ finally :
88+ self ._entered = False
89+ if not self ._cancel_scope :
90+ raise RuntimeError ("No active cancel scope" )
91+ self ._cancel_scope .__exit__ (exc_type , exc_val , exc_tb )
5392
5493 async def respond (self , response : SendResultT | ErrorData ) -> None :
55- assert not self ._responded , "Request already responded to"
56- self ._responded = True
94+ """Send a response for this request.
95+
96+ Must be called within a context manager block.
97+ Raises:
98+ RuntimeError: If not used within a context manager
99+ AssertionError: If request was already responded to
100+ """
101+ if not self ._entered :
102+ raise RuntimeError ("RequestResponder must be used as a context manager" )
103+ assert not self ._completed , "Request already responded to"
104+
105+ if not self .cancelled :
106+ self ._completed = True
107+
108+ await self ._session ._send_response (
109+ request_id = self .request_id , response = response
110+ )
111+
112+ async def cancel (self ) -> None :
113+ """Cancel this request and mark it as completed."""
114+ if not self ._entered :
115+ raise RuntimeError ("RequestResponder must be used as a context manager" )
116+ if not self ._cancel_scope :
117+ raise RuntimeError ("No active cancel scope" )
57118
119+ self ._cancel_scope .cancel ()
120+ self ._completed = True # Mark as completed so it's removed from in_flight
121+ # Send an error response to indicate cancellation
58122 await self ._session ._send_response (
59- request_id = self .request_id , response = response
123+ request_id = self .request_id ,
124+ response = ErrorData (code = 0 , message = "Request cancelled" , data = None ),
60125 )
61126
127+ @property
128+ def in_flight (self ) -> bool :
129+ return not self ._completed and not self .cancelled
130+
131+ @property
132+ def cancelled (self ) -> bool :
133+ return self ._cancel_scope is not None and self ._cancel_scope .cancel_called
134+
62135
63136class BaseSession (
64137 AbstractAsyncContextManager ,
@@ -82,6 +155,7 @@ class BaseSession(
82155 RequestId , MemoryObjectSendStream [JSONRPCResponse | JSONRPCError ]
83156 ]
84157 _request_id : int
158+ _in_flight : dict [RequestId , RequestResponder [ReceiveRequestT , SendResultT ]]
85159
86160 def __init__ (
87161 self ,
@@ -99,6 +173,7 @@ def __init__(
99173 self ._receive_request_type = receive_request_type
100174 self ._receive_notification_type = receive_notification_type
101175 self ._read_timeout_seconds = read_timeout_seconds
176+ self ._in_flight = {}
102177
103178 self ._incoming_message_stream_writer , self ._incoming_message_stream_reader = (
104179 anyio .create_memory_object_stream [
@@ -219,27 +294,45 @@ async def _receive_loop(self) -> None:
219294 by_alias = True , mode = "json" , exclude_none = True
220295 )
221296 )
297+
222298 responder = RequestResponder (
223299 request_id = message .root .id ,
224300 request_meta = validated_request .root .params .meta
225301 if validated_request .root .params
226302 else None ,
227303 request = validated_request ,
228304 session = self ,
305+ on_complete = lambda r : self ._in_flight .pop (r .request_id , None ),
229306 )
230307
308+ self ._in_flight [responder .request_id ] = responder
231309 await self ._received_request (responder )
232- if not responder ._responded :
310+ if not responder ._completed :
233311 await self ._incoming_message_stream_writer .send (responder )
312+
234313 elif isinstance (message .root , JSONRPCNotification ):
235- notification = self ._receive_notification_type .model_validate (
236- message .root .model_dump (
237- by_alias = True , mode = "json" , exclude_none = True
314+ try :
315+ notification = self ._receive_notification_type .model_validate (
316+ message .root .model_dump (
317+ by_alias = True , mode = "json" , exclude_none = True
318+ )
319+ )
320+ # Handle cancellation notifications
321+ if isinstance (notification .root , CancelledNotification ):
322+ cancelled_id = notification .root .params .requestId
323+ if cancelled_id in self ._in_flight :
324+ await self ._in_flight [cancelled_id ].cancel ()
325+ else :
326+ await self ._received_notification (notification )
327+ await self ._incoming_message_stream_writer .send (
328+ notification
329+ )
330+ except Exception as e :
331+ # For other validation errors, log and continue
332+ logging .warning (
333+ f"Failed to validate notification: { e } . "
334+ f"Message was: { message .root } "
238335 )
239- )
240-
241- await self ._received_notification (notification )
242- await self ._incoming_message_stream_writer .send (notification )
243336 else : # Response or error
244337 stream = self ._response_streams .pop (message .root .id , None )
245338 if stream :
0 commit comments