Skip to content

Commit f857e0c

Browse files
authored
Merge branch 'main' into main
2 parents cb929ea + f676f6c commit f857e0c

File tree

5 files changed

+91
-29
lines changed

5 files changed

+91
-29
lines changed

.github/CODEOWNERS

Lines changed: 0 additions & 23 deletions
This file was deleted.

examples/clients/simple-chatbot/mcp_simple_chatbot/main.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,8 +291,15 @@ async def process_llm_response(self, llm_response: str) -> str:
291291
"""
292292
import json
293293

294+
def _clean_json_string(json_string: str) -> str:
295+
"""Remove ```json ... ``` or ``` ... ``` wrappers if the LLM response is fenced."""
296+
import re
297+
298+
pattern = r"^```(?:\s*json)?\s*(.*?)\s*```$"
299+
return re.sub(pattern, r"\1", json_string, flags=re.DOTALL | re.IGNORECASE).strip()
300+
294301
try:
295-
tool_call = json.loads(llm_response)
302+
tool_call = json.loads(_clean_json_string(llm_response))
296303
if "tool" in tool_call and "arguments" in tool_call:
297304
logging.info(f"Executing tool: {tool_call['tool']}")
298305
logging.info(f"With arguments: {tool_call['arguments']}")

src/mcp/client/sse.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from anyio.abc import TaskStatus
99
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
1010
from httpx_sse import aconnect_sse
11+
from httpx_sse._exceptions import SSEError
1112

1213
import mcp.types as types
1314
from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client
@@ -105,6 +106,9 @@ async def sse_reader(
105106
await read_stream_writer.send(session_message)
106107
case _:
107108
logger.warning(f"Unknown SSE event: {sse.event}")
109+
except SSEError as sse_exc:
110+
logger.exception("Encountered SSE exception")
111+
raise sse_exc
108112
except Exception as exc:
109113
logger.exception("Error in sse_reader")
110114
await read_stream_writer.send(exc)

src/mcp/shared/session.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -395,11 +395,17 @@ async def _receive_loop(self) -> None:
395395
# call it with the progress information
396396
if progress_token in self._progress_callbacks:
397397
callback = self._progress_callbacks[progress_token]
398-
await callback(
399-
notification.root.params.progress,
400-
notification.root.params.total,
401-
notification.root.params.message,
402-
)
398+
try:
399+
await callback(
400+
notification.root.params.progress,
401+
notification.root.params.total,
402+
notification.root.params.message,
403+
)
404+
except Exception as e:
405+
logging.error(
406+
"Progress callback raised an exception: %s",
407+
e,
408+
)
403409
await self._received_notification(notification)
404410
await self._handle_incoming(notification)
405411
except Exception as e:

tests/shared/test_progress_notifications.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import Any, cast
2+
from unittest.mock import patch
23

34
import anyio
45
import pytest
@@ -10,6 +11,7 @@
1011
from mcp.server.models import InitializationOptions
1112
from mcp.server.session import ServerSession
1213
from mcp.shared.context import RequestContext
14+
from mcp.shared.memory import create_connected_server_and_client_session
1315
from mcp.shared.progress import progress
1416
from mcp.shared.session import BaseSession, RequestResponder, SessionMessage
1517

@@ -320,3 +322,69 @@ async def handle_client_message(
320322
assert server_progress_updates[3]["progress"] == 100
321323
assert server_progress_updates[3]["total"] == 100
322324
assert server_progress_updates[3]["message"] == "Processing results..."
325+
326+
327+
@pytest.mark.anyio
328+
async def test_progress_callback_exception_logging():
329+
"""Test that exceptions in progress callbacks are logged and \
330+
don't crash the session."""
331+
# Track logged warnings
332+
logged_errors: list[str] = []
333+
334+
def mock_log_error(msg: str, *args: Any) -> None:
335+
logged_errors.append(msg % args if args else msg)
336+
337+
# Create a progress callback that raises an exception
338+
async def failing_progress_callback(progress: float, total: float | None, message: str | None) -> None:
339+
raise ValueError("Progress callback failed!")
340+
341+
# Create a server with a tool that sends progress notifications
342+
server = Server(name="TestProgressServer")
343+
344+
@server.call_tool()
345+
async def handle_call_tool(name: str, arguments: Any) -> list[types.TextContent]:
346+
if name == "progress_tool":
347+
# Send a progress notification
348+
await server.request_context.session.send_progress_notification(
349+
progress_token=server.request_context.request_id,
350+
progress=50.0,
351+
total=100.0,
352+
message="Halfway done",
353+
)
354+
return [types.TextContent(type="text", text="progress_result")]
355+
raise ValueError(f"Unknown tool: {name}")
356+
357+
@server.list_tools()
358+
async def handle_list_tools() -> list[types.Tool]:
359+
return [
360+
types.Tool(
361+
name="progress_tool",
362+
description="A tool that sends progress notifications",
363+
inputSchema={},
364+
)
365+
]
366+
367+
# Test with mocked logging
368+
with patch("mcp.shared.session.logging.error", side_effect=mock_log_error):
369+
async with create_connected_server_and_client_session(server) as client_session:
370+
# Send a request with a failing progress callback
371+
result = await client_session.send_request(
372+
types.ClientRequest(
373+
types.CallToolRequest(
374+
method="tools/call",
375+
params=types.CallToolRequestParams(name="progress_tool", arguments={}),
376+
)
377+
),
378+
types.CallToolResult,
379+
progress_callback=failing_progress_callback,
380+
)
381+
382+
# Verify the request completed successfully despite the callback failure
383+
assert len(result.content) == 1
384+
content = result.content[0]
385+
assert isinstance(content, types.TextContent)
386+
assert content.text == "progress_result"
387+
388+
# Check that a warning was logged for the progress callback exception
389+
assert len(logged_errors) > 0
390+
assert any("Progress callback raised an exception" in warning for warning in logged_errors)

0 commit comments

Comments
 (0)