Skip to content

Commit 40cf77e

Browse files
committed
Convert ServerAsyncOperationManager into async context manager
1 parent 5f422e7 commit 40cf77e

File tree

5 files changed

+148
-64
lines changed

5 files changed

+148
-64
lines changed

src/mcp/server/fastmcp/server.py

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations as _annotations
44

5+
import contextlib
56
import inspect
67
import re
78
from collections.abc import AsyncIterator, Awaitable, Callable, Collection, Iterable, Sequence
@@ -845,14 +846,21 @@ def decorator(
845846

846847
return decorator
847848

849+
@contextlib.asynccontextmanager
850+
async def _stdio_lifespan(self) -> AsyncIterator[None]:
851+
"""Lifespan that manages stdio operations."""
852+
async with self._async_operations.run():
853+
yield
854+
848855
async def run_stdio_async(self) -> None:
849856
"""Run the server using stdio transport."""
850857
async with stdio_server() as (read_stream, write_stream):
851-
await self._mcp_server.run(
852-
read_stream,
853-
write_stream,
854-
self._mcp_server.create_initialization_options(),
855-
)
858+
async with self._stdio_lifespan():
859+
await self._mcp_server.run(
860+
read_stream,
861+
write_stream,
862+
self._mcp_server.create_initialization_options(),
863+
)
856864

857865
async def run_sse_async(self, mount_path: str | None = None) -> None:
858866
"""Run the server using SSE transport."""
@@ -910,6 +918,12 @@ def _normalize_path(self, mount_path: str, endpoint: str) -> str:
910918
# Combine paths
911919
return mount_path + endpoint
912920

921+
@contextlib.asynccontextmanager
922+
async def _sse_lifespan(self) -> AsyncIterator[None]:
923+
"""Lifespan that manages SSE operations."""
924+
async with self._async_operations.run():
925+
yield
926+
913927
def sse_app(self, mount_path: str | None = None) -> Starlette:
914928
"""Return an instance of the SSE server app."""
915929
from starlette.middleware import Middleware
@@ -1040,7 +1054,16 @@ async def sse_endpoint(request: Request) -> Response:
10401054
routes.extend(self._custom_starlette_routes)
10411055

10421056
# Create Starlette app with routes and middleware
1043-
return Starlette(debug=self.settings.debug, routes=routes, middleware=middleware)
1057+
return Starlette(
1058+
debug=self.settings.debug, routes=routes, middleware=middleware, lifespan=lambda app: self._sse_lifespan()
1059+
)
1060+
1061+
@contextlib.asynccontextmanager
1062+
async def _streamable_http_lifespan(self) -> AsyncIterator[None]:
1063+
"""Lifespan that manages Streamable HTTP operations."""
1064+
async with self.session_manager.run():
1065+
async with self._async_operations.run():
1066+
yield
10441067

10451068
def streamable_http_app(self) -> Starlette:
10461069
"""Return an instance of the StreamableHTTP server app."""
@@ -1135,7 +1158,7 @@ def streamable_http_app(self) -> Starlette:
11351158
debug=self.settings.debug,
11361159
routes=routes,
11371160
middleware=middleware,
1138-
lifespan=lambda app: self.session_manager.run(),
1161+
lifespan=lambda app: self._streamable_http_lifespan(),
11391162
)
11401163

