Skip to content

Commit 5e1b3bb

Browse files
Merge branch 'main' into lorenzocesconetto/fix-callback-exception-warning
2 parents e973737 + 5441767 commit 5e1b3bb

File tree

7 files changed

+448
-190
lines changed

7 files changed

+448
-190
lines changed

src/mcp/client/session.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,12 +116,18 @@ def __init__(
116116
self._message_handler = message_handler or _default_message_handler
117117

118118
async def initialize(self) -> types.InitializeResult:
119-
sampling = types.SamplingCapability()
120-
roots = types.RootsCapability(
119+
sampling = (
120+
types.SamplingCapability()
121+
if self._sampling_callback is not _default_sampling_callback
122+
else None
123+
)
124+
roots = (
121125
# TODO: Should this be based on whether we
122126
# _will_ send notifications, or only whether
123127
# they're supported?
124-
listChanged=True,
128+
types.RootsCapability(listChanged=True)
129+
if self._list_roots_callback is not _default_list_roots_callback
130+
else None
125131
)
126132

127133
result = await self.send_request(

src/mcp/server/lowlevel/server.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def __init__(
147147
}
148148
self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {}
149149
self.notification_options = NotificationOptions()
150-
logger.debug(f"Initializing server '{name}'")
150+
logger.debug("Initializing server %r", name)
151151

152152
def create_initialization_options(
153153
self,
@@ -510,7 +510,7 @@ async def run(
510510

511511
async with anyio.create_task_group() as tg:
512512
async for message in session.incoming_messages:
513-
logger.debug(f"Received message: {message}")
513+
logger.debug("Received message: %s", message)
514514

515515
tg.start_soon(
516516
self._handle_message,
@@ -543,7 +543,9 @@ async def _handle_message(
543543
await self._handle_notification(notify)
544544

545545
for warning in w:
546-
logger.info(f"Warning: {warning.category.__name__}: {warning.message}")
546+
logger.info(
547+
"Warning: %s: %s", warning.category.__name__, warning.message
548+
)
547549

548550
async def _handle_request(
549551
self,
@@ -553,10 +555,9 @@ async def _handle_request(
553555
lifespan_context: LifespanResultT,
554556
raise_exceptions: bool,
555557
):
556-
logger.info(f"Processing request of type {type(req).__name__}")
557-
if type(req) in self.request_handlers:
558-
handler = self.request_handlers[type(req)]
559-
logger.debug(f"Dispatching request of type {type(req).__name__}")
558+
logger.info("Processing request of type %s", type(req).__name__)
559+
if handler := self.request_handlers.get(type(req)): # type: ignore
560+
logger.debug("Dispatching request of type %s", type(req).__name__)
560561

561562
token = None
562563
try:
@@ -602,16 +603,13 @@ async def _handle_request(
602603
logger.debug("Response sent")
603604

604605
async def _handle_notification(self, notify: Any):
605-
if type(notify) in self.notification_handlers:
606-
assert type(notify) in self.notification_handlers
607-
608-
handler = self.notification_handlers[type(notify)]
609-
logger.debug(f"Dispatching notification of type {type(notify).__name__}")
606+
if handler := self.notification_handlers.get(type(notify)): # type: ignore
607+
logger.debug("Dispatching notification of type %s", type(notify).__name__)
610608

611609
try:
612610
await handler(notify)
613-
except Exception as err:
614-
logger.error(f"Uncaught exception in notification handler: {err}")
611+
except Exception:
612+
logger.exception("Uncaught exception in notification handler")
615613

616614

617615
async def _ping_handler(request: types.PingRequest) -> types.ServerResult:

src/mcp/server/streamable_http.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,8 @@ async def _handle_post_request(
397397
await response(scope, receive, send)
398398

399399
# Process the message after sending the response
400-
session_message = SessionMessage(message)
400+
metadata = ServerMessageMetadata(request_context=request)
401+
session_message = SessionMessage(message, metadata=metadata)
401402
await writer.send(session_message)
402403

403404
return
@@ -412,7 +413,8 @@ async def _handle_post_request(
412413

413414
if self.is_json_response_enabled:
414415
# Process the message
415-
session_message = SessionMessage(message)
416+
metadata = ServerMessageMetadata(request_context=request)
417+
session_message = SessionMessage(message, metadata=metadata)
416418
await writer.send(session_message)
417419
try:
418420
# Process messages from the request-specific stream
@@ -511,7 +513,8 @@ async def sse_writer():
511513
async with anyio.create_task_group() as tg:
512514
tg.start_soon(response, scope, receive, send)
513515
# Then send the message to be processed by the server
514-
session_message = SessionMessage(message)
516+
metadata = ServerMessageMetadata(request_context=request)
517+
session_message = SessionMessage(message, metadata=metadata)
515518
await writer.send(session_message)
516519
except Exception:
517520
logger.exception("SSE response error")

src/mcp/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ class RootsCapability(BaseModel):
218218

219219

220220
class SamplingCapability(BaseModel):
221-
"""Capability for logging operations."""
221+
"""Capability for sampling operations."""
222222

223223
model_config = ConfigDict(extra="allow")
224224

tests/client/test_session.py

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
from typing import Any
2+
13
import anyio
24
import pytest
35

46
import mcp.types as types
57
from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession
8+
from mcp.shared.context import RequestContext
69
from mcp.shared.message import SessionMessage
710
from mcp.shared.session import RequestResponder
811
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
@@ -380,3 +383,167 @@ async def mock_server():
380383
# Should raise RuntimeError for unsupported version
381384
with pytest.raises(RuntimeError, match="Unsupported protocol version"):
382385
await session.initialize()
386+
387+
388+
@pytest.mark.anyio
389+
async def test_client_capabilities_default():
390+
"""Test that client capabilities are properly set with default callbacks"""
391+
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
392+
SessionMessage
393+
](1)
394+
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
395+
SessionMessage
396+
](1)
397+
398+
received_capabilities = None
399+
400+
async def mock_server():
401+
nonlocal received_capabilities
402+
403+
session_message = await client_to_server_receive.receive()
404+
jsonrpc_request = session_message.message
405+
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
406+
request = ClientRequest.model_validate(
407+
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
408+
)
409+
assert isinstance(request.root, InitializeRequest)
410+
received_capabilities = request.root.params.capabilities
411+
412+
result = ServerResult(
413+
InitializeResult(
414+
protocolVersion=LATEST_PROTOCOL_VERSION,
415+
capabilities=ServerCapabilities(),
416+
serverInfo=Implementation(name="mock-server", version="0.1.0"),
417+
)
418+
)
419+
420+
async with server_to_client_send:
421+
await server_to_client_send.send(
422+
SessionMessage(
423+
JSONRPCMessage(
424+
JSONRPCResponse(
425+
jsonrpc="2.0",
426+
id=jsonrpc_request.root.id,
427+
result=result.model_dump(
428+
by_alias=True, mode="json", exclude_none=True
429+
),
430+
)
431+
)
432+
)
433+
)
434+
# Receive initialized notification
435+
await client_to_server_receive.receive()
436+
437+
async with (
438+
ClientSession(
439+
server_to_client_receive,
440+
client_to_server_send,
441+
) as session,
442+
anyio.create_task_group() as tg,
443+
client_to_server_send,
444+
client_to_server_receive,
445+
server_to_client_send,
446+
server_to_client_receive,
447+
):
448+
tg.start_soon(mock_server)
449+
await session.initialize()
450+
451+
# Assert that capabilities are properly set with defaults
452+
assert received_capabilities is not None
453+
assert received_capabilities.sampling is None # No custom sampling callback
454+
assert received_capabilities.roots is None # No custom list_roots callback
455+
456+
457+
@pytest.mark.anyio
458+
async def test_client_capabilities_with_custom_callbacks():
459+
"""Test that client capabilities are properly set with custom callbacks"""
460+
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
461+
SessionMessage
462+
](1)
463+
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
464+
SessionMessage
465+
](1)
466+
467+
received_capabilities = None
468+
469+
async def custom_sampling_callback(
470+
context: RequestContext["ClientSession", Any],
471+
params: types.CreateMessageRequestParams,
472+
) -> types.CreateMessageResult | types.ErrorData:
473+
return types.CreateMessageResult(
474+
role="assistant",
475+
content=types.TextContent(type="text", text="test"),
476+
model="test-model",
477+
)
478+
479+
async def custom_list_roots_callback(
480+
context: RequestContext["ClientSession", Any],
481+
) -> types.ListRootsResult | types.ErrorData:
482+
return types.ListRootsResult(roots=[])
483+
484+
async def mock_server():
485+
nonlocal received_capabilities
486+
487+
session_message = await client_to_server_receive.receive()
488+
jsonrpc_request = session_message.message
489+
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
490+
request = ClientRequest.model_validate(
491+
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
492+
)
493+
assert isinstance(request.root, InitializeRequest)
494+
received_capabilities = request.root.params.capabilities
495+
496+
result = ServerResult(
497+
InitializeResult(
498+
protocolVersion=LATEST_PROTOCOL_VERSION,
499+
capabilities=ServerCapabilities(),
500+
serverInfo=Implementation(name="mock-server", version="0.1.0"),
501+
)
502+
)
503+
504+
async with server_to_client_send:
505+
await server_to_client_send.send(
506+
SessionMessage(
507+
JSONRPCMessage(
508+
JSONRPCResponse(
509+
jsonrpc="2.0",
510+
id=jsonrpc_request.root.id,
511+
result=result.model_dump(
512+
by_alias=True, mode="json", exclude_none=True
513+
),
514+
)
515+
)
516+
)
517+
)
518+
# Receive initialized notification
519+
await client_to_server_receive.receive()
520+
521+
async with (
522+
ClientSession(
523+
server_to_client_receive,
524+
client_to_server_send,
525+
sampling_callback=custom_sampling_callback,
526+
list_roots_callback=custom_list_roots_callback,
527+
) as session,
528+
anyio.create_task_group() as tg,
529+
client_to_server_send,
530+
client_to_server_receive,
531+
server_to_client_send,
532+
server_to_client_receive,
533+
):
534+
tg.start_soon(mock_server)
535+
await session.initialize()
536+
537+
# Assert that capabilities are properly set with custom callbacks
538+
assert received_capabilities is not None
539+
assert (
540+
received_capabilities.sampling is not None
541+
) # Custom sampling callback provided
542+
assert isinstance(received_capabilities.sampling, types.SamplingCapability)
543+
assert (
544+
received_capabilities.roots is not None
545+
) # Custom list_roots callback provided
546+
assert isinstance(received_capabilities.roots, types.RootsCapability)
547+
assert (
548+
received_capabilities.roots.listChanged is True
549+
) # Should be True for custom callback

0 commit comments

Comments
 (0)