|
1 | 1 | from typing import Any, cast
|
| 2 | +from unittest.mock import patch |
2 | 3 |
|
3 | 4 | import anyio
|
4 | 5 | import pytest
|
|
10 | 11 | from mcp.server.models import InitializationOptions
|
11 | 12 | from mcp.server.session import ServerSession
|
12 | 13 | from mcp.shared.context import RequestContext
|
| 14 | +from mcp.shared.memory import create_connected_server_and_client_session |
13 | 15 | from mcp.shared.progress import progress
|
14 | 16 | from mcp.shared.session import BaseSession, RequestResponder, SessionMessage
|
15 | 17 |
|
@@ -320,3 +322,69 @@ async def handle_client_message(
|
320 | 322 | assert server_progress_updates[3]["progress"] == 100
|
321 | 323 | assert server_progress_updates[3]["total"] == 100
|
322 | 324 | assert server_progress_updates[3]["message"] == "Processing results..."
|
| 325 | + |
| 326 | + |
| 327 | +@pytest.mark.anyio |
| 328 | +async def test_progress_callback_exception_logging(): |
| 329 | + """Test that exceptions in progress callbacks are logged and \ |
| 330 | + don't crash the session.""" |
| 331 | + # Track logged warnings |
| 332 | + logged_errors: list[str] = [] |
| 333 | + |
| 334 | + def mock_log_error(msg: str, *args: Any) -> None: |
| 335 | + logged_errors.append(msg % args if args else msg) |
| 336 | + |
| 337 | + # Create a progress callback that raises an exception |
| 338 | + async def failing_progress_callback(progress: float, total: float | None, message: str | None) -> None: |
| 339 | + raise ValueError("Progress callback failed!") |
| 340 | + |
| 341 | + # Create a server with a tool that sends progress notifications |
| 342 | + server = Server(name="TestProgressServer") |
| 343 | + |
| 344 | + @server.call_tool() |
| 345 | + async def handle_call_tool(name: str, arguments: Any) -> list[types.TextContent]: |
| 346 | + if name == "progress_tool": |
| 347 | + # Send a progress notification |
| 348 | + await server.request_context.session.send_progress_notification( |
| 349 | + progress_token=server.request_context.request_id, |
| 350 | + progress=50.0, |
| 351 | + total=100.0, |
| 352 | + message="Halfway done", |
| 353 | + ) |
| 354 | + return [types.TextContent(type="text", text="progress_result")] |
| 355 | + raise ValueError(f"Unknown tool: {name}") |
| 356 | + |
| 357 | + @server.list_tools() |
| 358 | + async def handle_list_tools() -> list[types.Tool]: |
| 359 | + return [ |
| 360 | + types.Tool( |
| 361 | + name="progress_tool", |
| 362 | + description="A tool that sends progress notifications", |
| 363 | + inputSchema={}, |
| 364 | + ) |
| 365 | + ] |
| 366 | + |
| 367 | + # Test with mocked logging |
| 368 | + with patch("mcp.shared.session.logging.error", side_effect=mock_log_error): |
| 369 | + async with create_connected_server_and_client_session(server) as client_session: |
| 370 | + # Send a request with a failing progress callback |
| 371 | + result = await client_session.send_request( |
| 372 | + types.ClientRequest( |
| 373 | + types.CallToolRequest( |
| 374 | + method="tools/call", |
| 375 | + params=types.CallToolRequestParams(name="progress_tool", arguments={}), |
| 376 | + ) |
| 377 | + ), |
| 378 | + types.CallToolResult, |
| 379 | + progress_callback=failing_progress_callback, |
| 380 | + ) |
| 381 | + |
| 382 | + # Verify the request completed successfully despite the callback failure |
| 383 | + assert len(result.content) == 1 |
| 384 | + content = result.content[0] |
| 385 | + assert isinstance(content, types.TextContent) |
| 386 | + assert content.text == "progress_result" |
| 387 | + |
| 388 | + # Check that a warning was logged for the progress callback exception |
| 389 | + assert len(logged_errors) > 0 |
| 390 | + assert any("Progress callback raised an exception" in warning for warning in logged_errors) |
0 commit comments