Skip to content

Commit f1a964c

Browse files
Add unit test + remove try-except-else + add try-except for JSONRPCRequest
1 parent be40a3e commit f1a964c

File tree

2 files changed

+121
-31
lines changed

2 files changed

+121
-31
lines changed

src/mcp/shared/session.py

Lines changed: 41 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -351,11 +351,21 @@ async def _receive_loop(self) -> None:
351351
if isinstance(message, Exception):
352352
await self._handle_incoming(message)
353353
elif isinstance(message.message.root, JSONRPCRequest):
354-
validated_request = self._receive_request_type.model_validate(
355-
message.message.root.model_dump(
356-
by_alias=True, mode="json", exclude_none=True
354+
try:
355+
validated_request = self._receive_request_type.model_validate(
356+
message.message.root.model_dump(
357+
by_alias=True, mode="json", exclude_none=True
358+
)
357359
)
358-
)
360+
except Exception as e:
361+
# For other validation errors, log and continue
362+
logging.warning(
363+
"Failed to validate request: %s. Message was: %s",
364+
e,
365+
message.message.root,
366+
)
367+
continue
368+
359369
responder = RequestResponder(
360370
request_id=message.message.root.id,
361371
request_meta=validated_request.root.params.meta
@@ -386,33 +396,33 @@ async def _receive_loop(self) -> None:
386396
e,
387397
message.message.root,
388398
)
389-
else: # Notification is valid
390-
# Handle cancellation notifications
391-
if isinstance(notification.root, CancelledNotification):
392-
cancelled_id = notification.root.params.requestId
393-
if cancelled_id in self._in_flight:
394-
await self._in_flight[cancelled_id].cancel()
395-
else:
396-
# Handle progress notifications callback
397-
if isinstance(notification.root, ProgressNotification):
398-
progress_token = notification.root.params.progressToken
399-
# If there is a progress callback for this token,
400-
# call it with the progress information
401-
if progress_token in self._progress_callbacks:
402-
callback = self._progress_callbacks[progress_token]
403-
try:
404-
await callback(
405-
notification.root.params.progress,
406-
notification.root.params.total,
407-
notification.root.params.message,
408-
)
409-
except Exception as e:
410-
logging.warning(
411-
"Progress callback raised an exception: %s",
412-
e,
413-
)
414-
await self._received_notification(notification)
415-
await self._handle_incoming(notification)
399+
continue
400+
# Handle cancellation notifications
401+
if isinstance(notification.root, CancelledNotification):
402+
cancelled_id = notification.root.params.requestId
403+
if cancelled_id in self._in_flight:
404+
await self._in_flight[cancelled_id].cancel()
405+
else:
406+
# Handle progress notifications callback
407+
if isinstance(notification.root, ProgressNotification):
408+
progress_token = notification.root.params.progressToken
409+
# If there is a progress callback for this token,
410+
# call it with the progress information
411+
if progress_token in self._progress_callbacks:
412+
callback = self._progress_callbacks[progress_token]
413+
try:
414+
await callback(
415+
notification.root.params.progress,
416+
notification.root.params.total,
417+
notification.root.params.message,
418+
)
419+
except Exception as e:
420+
logging.warning(
421+
"Progress callback raised an exception: %s",
422+
e,
423+
)
424+
await self._received_notification(notification)
425+
await self._handle_incoming(notification)
416426
else: # Response or error
417427
stream = self._response_streams.pop(message.message.root.id, None)
418428
if stream:

tests/shared/test_progress_notifications.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import Any, cast
2+
from unittest.mock import patch
23

34
import anyio
45
import pytest
@@ -10,12 +11,16 @@
1011
from mcp.server.models import InitializationOptions
1112
from mcp.server.session import ServerSession
1213
from mcp.shared.context import RequestContext
14+
from mcp.shared.memory import create_connected_server_and_client_session
1315
from mcp.shared.progress import progress
1416
from mcp.shared.session import (
1517
BaseSession,
1618
RequestResponder,
1719
SessionMessage,
1820
)
21+
from mcp.types import (
22+
TextContent,
23+
)
1924

2025

2126
@pytest.mark.anyio
@@ -347,3 +352,78 @@ async def handle_client_message(
347352
assert server_progress_updates[3]["progress"] == 100
348353
assert server_progress_updates[3]["total"] == 100
349354
assert server_progress_updates[3]["message"] == "Processing results..."
355+
356+
357+
@pytest.mark.anyio
358+
async def test_progress_callback_exception_logging():
359+
"""Test that exceptions in progress callbacks are logged and \
360+
don't crash the session."""
361+
# Track logged warnings
362+
logged_warnings = []
363+
364+
def mock_warning(msg, *args):
365+
logged_warnings.append(msg % args if args else msg)
366+
367+
# Create a progress callback that raises an exception
368+
async def failing_progress_callback(
369+
progress: float, total: float | None, message: str | None
370+
) -> None:
371+
raise ValueError("Progress callback failed!")
372+
373+
# Create a server with a tool that sends progress notifications
374+
server = Server(name="TestProgressServer")
375+
376+
@server.call_tool()
377+
async def handle_call_tool(
378+
name: str, arguments: dict | None
379+
) -> list[types.TextContent]:
380+
if name == "progress_tool":
381+
# Send a progress notification
382+
await server.request_context.session.send_progress_notification(
383+
progress_token=server.request_context.request_id,
384+
progress=50.0,
385+
total=100.0,
386+
message="Halfway done",
387+
)
388+
return [types.TextContent(type="text", text="progress_result")]
389+
raise ValueError(f"Unknown tool: {name}")
390+
391+
@server.list_tools()
392+
async def handle_list_tools() -> list[types.Tool]:
393+
return [
394+
types.Tool(
395+
name="progress_tool",
396+
description="A tool that sends progress notifications",
397+
inputSchema={},
398+
)
399+
]
400+
401+
# Test with mocked logging
402+
with patch("mcp.shared.session.logging.warning", side_effect=mock_warning):
403+
async with create_connected_server_and_client_session(server) as client_session:
404+
# Send a request with a failing progress callback
405+
result = await client_session.send_request(
406+
types.ClientRequest(
407+
types.CallToolRequest(
408+
method="tools/call",
409+
params=types.CallToolRequestParams(
410+
name="progress_tool", arguments={}
411+
),
412+
)
413+
),
414+
types.CallToolResult,
415+
progress_callback=failing_progress_callback,
416+
)
417+
418+
# Verify the request completed successfully despite the callback failure
419+
assert len(result.content) == 1
420+
content = result.content[0]
421+
assert isinstance(content, TextContent)
422+
assert content.text == "progress_result"
423+
424+
# Check that a warning was logged for the progress callback exception
425+
assert len(logged_warnings) > 0
426+
assert any(
427+
"Progress callback raised an exception" in warning
428+
for warning in logged_warnings
429+
)

0 commit comments

Comments
 (0)