Skip to content

Commit 76f135e

Browse files
committed
Use server TaskGroup to fix operations blocking CallTool requests
1 parent 2ed562e commit 76f135e

File tree

1 file changed

+36
-30
lines changed

1 file changed

+36
-30
lines changed

src/mcp/server/lowlevel/server.py

Lines changed: 36 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ async def main():
7777

7878
import anyio
7979
import jsonschema
80+
from anyio.abc import TaskGroup
8081
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
8182
from pydantic import AnyUrl
8283
from typing_extensions import TypeVar
@@ -252,7 +253,7 @@ def decorator(
252253

253254
wrapper = create_call_wrapper(func, types.ListPromptsRequest)
254255

255-
async def handler(req: types.ListPromptsRequest):
256+
async def handler(req: types.ListPromptsRequest, _: Any = None):
256257
result = await wrapper(req)
257258
# Handle both old style (list[Prompt]) and new style (ListPromptsResult)
258259
if isinstance(result, types.ListPromptsResult):
@@ -272,7 +273,7 @@ def decorator(
272273
):
273274
logger.debug("Registering handler for GetPromptRequest")
274275

275-
async def handler(req: types.GetPromptRequest):
276+
async def handler(req: types.GetPromptRequest, _: Any = None):
276277
prompt_get = await func(req.params.name, req.params.arguments)
277278
return types.ServerResult(prompt_get)
278279

@@ -290,7 +291,7 @@ def decorator(
290291

291292
wrapper = create_call_wrapper(func, types.ListResourcesRequest)
292293

293-
async def handler(req: types.ListResourcesRequest):
294+
async def handler(req: types.ListResourcesRequest, _: Any = None):
294295
result = await wrapper(req)
295296
# Handle both old style (list[Resource]) and new style (ListResourcesResult)
296297
if isinstance(result, types.ListResourcesResult):
@@ -308,7 +309,7 @@ def list_resource_templates(self):
308309
def decorator(func: Callable[[], Awaitable[list[types.ResourceTemplate]]]):
309310
logger.debug("Registering handler for ListResourceTemplatesRequest")
310311

311-
async def handler(_: Any):
312+
async def handler(_1: Any, _2: Any = None):
312313
templates = await func()
313314
return types.ServerResult(types.ListResourceTemplatesResult(resourceTemplates=templates))
314315

@@ -323,7 +324,7 @@ def decorator(
323324
):
324325
logger.debug("Registering handler for ReadResourceRequest")
325326

326-
async def handler(req: types.ReadResourceRequest):
327+
async def handler(req: types.ReadResourceRequest, _: Any = None):
327328
result = await func(req.params.uri)
328329

329330
def create_content(data: str | bytes, mime_type: str | None):
@@ -379,7 +380,7 @@ def set_logging_level(self):
379380
def decorator(func: Callable[[types.LoggingLevel], Awaitable[None]]):
380381
logger.debug("Registering handler for SetLevelRequest")
381382

382-
async def handler(req: types.SetLevelRequest):
383+
async def handler(req: types.SetLevelRequest, _: Any = None):
383384
await func(req.params.level)
384385
return types.ServerResult(types.EmptyResult())
385386

@@ -392,7 +393,7 @@ def subscribe_resource(self):
392393
def decorator(func: Callable[[AnyUrl], Awaitable[None]]):
393394
logger.debug("Registering handler for SubscribeRequest")
394395

395-
async def handler(req: types.SubscribeRequest):
396+
async def handler(req: types.SubscribeRequest, _: Any = None):
396397
await func(req.params.uri)
397398
return types.ServerResult(types.EmptyResult())
398399

@@ -405,7 +406,7 @@ def unsubscribe_resource(self):
405406
def decorator(func: Callable[[AnyUrl], Awaitable[None]]):
406407
logger.debug("Registering handler for UnsubscribeRequest")
407408

408-
async def handler(req: types.UnsubscribeRequest):
409+
async def handler(req: types.UnsubscribeRequest, _: Any = None):
409410
await func(req.params.uri)
410411
return types.ServerResult(types.EmptyResult())
411412

@@ -423,7 +424,7 @@ def decorator(
423424

424425
wrapper = create_call_wrapper(func, types.ListToolsRequest)
425426

426-
async def handler(req: types.ListToolsRequest):
427+
async def handler(req: types.ListToolsRequest, _: Any = None):
427428
result = await wrapper(req)
428429

429430
# Handle both old style (list[Tool]) and new style (ListToolsResult)
@@ -493,7 +494,7 @@ def decorator(
493494
):
494495
logger.debug("Registering handler for CallToolRequest")
495496

496-
async def handler(req: types.CallToolRequest):
497+
async def handler(req: types.CallToolRequest, server_scope: TaskGroup):
497498
try:
498499
tool_name = req.params.name
499500
arguments = req.params.arguments or {}
@@ -563,20 +564,20 @@ async def execute_async():
563564
logger.exception(f"Async execution failed for {tool_name}")
564565
self.async_operations.fail_operation(operation.token, str(e))
565566

566-
async with anyio.create_task_group() as tg:
567-
tg.start_soon(execute_async)
568-
569-
# Return operation result with immediate content
570-
logger.info(f"Returning async operation result for {tool_name}")
571-
return types.ServerResult(
572-
types.CallToolResult(
573-
content=immediate_content,
574-
operation=types.AsyncResultProperties(
575-
token=operation.token,
576-
keepAlive=operation.keep_alive,
577-
),
578-
)
567+
# Dispatch in concurrency scope of the server to run between requests
568+
server_scope.start_soon(execute_async)
569+
570+
# Return operation result with immediate content
571+
logger.info(f"Returning async operation result for {tool_name}")
572+
return types.ServerResult(
573+
types.CallToolResult(
574+
content=immediate_content,
575+
operation=types.AsyncResultProperties(
576+
token=operation.token,
577+
keepAlive=operation.keep_alive,
578+
),
579579
)
580+
)
580581

581582
# tool call
582583
results = await func(tool_name, arguments)
@@ -690,7 +691,7 @@ def decorator(
690691
):
691692
logger.debug("Registering handler for ProgressNotification")
692693

693-
async def handler(req: types.ProgressNotification):
694+
async def handler(req: types.ProgressNotification, _: Any = None):
694695
await func(
695696
req.params.progressToken,
696697
req.params.progress,
@@ -718,7 +719,7 @@ def decorator(
718719
):
719720
logger.debug("Registering handler for CompleteRequest")
720721

721-
async def handler(req: types.CompleteRequest):
722+
async def handler(req: types.CompleteRequest, _: Any = None):
722723
completion = await func(req.params.ref, req.params.argument, req.params.context)
723724
return types.ServerResult(
724725
types.CompleteResult(
@@ -754,7 +755,7 @@ def get_operation_status(self):
754755
def decorator(func: Callable[[str], Awaitable[types.GetOperationStatusResult]]):
755756
logger.debug("Registering handler for GetOperationStatusRequest")
756757

757-
async def handler(req: types.GetOperationStatusRequest):
758+
async def handler(req: types.GetOperationStatusRequest, _: Any = None):
758759
# Validate token and get operation
759760
operation = self._validate_operation_token(req.params.token)
760761

@@ -776,7 +777,7 @@ def get_operation_result(self):
776777
def decorator(func: Callable[[str], Awaitable[types.GetOperationPayloadResult]]):
777778
logger.debug("Registering handler for GetOperationPayloadRequest")
778779

779-
async def handler(req: types.GetOperationPayloadRequest):
780+
async def handler(req: types.GetOperationPayloadRequest, _: Any = None):
780781
# Validate token and get operation
781782
operation = self._validate_operation_token(req.params.token)
782783

@@ -878,6 +879,7 @@ async def run(
878879
session,
879880
lifespan_context,
880881
raise_exceptions,
882+
tg,
881883
)
882884
finally:
883885
# Cancel session operations and stop cleanup task
@@ -892,13 +894,16 @@ async def _handle_message(
892894
session: ServerSession,
893895
lifespan_context: LifespanResultT,
894896
raise_exceptions: bool = False,
897+
server_scope: TaskGroup | None = None,
895898
):
896899
with warnings.catch_warnings(record=True) as w:
897900
# TODO(Marcelo): We should be checking if message is Exception here.
898901
match message: # type: ignore[reportMatchNotExhaustive]
899902
case RequestResponder(request=types.ClientRequest(root=req)) as responder:
900903
with responder:
901-
await self._handle_request(message, req, session, lifespan_context, raise_exceptions)
904+
await self._handle_request(
905+
message, req, session, lifespan_context, raise_exceptions, server_scope
906+
)
902907
case types.ClientNotification(root=notify):
903908
await self._handle_notification(notify)
904909

@@ -912,6 +917,7 @@ async def _handle_request(
912917
session: ServerSession,
913918
lifespan_context: LifespanResultT,
914919
raise_exceptions: bool,
920+
server_scope: TaskGroup | None = None,
915921
):
916922
logger.info("Processing request of type %s", type(req).__name__)
917923
if handler := self.request_handlers.get(type(req)): # type: ignore
@@ -936,7 +942,7 @@ async def _handle_request(
936942
request=request_data,
937943
)
938944
)
939-
response = await handler(req)
945+
response = await handler(req, server_scope)
940946

941947
# Track async operations for cancellation
942948
if isinstance(req, types.CallToolRequest):
@@ -985,5 +991,5 @@ async def _handle_notification(self, notify: Any):
985991
logger.exception("Uncaught exception in notification handler")
986992

987993

988-
async def _ping_handler(request: types.PingRequest) -> types.ServerResult:
994+
async def _ping_handler(request: types.PingRequest, _: Any = None) -> types.ServerResult:
989995
return types.ServerResult(types.EmptyResult())

0 commit comments

Comments
 (0)