Skip to content

Commit 17bef50

Browse files
committed
Fully switch AsyncOperationManager to anyio
1 parent c2f8bb1 commit 17bef50

File tree

4 files changed

+33
-116
lines changed

4 files changed

+33
-116
lines changed

src/mcp/client/session.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
from datetime import timedelta
3-
from typing import Any, Protocol
3+
from typing import Any, Protocol, Self
44

55
import anyio
66
import anyio.lowlevel
@@ -139,6 +139,12 @@ def __init__(
139139
self._tool_output_schemas: dict[str, dict[str, Any] | None] = {}
140140
self._operation_manager = ClientAsyncOperationManager()
141141

142+
async def __aenter__(self) -> Self:
143+
await super().__aenter__()
144+
self._task_group.start_soon(self._operation_manager.cleanup_loop)
145+
self._exit_stack.push_async_callback(lambda: self._operation_manager.stop_cleanup_loop())
146+
return self
147+
142148
async def initialize(self) -> types.InitializeResult:
143149
sampling = types.SamplingCapability() if self._sampling_callback is not _default_sampling_callback else None
144150
elicitation = (
@@ -176,15 +182,8 @@ async def initialize(self) -> types.InitializeResult:
176182

177183
await self.send_notification(types.ClientNotification(types.InitializedNotification()))
178184

179-
# Start cleanup task for operations
180-
await self._operation_manager.start_cleanup_task()
181-
182185
return result
183186

184-
async def close(self) -> None:
185-
"""Clean up resources."""
186-
await self._operation_manager.stop_cleanup_task()
187-
188187
async def send_ping(self) -> types.EmptyResult:
189188
"""Send a ping request."""
190189
return await self.send_request(

src/mcp/server/lowlevel/server.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -865,11 +865,10 @@ async def run(
865865
)
866866
)
867867

868-
# Start async operations cleanup task
869-
await self.async_operations.start_cleanup_task()
868+
async with anyio.create_task_group() as tg:
869+
tg.start_soon(self.async_operations.cleanup_loop)
870870

871-
try:
872-
async with anyio.create_task_group() as tg:
871+
try:
873872
async for message in session.incoming_messages:
874873
logger.debug("Received message: %s", message)
875874

@@ -881,9 +880,12 @@ async def run(
881880
raise_exceptions,
882881
tg,
883882
)
884-
finally:
885-
# Stop cleanup task
886-
await self.async_operations.stop_cleanup_task()
883+
finally:
884+
# Stop cleanup loop before task group exits
885+
await self.async_operations.stop_cleanup_loop()
886+
887+
# Cancel all remaining tasks in the task group (cleanup loop and potentially LROs)
888+
tg.cancel_scope.cancel()
887889

888890
async def _handle_message(
889891
self,

src/mcp/shared/async_operations.py

Lines changed: 17 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22

33
from __future__ import annotations
44

5-
import asyncio
5+
import logging
66
import secrets
77
import time
88
from collections.abc import Callable
99
from dataclasses import dataclass
1010
from typing import Any, Generic, TypeVar
1111

12+
import anyio
13+
1214
import mcp.types as types
1315
from mcp.types import AsyncOperationStatus
1416

@@ -66,9 +68,9 @@ class BaseOperationManager(Generic[OperationT]):
6668

6769
def __init__(self, *, token_generator: Callable[[str | None], str] | None = None):
6870
self._operations: dict[str, OperationT] = {}
69-
self._cleanup_task: asyncio.Task[None] | None = None
7071
self._cleanup_interval = 60 # Cleanup every 60 seconds
7172
self._token_generator = token_generator or self._default_token_generator
73+
self._running = False
7274

7375
def _default_token_generator(self, session_id: str | None = None) -> str:
7476
"""Default token generation using random tokens."""
@@ -105,31 +107,20 @@ def cleanup_expired(self) -> int:
105107
self._remove_operation(token)
106108
return len(expired_tokens)
107109

108-
async def start_cleanup_task(self) -> None:
109-
"""Start the background cleanup task."""
110-
if self._cleanup_task is None:
111-
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
112-
113-
async def stop_cleanup_task(self) -> None:
114-
"""Stop the background cleanup task."""
115-
if self._cleanup_task:
116-
self._cleanup_task.cancel()
117-
try:
118-
await self._cleanup_task
119-
except asyncio.CancelledError:
120-
pass
121-
self._cleanup_task = None
122-
123-
async def _cleanup_loop(self) -> None:
110+
async def stop_cleanup_loop(self) -> None:
111+
self._running = False
112+
113+
async def cleanup_loop(self) -> None:
124114
"""Background task to clean up expired operations."""
125-
while True:
126-
try:
127-
await asyncio.sleep(self._cleanup_interval)
128-
count = self.cleanup_expired()
129-
if count > 0:
130-
print(f"Cleaned up {count} expired operations")
131-
except asyncio.CancelledError:
132-
break
115+
if self._running:
116+
return
117+
self._running = True
118+
119+
while self._running:
120+
await anyio.sleep(self._cleanup_interval)
121+
count = self.cleanup_expired()
122+
if count > 0:
123+
logging.debug(f"Cleaned up {count} expired operations")
133124

134125

135126
class ClientAsyncOperationManager(BaseOperationManager[ClientAsyncOperation]):
@@ -292,32 +283,3 @@ def mark_input_completed(self, token: str) -> bool:
292283

293284
operation.status = "working"
294285
return True
295-
296-
async def start_cleanup_task(self) -> None:
297-
"""Start the background cleanup task."""
298-
if self._cleanup_task is not None:
299-
return
300-
301-
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
302-
303-
async def stop_cleanup_task(self) -> None:
304-
"""Stop the background cleanup task."""
305-
if self._cleanup_task is not None:
306-
self._cleanup_task.cancel()
307-
try:
308-
await self._cleanup_task
309-
except asyncio.CancelledError:
310-
pass
311-
self._cleanup_task = None
312-
313-
async def _cleanup_loop(self) -> None:
314-
"""Background cleanup loop."""
315-
while True:
316-
try:
317-
await asyncio.sleep(self._cleanup_interval)
318-
self.cleanup_expired_operations()
319-
except asyncio.CancelledError:
320-
break
321-
except Exception:
322-
# Log error but continue cleanup loop
323-
pass

tests/shared/test_async_operations.py

Lines changed: 0 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
from typing import Any, cast
66
from unittest.mock import Mock
77

8-
import pytest
9-
108
import mcp.types as types
119
from mcp.shared.async_operations import ServerAsyncOperation, ServerAsyncOperationManager
1210
from mcp.types import AsyncOperationStatus
@@ -213,50 +211,6 @@ def test_concurrent_operations(self):
213211
removed_count = manager.cleanup_expired_operations()
214212
assert removed_count == 25 and len(manager._operations) == 25
215213

216-
@pytest.mark.anyio
217-
async def test_cleanup_task_lifecycle(self):
218-
"""Test background cleanup task management."""
219-
manager = ServerAsyncOperationManager()
220-
221-
await manager.start_cleanup_task()
222-
assert manager._cleanup_task is not None and not manager._cleanup_task.done()
223-
224-
# Starting again should be no-op
225-
await manager.start_cleanup_task()
226-
227-
await manager.stop_cleanup_task()
228-
assert manager._cleanup_task is None
229-
230-
def test_dependency_injection_and_integration(self):
231-
"""Test AsyncOperationManager dependency injection and server integration."""
232-
from mcp.server.fastmcp import FastMCP
233-
from mcp.server.lowlevel import Server
234-
235-
# Test custom manager injection
236-
custom_manager = ServerAsyncOperationManager()
237-
operation = custom_manager.create_operation("shared_tool", {"data": "shared"}, session_id="session1")
238-
239-
# Test FastMCP integration
240-
fastmcp = FastMCP("FastMCP", async_operations=custom_manager)
241-
assert fastmcp._async_operations is custom_manager
242-
assert fastmcp._async_operations.get_operation(operation.token) is operation
243-
244-
# Test lowlevel Server integration
245-
lowlevel = Server("LowLevel", async_operations=custom_manager)
246-
assert lowlevel.async_operations is custom_manager
247-
assert lowlevel.async_operations.get_operation(operation.token) is operation
248-
249-
# Test default creation
250-
default_fastmcp = FastMCP("Default")
251-
default_server = Server("Default")
252-
assert isinstance(default_fastmcp._async_operations, ServerAsyncOperationManager)
253-
assert isinstance(default_server.async_operations, ServerAsyncOperationManager)
254-
assert default_fastmcp._async_operations is not custom_manager
255-
256-
# Test shared manager between servers
257-
new_op = fastmcp._async_operations.create_operation("new_tool", {}, session_id="session2")
258-
assert lowlevel.async_operations.get_operation(new_op.token) is new_op
259-
260214

261215
class TestAsyncOperation:
262216
"""Test AsyncOperation dataclass."""

0 commit comments

Comments
 (0)