@@ -77,6 +77,7 @@ async def main():
77
77
78
78
import anyio
79
79
import jsonschema
80
+ from anyio .abc import TaskGroup
80
81
from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
81
82
from pydantic import AnyUrl
82
83
from typing_extensions import TypeVar
@@ -252,7 +253,7 @@ def decorator(
252
253
253
254
wrapper = create_call_wrapper (func , types .ListPromptsRequest )
254
255
255
- async def handler (req : types .ListPromptsRequest ):
256
+ async def handler (req : types .ListPromptsRequest , _ : Any = None ):
256
257
result = await wrapper (req )
257
258
# Handle both old style (list[Prompt]) and new style (ListPromptsResult)
258
259
if isinstance (result , types .ListPromptsResult ):
@@ -272,7 +273,7 @@ def decorator(
272
273
):
273
274
logger .debug ("Registering handler for GetPromptRequest" )
274
275
275
- async def handler (req : types .GetPromptRequest ):
276
+ async def handler (req : types .GetPromptRequest , _ : Any = None ):
276
277
prompt_get = await func (req .params .name , req .params .arguments )
277
278
return types .ServerResult (prompt_get )
278
279
@@ -290,7 +291,7 @@ def decorator(
290
291
291
292
wrapper = create_call_wrapper (func , types .ListResourcesRequest )
292
293
293
- async def handler (req : types .ListResourcesRequest ):
294
+ async def handler (req : types .ListResourcesRequest , _ : Any = None ):
294
295
result = await wrapper (req )
295
296
# Handle both old style (list[Resource]) and new style (ListResourcesResult)
296
297
if isinstance (result , types .ListResourcesResult ):
@@ -308,7 +309,7 @@ def list_resource_templates(self):
308
309
def decorator (func : Callable [[], Awaitable [list [types .ResourceTemplate ]]]):
309
310
logger .debug ("Registering handler for ListResourceTemplatesRequest" )
310
311
311
- async def handler (_ : Any ):
312
+ async def handler (_1 : Any , _2 : Any = None ):
312
313
templates = await func ()
313
314
return types .ServerResult (types .ListResourceTemplatesResult (resourceTemplates = templates ))
314
315
@@ -323,7 +324,7 @@ def decorator(
323
324
):
324
325
logger .debug ("Registering handler for ReadResourceRequest" )
325
326
326
- async def handler (req : types .ReadResourceRequest ):
327
+ async def handler (req : types .ReadResourceRequest , _ : Any = None ):
327
328
result = await func (req .params .uri )
328
329
329
330
def create_content (data : str | bytes , mime_type : str | None ):
@@ -379,7 +380,7 @@ def set_logging_level(self):
379
380
def decorator (func : Callable [[types .LoggingLevel ], Awaitable [None ]]):
380
381
logger .debug ("Registering handler for SetLevelRequest" )
381
382
382
- async def handler (req : types .SetLevelRequest ):
383
+ async def handler (req : types .SetLevelRequest , _ : Any = None ):
383
384
await func (req .params .level )
384
385
return types .ServerResult (types .EmptyResult ())
385
386
@@ -392,7 +393,7 @@ def subscribe_resource(self):
392
393
def decorator (func : Callable [[AnyUrl ], Awaitable [None ]]):
393
394
logger .debug ("Registering handler for SubscribeRequest" )
394
395
395
- async def handler (req : types .SubscribeRequest ):
396
+ async def handler (req : types .SubscribeRequest , _ : Any = None ):
396
397
await func (req .params .uri )
397
398
return types .ServerResult (types .EmptyResult ())
398
399
@@ -405,7 +406,7 @@ def unsubscribe_resource(self):
405
406
def decorator (func : Callable [[AnyUrl ], Awaitable [None ]]):
406
407
logger .debug ("Registering handler for UnsubscribeRequest" )
407
408
408
- async def handler (req : types .UnsubscribeRequest ):
409
+ async def handler (req : types .UnsubscribeRequest , _ : Any = None ):
409
410
await func (req .params .uri )
410
411
return types .ServerResult (types .EmptyResult ())
411
412
@@ -423,7 +424,7 @@ def decorator(
423
424
424
425
wrapper = create_call_wrapper (func , types .ListToolsRequest )
425
426
426
- async def handler (req : types .ListToolsRequest ):
427
+ async def handler (req : types .ListToolsRequest , _ : Any = None ):
427
428
result = await wrapper (req )
428
429
429
430
# Handle both old style (list[Tool]) and new style (ListToolsResult)
@@ -493,7 +494,7 @@ def decorator(
493
494
):
494
495
logger .debug ("Registering handler for CallToolRequest" )
495
496
496
- async def handler (req : types .CallToolRequest ):
497
+ async def handler (req : types .CallToolRequest , server_scope : TaskGroup ):
497
498
try :
498
499
tool_name = req .params .name
499
500
arguments = req .params .arguments or {}
@@ -563,20 +564,20 @@ async def execute_async():
563
564
logger .exception (f"Async execution failed for { tool_name } " )
564
565
self .async_operations .fail_operation (operation .token , str (e ))
565
566
566
- async with anyio .create_task_group () as tg :
567
- tg .start_soon (execute_async )
568
-
569
- # Return operation result with immediate content
570
- logger .info (f"Returning async operation result for { tool_name } " )
571
- return types .ServerResult (
572
- types .CallToolResult (
573
- content = immediate_content ,
574
- operation = types .AsyncResultProperties (
575
- token = operation .token ,
576
- keepAlive = operation .keep_alive ,
577
- ),
578
- )
567
+ # Dispatch in concurrency scope of the server to run between requests
568
+ server_scope .start_soon (execute_async )
569
+
570
+ # Return operation result with immediate content
571
+ logger .info (f"Returning async operation result for { tool_name } " )
572
+ return types .ServerResult (
573
+ types .CallToolResult (
574
+ content = immediate_content ,
575
+ operation = types .AsyncResultProperties (
576
+ token = operation .token ,
577
+ keepAlive = operation .keep_alive ,
578
+ ),
579
579
)
580
+ )
580
581
581
582
# tool call
582
583
results = await func (tool_name , arguments )
@@ -690,7 +691,7 @@ def decorator(
690
691
):
691
692
logger .debug ("Registering handler for ProgressNotification" )
692
693
693
- async def handler (req : types .ProgressNotification ):
694
+ async def handler (req : types .ProgressNotification , _ : Any = None ):
694
695
await func (
695
696
req .params .progressToken ,
696
697
req .params .progress ,
@@ -718,7 +719,7 @@ def decorator(
718
719
):
719
720
logger .debug ("Registering handler for CompleteRequest" )
720
721
721
- async def handler (req : types .CompleteRequest ):
722
+ async def handler (req : types .CompleteRequest , _ : Any = None ):
722
723
completion = await func (req .params .ref , req .params .argument , req .params .context )
723
724
return types .ServerResult (
724
725
types .CompleteResult (
@@ -754,7 +755,7 @@ def get_operation_status(self):
754
755
def decorator (func : Callable [[str ], Awaitable [types .GetOperationStatusResult ]]):
755
756
logger .debug ("Registering handler for GetOperationStatusRequest" )
756
757
757
- async def handler (req : types .GetOperationStatusRequest ):
758
+ async def handler (req : types .GetOperationStatusRequest , _ : Any = None ):
758
759
# Validate token and get operation
759
760
operation = self ._validate_operation_token (req .params .token )
760
761
@@ -776,7 +777,7 @@ def get_operation_result(self):
776
777
def decorator (func : Callable [[str ], Awaitable [types .GetOperationPayloadResult ]]):
777
778
logger .debug ("Registering handler for GetOperationPayloadRequest" )
778
779
779
- async def handler (req : types .GetOperationPayloadRequest ):
780
+ async def handler (req : types .GetOperationPayloadRequest , _ : Any = None ):
780
781
# Validate token and get operation
781
782
operation = self ._validate_operation_token (req .params .token )
782
783
@@ -878,6 +879,7 @@ async def run(
878
879
session ,
879
880
lifespan_context ,
880
881
raise_exceptions ,
882
+ tg ,
881
883
)
882
884
finally :
883
885
# Cancel session operations and stop cleanup task
@@ -892,13 +894,16 @@ async def _handle_message(
892
894
session : ServerSession ,
893
895
lifespan_context : LifespanResultT ,
894
896
raise_exceptions : bool = False ,
897
+ server_scope : TaskGroup | None = None ,
895
898
):
896
899
with warnings .catch_warnings (record = True ) as w :
897
900
# TODO(Marcelo): We should be checking if message is Exception here.
898
901
match message : # type: ignore[reportMatchNotExhaustive]
899
902
case RequestResponder (request = types .ClientRequest (root = req )) as responder :
900
903
with responder :
901
- await self ._handle_request (message , req , session , lifespan_context , raise_exceptions )
904
+ await self ._handle_request (
905
+ message , req , session , lifespan_context , raise_exceptions , server_scope
906
+ )
902
907
case types .ClientNotification (root = notify ):
903
908
await self ._handle_notification (notify )
904
909
@@ -912,6 +917,7 @@ async def _handle_request(
912
917
session : ServerSession ,
913
918
lifespan_context : LifespanResultT ,
914
919
raise_exceptions : bool ,
920
+ server_scope : TaskGroup | None = None ,
915
921
):
916
922
logger .info ("Processing request of type %s" , type (req ).__name__ )
917
923
if handler := self .request_handlers .get (type (req )): # type: ignore
@@ -936,7 +942,7 @@ async def _handle_request(
936
942
request = request_data ,
937
943
)
938
944
)
939
- response = await handler (req )
945
+ response = await handler (req , server_scope )
940
946
941
947
# Track async operations for cancellation
942
948
if isinstance (req , types .CallToolRequest ):
@@ -985,5 +991,5 @@ async def _handle_notification(self, notify: Any):
985
991
logger .exception ("Uncaught exception in notification handler" )
986
992
987
993
988
- async def _ping_handler (request : types .PingRequest ) -> types .ServerResult :
994
+ async def _ping_handler (request : types .PingRequest , _ : Any = None ) -> types .ServerResult :
989
995
return types .ServerResult (types .EmptyResult ())
0 commit comments