diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index b2f49fc8b..4e774984d 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -392,11 +392,17 @@ async def _receive_loop(self) -> None: # call it with the progress information if progress_token in self._progress_callbacks: callback = self._progress_callbacks[progress_token] - await callback( - notification.root.params.progress, - notification.root.params.total, - notification.root.params.message, - ) + try: + await callback( + notification.root.params.progress, + notification.root.params.total, + notification.root.params.message, + ) + except Exception as e: + logging.error( + "Progress callback raised an exception: %s", + e, + ) await self._received_notification(notification) await self._handle_incoming(notification) except Exception as e: diff --git a/tests/shared/test_progress_notifications.py b/tests/shared/test_progress_notifications.py index d3aabba20..600972272 100644 --- a/tests/shared/test_progress_notifications.py +++ b/tests/shared/test_progress_notifications.py @@ -1,4 +1,5 @@ from typing import Any, cast +from unittest.mock import patch import anyio import pytest @@ -10,6 +11,7 @@ from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession from mcp.shared.context import RequestContext +from mcp.shared.memory import create_connected_server_and_client_session from mcp.shared.progress import progress from mcp.shared.session import BaseSession, RequestResponder, SessionMessage @@ -320,3 +322,69 @@ async def handle_client_message( assert server_progress_updates[3]["progress"] == 100 assert server_progress_updates[3]["total"] == 100 assert server_progress_updates[3]["message"] == "Processing results..." + + +@pytest.mark.anyio +async def test_progress_callback_exception_logging(): + """Test that exceptions in progress callbacks are logged and \ + don't crash the session.""" + # Track logged warnings + logged_errors: list[str] = [] + + def mock_log_error(msg: str, *args: Any) -> None: + logged_errors.append(msg % args if args else msg) + + # Create a progress callback that raises an exception + async def failing_progress_callback(progress: float, total: float | None, message: str | None) -> None: + raise ValueError("Progress callback failed!") + + # Create a server with a tool that sends progress notifications + server = Server(name="TestProgressServer") + + @server.call_tool() + async def handle_call_tool(name: str, arguments: Any) -> list[types.TextContent]: + if name == "progress_tool": + # Send a progress notification + await server.request_context.session.send_progress_notification( + progress_token=server.request_context.request_id, + progress=50.0, + total=100.0, + message="Halfway done", + ) + return [types.TextContent(type="text", text="progress_result")] + raise ValueError(f"Unknown tool: {name}") + + @server.list_tools() + async def handle_list_tools() -> list[types.Tool]: + return [ + types.Tool( + name="progress_tool", + description="A tool that sends progress notifications", + inputSchema={}, + ) + ] + + # Test with mocked logging + with patch("mcp.shared.session.logging.error", side_effect=mock_log_error): + async with create_connected_server_and_client_session(server) as client_session: + # Send a request with a failing progress callback + result = await client_session.send_request( + types.ClientRequest( + types.CallToolRequest( + method="tools/call", + params=types.CallToolRequestParams(name="progress_tool", arguments={}), + ) + ), + types.CallToolResult, + progress_callback=failing_progress_callback, + ) + + # Verify the request completed successfully despite the callback failure + assert len(result.content) == 1 + content = result.content[0] + assert isinstance(content, types.TextContent) + assert content.text == "progress_result" + + # Check that a warning was logged for the progress callback exception + assert len(logged_errors) > 0 + assert any("Progress callback raised an exception" in warning for warning in logged_errors)