|  | 
| 1 | 1 | from typing import Any, cast | 
|  | 2 | +from unittest.mock import patch | 
| 2 | 3 | 
 | 
| 3 | 4 | import anyio | 
| 4 | 5 | import pytest | 
|  | 
| 10 | 11 | from mcp.server.models import InitializationOptions | 
| 11 | 12 | from mcp.server.session import ServerSession | 
| 12 | 13 | from mcp.shared.context import RequestContext | 
|  | 14 | +from mcp.shared.memory import create_connected_server_and_client_session | 
| 13 | 15 | from mcp.shared.progress import progress | 
| 14 | 16 | from mcp.shared.session import BaseSession, RequestResponder, SessionMessage | 
| 15 | 17 | 
 | 
| @@ -320,3 +322,69 @@ async def handle_client_message( | 
| 320 | 322 |     assert server_progress_updates[3]["progress"] == 100 | 
| 321 | 323 |     assert server_progress_updates[3]["total"] == 100 | 
| 322 | 324 |     assert server_progress_updates[3]["message"] == "Processing results..." | 
|  | 325 | + | 
|  | 326 | + | 
|  | 327 | +@pytest.mark.anyio | 
|  | 328 | +async def test_progress_callback_exception_logging(): | 
|  | 329 | +    """Test that exceptions in progress callbacks are logged and \ | 
|  | 330 | +        don't crash the session.""" | 
|  | 331 | +    # Track logged warnings | 
|  | 332 | +    logged_errors: list[str] = [] | 
|  | 333 | + | 
|  | 334 | +    def mock_log_error(msg: str, *args: Any) -> None: | 
|  | 335 | +        logged_errors.append(msg % args if args else msg) | 
|  | 336 | + | 
|  | 337 | +    # Create a progress callback that raises an exception | 
|  | 338 | +    async def failing_progress_callback(progress: float, total: float | None, message: str | None) -> None: | 
|  | 339 | +        raise ValueError("Progress callback failed!") | 
|  | 340 | + | 
|  | 341 | +    # Create a server with a tool that sends progress notifications | 
|  | 342 | +    server = Server(name="TestProgressServer") | 
|  | 343 | + | 
|  | 344 | +    @server.call_tool() | 
|  | 345 | +    async def handle_call_tool(name: str, arguments: Any) -> list[types.TextContent]: | 
|  | 346 | +        if name == "progress_tool": | 
|  | 347 | +            # Send a progress notification | 
|  | 348 | +            await server.request_context.session.send_progress_notification( | 
|  | 349 | +                progress_token=server.request_context.request_id, | 
|  | 350 | +                progress=50.0, | 
|  | 351 | +                total=100.0, | 
|  | 352 | +                message="Halfway done", | 
|  | 353 | +            ) | 
|  | 354 | +            return [types.TextContent(type="text", text="progress_result")] | 
|  | 355 | +        raise ValueError(f"Unknown tool: {name}") | 
|  | 356 | + | 
|  | 357 | +    @server.list_tools() | 
|  | 358 | +    async def handle_list_tools() -> list[types.Tool]: | 
|  | 359 | +        return [ | 
|  | 360 | +            types.Tool( | 
|  | 361 | +                name="progress_tool", | 
|  | 362 | +                description="A tool that sends progress notifications", | 
|  | 363 | +                inputSchema={}, | 
|  | 364 | +            ) | 
|  | 365 | +        ] | 
|  | 366 | + | 
|  | 367 | +    # Test with mocked logging | 
|  | 368 | +    with patch("mcp.shared.session.logging.error", side_effect=mock_log_error): | 
|  | 369 | +        async with create_connected_server_and_client_session(server) as client_session: | 
|  | 370 | +            # Send a request with a failing progress callback | 
|  | 371 | +            result = await client_session.send_request( | 
|  | 372 | +                types.ClientRequest( | 
|  | 373 | +                    types.CallToolRequest( | 
|  | 374 | +                        method="tools/call", | 
|  | 375 | +                        params=types.CallToolRequestParams(name="progress_tool", arguments={}), | 
|  | 376 | +                    ) | 
|  | 377 | +                ), | 
|  | 378 | +                types.CallToolResult, | 
|  | 379 | +                progress_callback=failing_progress_callback, | 
|  | 380 | +            ) | 
|  | 381 | + | 
|  | 382 | +            # Verify the request completed successfully despite the callback failure | 
|  | 383 | +            assert len(result.content) == 1 | 
|  | 384 | +            content = result.content[0] | 
|  | 385 | +            assert isinstance(content, types.TextContent) | 
|  | 386 | +            assert content.text == "progress_result" | 
|  | 387 | + | 
|  | 388 | +            # Check that a warning was logged for the progress callback exception | 
|  | 389 | +            assert len(logged_errors) > 0 | 
|  | 390 | +            assert any("Progress callback raised an exception" in warning for warning in logged_errors) | 
0 commit comments