Skip to content

Commit 4c2bdb7

Browse files
committed
refactor: simplify notification handling by passing session
This commit refactors the notification handling logic to eliminate the global context variable and introspection. The `ServerSession` is now explicitly passed to notification handlers, simplifying the control flow and improving explicitness. This addresses the code review feedback to avoid introspection and global state in the low-level server code.
1 parent b03a07c commit 4c2bdb7

File tree

2 files changed

+45
-130
lines changed

2 files changed

+45
-130
lines changed

src/mcp/server/lowlevel/server.py

Lines changed: 32 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ async def main():
6868
from __future__ import annotations as _annotations
6969

7070
import contextvars
71-
import inspect
7271
import json
7372
import logging
7473
import warnings
@@ -105,9 +104,6 @@ async def main():
105104
# This will be properly typed in each Server instance's context
106105
request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any, Any]] = contextvars.ContextVar("request_ctx")
107106

108-
# Context variable to hold the current ServerSession, accessible by notification handlers
109-
current_session_ctx: contextvars.ContextVar[ServerSession] = contextvars.ContextVar("current_server_session")
110-
111107

112108
class NotificationOptions:
113109
def __init__(
@@ -512,16 +508,23 @@ async def handler(req: types.CallToolRequest):
512508

513509
def progress_notification(self):
514510
def decorator(
515-
func: Callable[[str | int, float, float | None, str | None], Awaitable[None]],
511+
func: Callable[
512+
[str | int, float, float | None, str | None, ServerSession | None],
513+
Awaitable[None],
514+
],
516515
):
517516
logger.debug("Registering handler for ProgressNotification")
518517

519-
async def handler(req: types.ProgressNotification):
518+
async def handler(
519+
req: types.ProgressNotification,
520+
session: ServerSession | None = None,
521+
):
520522
await func(
521523
req.params.progressToken,
522524
req.params.progress,
523525
req.params.total,
524526
req.params.message,
527+
session,
525528
)
526529

527530
self.notification_handlers[types.ProgressNotification] = handler
@@ -533,10 +536,10 @@ def initialized_notification(self):
533536
"""Decorator to register a handler for InitializedNotification."""
534537

535538
def decorator(
536-
func: (
537-
Callable[[types.InitializedNotification, ServerSession], Awaitable[None]]
538-
| Callable[[types.InitializedNotification], Awaitable[None]]
539-
),
539+
func: Callable[
540+
[types.InitializedNotification, ServerSession | None],
541+
Awaitable[None],
542+
],
540543
):
541544
logger.debug("Registering handler for InitializedNotification")
542545
self.notification_handlers[types.InitializedNotification] = func
@@ -548,10 +551,10 @@ def roots_list_changed_notification(self):
548551
"""Decorator to register a handler for RootsListChangedNotification."""
549552

550553
def decorator(
551-
func: (
552-
Callable[[types.RootsListChangedNotification, ServerSession], Awaitable[None]]
553-
| Callable[[types.RootsListChangedNotification], Awaitable[None]]
554-
),
554+
func: Callable[
555+
[types.RootsListChangedNotification, ServerSession | None],
556+
Awaitable[None],
557+
],
555558
):
556559
logger.debug("Registering handler for RootsListChangedNotification")
557560
self.notification_handlers[types.RootsListChangedNotification] = func
@@ -635,21 +638,17 @@ async def _handle_message(
635638
lifespan_context: LifespanResultT,
636639
raise_exceptions: bool = False,
637640
):
638-
session_token = current_session_ctx.set(session)
639-
try:
640-
with warnings.catch_warnings(record=True) as w:
641-
# TODO(Marcelo): We should be checking if message is Exception here.
642-
match message: # type: ignore[reportMatchNotExhaustive]
643-
case RequestResponder(request=types.ClientRequest(root=req)) as responder:
644-
with responder:
645-
await self._handle_request(message, req, session, lifespan_context, raise_exceptions)
646-
case types.ClientNotification(root=notify):
647-
await self._handle_notification(notify)
648-
649-
for warning in w:
650-
logger.info("Warning: %s: %s", warning.category.__name__, warning.message)
651-
finally:
652-
current_session_ctx.reset(session_token)
641+
with warnings.catch_warnings(record=True) as w:
642+
# TODO(Marcelo): We should be checking if message is Exception here.
643+
match message: # type: ignore[reportMatchNotExhaustive]
644+
case RequestResponder(request=types.ClientRequest(root=req)) as responder:
645+
with responder:
646+
await self._handle_request(message, req, session, lifespan_context, raise_exceptions)
647+
case types.ClientNotification(root=notify):
648+
await self._handle_notification(notify, session)
649+
650+
for warning in w:
651+
logger.info("Warning: %s: %s", warning.category.__name__, warning.message)
653652

654653
async def _handle_request(
655654
self,
@@ -710,15 +709,14 @@ async def _handle_request(
710709

711710
logger.debug("Response sent")
712711

713-
async def _handle_notification(self, notify: Any):
712+
async def _handle_notification(self, notify: Any, session: ServerSession):
714713
if handler := self.notification_handlers.get(type(notify)): # type: ignore
715714
logger.debug("Dispatching notification of type %s", type(notify).__name__)
716715

717716
try:
718-
sig = inspect.signature(handler)
719-
if "session" in sig.parameters:
720-
await handler(notify, current_session_ctx.get())
721-
else:
717+
try:
718+
await handler(notify, session)
719+
except TypeError:
722720
await handler(notify)
723721
except Exception:
724722
logger.exception("Uncaught exception in notification handler")

tests/shared/test_notifications.py

Lines changed: 13 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from mcp.shared.context import RequestContext
1414
from mcp.shared.message import SessionMessage
1515
from mcp.shared.progress import progress
16-
from mcp.shared.session import BaseSession, RequestResponder, SessionMessage
16+
from mcp.shared.session import BaseSession, RequestResponder
1717

1818

1919
@pytest.mark.anyio
@@ -62,6 +62,7 @@ async def handle_progress(
6262
progress: float,
6363
total: float | None,
6464
message: str | None,
65+
session: ServerSession | None,
6566
):
6667
server_progress_updates.append(
6768
{
@@ -228,6 +229,7 @@ async def handle_progress(
228229
progress: float,
229230
total: float | None,
230231
message: str | None,
232+
session: ServerSession | None,
231233
):
232234
server_progress_updates.append(
233235
{"token": progress_token, "progress": progress, "total": total, "message": message}
@@ -332,9 +334,15 @@ async def test_initialized_notification():
332334

333335
server = Server("test")
334336
initialized_received = asyncio.Event()
337+
received_session: ServerSession | None = None
335338

336339
@server.initialized_notification()
337-
async def handle_initialized(notification: types.InitializedNotification):
340+
async def handle_initialized(
341+
notification: types.InitializedNotification,
342+
session: ServerSession | None = None,
343+
):
344+
nonlocal received_session
345+
received_session = session
338346
initialized_received.set()
339347

340348
async def run_server():
@@ -364,6 +372,7 @@ async def message_handler(
364372
tg.cancel_scope.cancel()
365373

366374
assert initialized_received.is_set()
375+
assert isinstance(received_session, ServerSession)
367376

368377

369378
@pytest.mark.anyio
@@ -374,105 +383,13 @@ async def test_roots_list_changed_notification():
374383

375384
server = Server("test")
376385
roots_list_changed_received = asyncio.Event()
386+
received_session: ServerSession | None = None
377387

378388
@server.roots_list_changed_notification()
379389
async def handle_roots_list_changed(
380390
notification: types.RootsListChangedNotification,
391+
session: ServerSession | None = None,
381392
):
382-
roots_list_changed_received.set()
383-
384-
async def run_server():
385-
await server.run(
386-
client_to_server_receive,
387-
server_to_client_send,
388-
server.create_initialization_options(),
389-
)
390-
391-
async def message_handler(
392-
message: (RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception),
393-
) -> None:
394-
if isinstance(message, Exception):
395-
raise message
396-
397-
async with (
398-
ClientSession(
399-
server_to_client_receive,
400-
client_to_server_send,
401-
message_handler=message_handler,
402-
) as client_session,
403-
anyio.create_task_group() as tg,
404-
):
405-
tg.start_soon(run_server)
406-
await client_session.initialize()
407-
await client_session.send_notification(
408-
types.ClientNotification(
409-
root=types.RootsListChangedNotification(method="notifications/roots/list_changed", params=None)
410-
)
411-
)
412-
await roots_list_changed_received.wait()
413-
tg.cancel_scope.cancel()
414-
415-
assert roots_list_changed_received.is_set()
416-
417-
418-
@pytest.mark.anyio
419-
async def test_initialized_notification_with_session():
420-
"""Test that the server receives and handles InitializedNotification with a session."""
421-
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
422-
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1)
423-
424-
server = Server("test")
425-
initialized_received = asyncio.Event()
426-
received_session = None
427-
428-
@server.initialized_notification()
429-
async def handle_initialized(notification: types.InitializedNotification, session: ServerSession):
430-
nonlocal received_session
431-
received_session = session
432-
initialized_received.set()
433-
434-
async def run_server():
435-
await server.run(
436-
client_to_server_receive,
437-
server_to_client_send,
438-
server.create_initialization_options(),
439-
)
440-
441-
async def message_handler(
442-
message: (RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception),
443-
) -> None:
444-
if isinstance(message, Exception):
445-
raise message
446-
447-
async with (
448-
ClientSession(
449-
server_to_client_receive,
450-
client_to_server_send,
451-
message_handler=message_handler,
452-
) as client_session,
453-
anyio.create_task_group() as tg,
454-
):
455-
tg.start_soon(run_server)
456-
await client_session.initialize()
457-
await initialized_received.wait()
458-
tg.cancel_scope.cancel()
459-
460-
assert initialized_received.is_set()
461-
assert isinstance(received_session, ServerSession)
462-
463-
464-
@pytest.mark.anyio
465-
async def test_roots_list_changed_notification_with_session():
466-
"""Test that the server receives and handles RootsListChangedNotification with a session."""
467-
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
468-
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1)
469-
470-
server = Server("test")
471-
roots_list_changed_received = asyncio.Event()
472-
received_session = None
473-
474-
@server.roots_list_changed_notification()
475-
async def handle_roots_list_changed(notification: types.RootsListChangedNotification, session: ServerSession):
476393
nonlocal received_session
477394
received_session = session
478395
roots_list_changed_received.set()

0 commit comments

Comments
 (0)