16
16
from mcp .types import (
17
17
CONNECTION_CLOSED ,
18
18
INVALID_PARAMS ,
19
+ CallToolResult ,
19
20
CancelledNotification ,
20
21
ClientNotification ,
21
22
ClientRequest ,
22
23
ClientResult ,
23
24
ErrorData ,
25
+ GetOperationPayloadRequest ,
26
+ GetOperationPayloadResult ,
24
27
JSONRPCError ,
25
28
JSONRPCMessage ,
26
29
JSONRPCNotification ,
@@ -177,6 +180,7 @@ class BaseSession(
177
180
_request_id : int
178
181
_in_flight : dict [RequestId , RequestResponder [ReceiveRequestT , SendResultT ]]
179
182
_progress_callbacks : dict [RequestId , ProgressFnT ]
183
+ _operation_requests : dict [str , RequestId ]
180
184
181
185
def __init__ (
182
186
self ,
@@ -196,6 +200,7 @@ def __init__(
196
200
self ._session_read_timeout_seconds = read_timeout_seconds
197
201
self ._in_flight = {}
198
202
self ._progress_callbacks = {}
203
+ self ._operation_requests = {}
199
204
self ._exit_stack = AsyncExitStack ()
200
205
201
206
async def __aenter__ (self ) -> Self :
@@ -251,6 +256,7 @@ async def send_request(
251
256
# Store the callback for this request
252
257
self ._progress_callbacks [request_id ] = progress_callback
253
258
259
+ pop_progress : RequestId | None = request_id
254
260
try :
255
261
jsonrpc_request = JSONRPCRequest (
256
262
jsonrpc = "2.0" ,
@@ -285,11 +291,28 @@ async def send_request(
285
291
if isinstance (response_or_error , JSONRPCError ):
286
292
raise McpError (response_or_error .error )
287
293
else :
288
- return result_type .model_validate (response_or_error .result )
294
+ result = result_type .model_validate (response_or_error .result )
295
+ if isinstance (result , CallToolResult ) and result .operation is not None :
296
+ # Store mapping of operation token to request ID for async operations
297
+ self ._operation_requests [result .operation .token ] = request_id
298
+
299
+ # Don't pop the progress function if we were given one
300
+ pop_progress = None
301
+ elif isinstance (request , GetOperationPayloadRequest ) and isinstance (result , GetOperationPayloadResult ):
302
+ # Checked request and result to ensure no error
303
+ operation_token = request .params .token
304
+
305
+ # Pop the progress function for the original request
306
+ pop_progress = self ._operation_requests [operation_token ]
307
+
308
+ # Pop the token mapping since we know we won't need it anymore
309
+ self ._operation_requests .pop (operation_token , None )
310
+ return result
289
311
290
312
finally :
291
313
self ._response_streams .pop (request_id , None )
292
- self ._progress_callbacks .pop (request_id , None )
314
+ if pop_progress :
315
+ self ._progress_callbacks .pop (pop_progress , None )
293
316
await response_stream .aclose ()
294
317
await response_stream_reader .aclose ()
295
318
0 commit comments