11411164
async def list_prompts(self) -> list[MCPPrompt]:
@@ -1337,12 +1360,17 @@ async def log(
13371360
logger_name: Optional logger name
13381361
**extra: Additional structured data to include
13391362
"""
1340-
await self.request_context.session.send_log_message(
1341-
level=level,
1342-
data=message,
1343-
logger=logger_name,
1344-
related_request_id=self.request_id,
1345-
)
1363+
try:
1364+
await self.request_context.session.send_log_message(
1365+
level=level,
1366+
data=message,
1367+
logger=logger_name,
1368+
related_request_id=self.request_id,
1369+
)
1370+
except Exception:
1371+
# Session might be closed (e.g., client disconnected)
1372+
logger.warning(f"Failed to send log message to client (session closed?): {message}")
1373+
pass
13461374

13471375
@property
13481376
def client_id(self) -> str | None:

src/mcp/server/lowlevel/server.py

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ async def main():
9595
from mcp.types import Operation, RequestId
9696

9797
logger = logging.getLogger(__name__)
98+
logger.setLevel(logging.DEBUG)
9899

99100
LifespanResultT = TypeVar("LifespanResultT", default=Any)
100101
RequestT = TypeVar("RequestT", default=Any)
@@ -564,8 +565,11 @@ async def execute_async():
564565
logger.exception(f"Async execution failed for {tool_name}")
565566
self.async_operations.fail_operation(operation.token, str(e))
566567

567-
# Dispatch in concurrency scope of the server to run between requests
568-
server_scope.start_soon(execute_async)
568+
# Start task directly in independent task group
569+
current_request_context = request_ctx.get()
570+
self.async_operations.start_task(
571+
operation.token, execute_async, current_request_context, request_ctx
572+
)
569573

570574
# Return operation result with immediate content
571575
logger.info(f"Returning async operation result for {tool_name}")
@@ -866,26 +870,17 @@ async def run(
866870
)
867871

868872
async with anyio.create_task_group() as tg:
869-
tg.start_soon(self.async_operations.cleanup_loop)
870-
871-
try:
872-
async for message in session.incoming_messages:
873-
logger.debug("Received message: %s", message)
874-
875-
tg.start_soon(
876-
self._handle_message,
877-
message,
878-
session,
879-
lifespan_context,
880-
raise_exceptions,
881-
tg,
882-
)
883-
finally:
884-
# Stop cleanup loop before task group exits
885-
await self.async_operations.stop_cleanup_loop()
886-
887-
# Cancel all remaining tasks in the task group (cleanup loop and potentially LROs)
888-
tg.cancel_scope.cancel()
873+
async for message in session.incoming_messages:
874+
logger.debug("Received message: %s", message)
875+
876+
tg.start_soon(
877+
self._handle_message,
878+
message,
879+
session,
880+
lifespan_context,
881+
raise_exceptions,
882+
tg,
883+
)
889884

890885
async def _handle_message(
891886
self,

src/mcp/shared/async_operations.py

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,22 @@
22

33
from __future__ import annotations
44

5+
import contextlib
56
import logging
67
import secrets
78
import time
8-
from collections.abc import Callable
9+
from collections.abc import AsyncIterator, Awaitable, Callable
910
from dataclasses import dataclass
1011
from typing import Any, Generic, TypeVar
1112

1213
import anyio
14+
from anyio.abc import TaskGroup
1315

1416
import mcp.types as types
1517
from mcp.types import AsyncOperationStatus
1618

19+
logger = logging.getLogger(__name__)
20+
1721

1822
@dataclass
1923
class ClientAsyncOperation:
@@ -120,7 +124,7 @@ async def cleanup_loop(self) -> None:
120124
await anyio.sleep(self._cleanup_interval)
121125
count = self.cleanup_expired()
122126
if count > 0:
123-
logging.debug(f"Cleaned up {count} expired operations")
127+
logger.debug(f"Cleaned up {count} expired operations")
124128

125129

126130
class ClientAsyncOperationManager(BaseOperationManager[ClientAsyncOperation]):
@@ -145,6 +149,63 @@ def get_tool_name(self, token: str) -> str | None:
145149
class ServerAsyncOperationManager(BaseOperationManager[ServerAsyncOperation]):
146150
"""Manages async tool operations with token-based tracking."""
147151

152+
def __init__(self, *, token_generator: Callable[[str | None], str] | None = None):
153+
super().__init__(token_generator=token_generator)
154+
self._task_group: TaskGroup | None = None
155+
self._run_lock = anyio.Lock()
156+
self._running = False
157+
158+
@contextlib.asynccontextmanager
159+
async def run(self) -> AsyncIterator[None]:
160+
"""Run the async operations manager with its own task group."""
161+
# Thread-safe check to ensure run() is only called once
162+
async with self._run_lock:
163+
if self._running:
164+
raise RuntimeError("ServerAsyncOperationManager.run() is already running.")
165+
self._running = True
166+
167+
async with anyio.create_task_group() as tg:
168+
self._task_group = tg
169+
logger.info("ServerAsyncOperationManager started")
170+
# Start cleanup loop
171+
tg.start_soon(self.cleanup_loop)
172+
try:
173+
yield
174+
finally:
175+
logger.info("ServerAsyncOperationManager shutting down")
176+
# Stop cleanup loop gracefully
177+
await self.stop_cleanup_loop()
178+
# Cancel task group to stop all spawned tasks
179+
tg.cancel_scope.cancel()
180+
self._task_group = None
181+
self._running = False
182+
183+
def start_task(
184+
self,
185+
token: str,
186+
task_func: Callable[[], Awaitable[None]],
187+
request_context: Any = None,
188+
request_ctx_var: Any = None,
189+
) -> None:
190+
"""Start an async task immediately in the independent task group."""
191+
if self._task_group is None:
192+
raise RuntimeError("Task group not started. Call run() first.")
193+
194+
async def run_task_with_context():
195+
context_token = None
196+
try:
197+
if request_context and request_ctx_var:
198+
context_token = request_ctx_var.set(request_context)
199+
await task_func()
200+
except Exception:
201+
# Handle task failures gracefully
202+
pass
203+
finally:
204+
if context_token and request_ctx_var:
205+
request_ctx_var.reset(context_token)
206+
207+
self._task_group.start_soon(run_task_with_context, name=f"lro_{token}")
208+
148209
def create_operation(
149210
self,
150211
tool_name: str,

src/mcp/shared/memory.py

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -71,30 +71,31 @@ async def create_connected_server_and_client_session(
7171
server_read, server_write = server_streams
7272

7373
# Create a cancel scope for the server task
74-
async with anyio.create_task_group() as tg:
75-
tg.start_soon(
76-
lambda: server.run(
77-
server_read,
78-
server_write,
79-
server.create_initialization_options(),
80-
raise_exceptions=raise_exceptions,
74+
async with server.async_operations.run():
75+
async with anyio.create_task_group() as tg:
76+
tg.start_soon(
77+
lambda: server.run(
78+
server_read,
79+
server_write,
80+
server.create_initialization_options(),
81+
raise_exceptions=raise_exceptions,
82+
)
8183
)
82-
)
83-
84-
try:
85-
async with ClientSession(
86-
read_stream=client_read,
87-
write_stream=client_write,
88-
read_timeout_seconds=read_timeout_seconds,
89-
sampling_callback=sampling_callback,
90-
list_roots_callback=list_roots_callback,
91-
logging_callback=logging_callback,
92-
message_handler=message_handler,
93-
client_info=client_info,
94-
elicitation_callback=elicitation_callback,
95-
protocol_version=protocol_version,
96-
) as client_session:
97-
await client_session.initialize()
98-
yield client_session
99-
finally:
100-
tg.cancel_scope.cancel()
84+
85+
try:
86+
async with ClientSession(
87+
read_stream=client_read,
88+
write_stream=client_write,
89+
read_timeout_seconds=read_timeout_seconds,
90+
sampling_callback=sampling_callback,
91+
list_roots_callback=list_roots_callback,
92+
logging_callback=logging_callback,
93+
message_handler=message_handler,
94+
client_info=client_info,
95+
elicitation_callback=elicitation_callback,
96+
protocol_version=protocol_version,
97+
) as client_session:
98+
await client_session.initialize()
99+
yield client_session
100+
finally:
101+
tg.cancel_scope.cancel()

tests/server/fastmcp/test_integration.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1013,7 +1013,6 @@ async def test_immediate_result_backward_compatibility(server_transport: str, se
10131013
await anyio.sleep(0.5)
10141014
else:
10151015
pytest.fail("Async operation timed out")
1016-
await anyio.sleep(0.01)
10171016

10181017

10191018
# Test async progress notifications

0 commit comments

Comments
 (0)