Skip to content

Commit b4e50aa

Browse files
Handles message type Exception in lowlevel/server.py _handle_message function. Mentioned as TODO on line 528. (modelcontextprotocol#786)
Co-authored-by: Felix Weinberger <[email protected]> Co-authored-by: Felix Weinberger <[email protected]>
1 parent 0e29cc4 commit b4e50aa

File tree

2 files changed

+84
-2
lines changed

2 files changed

+84
-2
lines changed

src/mcp/server/lowlevel/server.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -642,13 +642,21 @@ async def _handle_message(
642642
raise_exceptions: bool = False,
643643
):
644644
with warnings.catch_warnings(record=True) as w:
645-
# TODO(Marcelo): We should be checking if message is Exception here.
646-
match message: # type: ignore[reportMatchNotExhaustive]
645+
match message:
647646
case RequestResponder(request=types.ClientRequest(root=req)) as responder:
648647
with responder:
649648
await self._handle_request(message, req, session, lifespan_context, raise_exceptions)
650649
case types.ClientNotification(root=notify):
651650
await self._handle_notification(notify)
651+
case Exception():
652+
logger.error(f"Received exception from stream: {message}")
653+
await session.send_log_message(
654+
level="error",
655+
data="Internal Server Error",
656+
logger="mcp.server.exception_handler",
657+
)
658+
if raise_exceptions:
659+
raise message
652660

653661
for warning in w:
654662
logger.info("Warning: %s: %s", warning.category.__name__, warning.message)
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
from unittest.mock import AsyncMock, Mock
2+
3+
import pytest
4+
5+
import mcp.types as types
6+
from mcp.server.lowlevel.server import Server
7+
from mcp.server.session import ServerSession
8+
from mcp.shared.session import RequestResponder
9+
10+
11+
@pytest.mark.anyio
12+
async def test_exception_handling_with_raise_exceptions_true():
13+
"""Test that exceptions are re-raised when raise_exceptions=True"""
14+
server = Server("test-server")
15+
session = Mock(spec=ServerSession)
16+
session.send_log_message = AsyncMock()
17+
18+
test_exception = RuntimeError("Test error")
19+
20+
with pytest.raises(RuntimeError, match="Test error"):
21+
await server._handle_message(test_exception, session, {}, raise_exceptions=True)
22+
23+
session.send_log_message.assert_called_once()
24+
25+
26+
@pytest.mark.anyio
27+
@pytest.mark.parametrize(
28+
"exception_class,message",
29+
[
30+
(ValueError, "Test validation error"),
31+
(RuntimeError, "Test runtime error"),
32+
(KeyError, "Test key error"),
33+
(Exception, "Basic error"),
34+
],
35+
)
36+
async def test_exception_handling_with_raise_exceptions_false(exception_class: type[Exception], message: str):
37+
"""Test that exceptions are logged when raise_exceptions=False"""
38+
server = Server("test-server")
39+
session = Mock(spec=ServerSession)
40+
session.send_log_message = AsyncMock()
41+
42+
test_exception = exception_class(message)
43+
44+
await server._handle_message(test_exception, session, {}, raise_exceptions=False)
45+
46+
# Should send log message
47+
session.send_log_message.assert_called_once()
48+
call_args = session.send_log_message.call_args
49+
50+
assert call_args.kwargs["level"] == "error"
51+
assert call_args.kwargs["data"] == "Internal Server Error"
52+
assert call_args.kwargs["logger"] == "mcp.server.exception_handler"
53+
54+
55+
@pytest.mark.anyio
56+
async def test_normal_message_handling_not_affected():
57+
"""Test that normal messages still work correctly"""
58+
server = Server("test-server")
59+
session = Mock(spec=ServerSession)
60+
61+
# Create a mock RequestResponder
62+
responder = Mock(spec=RequestResponder)
63+
responder.request = types.ClientRequest(root=types.PingRequest(method="ping"))
64+
responder.__enter__ = Mock(return_value=responder)
65+
responder.__exit__ = Mock(return_value=None)
66+
67+
# Mock the _handle_request method to avoid complex setup
68+
server._handle_request = AsyncMock()
69+
70+
# Should handle normally without any exception handling
71+
await server._handle_message(responder, session, {}, raise_exceptions=False)
72+
73+
# Verify _handle_request was called
74+
server._handle_request.assert_called_once()

0 commit comments

Comments
 (0)