Skip to content

Commit 1063634

Browse files
committed
Adding support for roots changed notification and initialized notification.
1 parent dced223 commit 1063634

File tree

2 files changed

+242
-11
lines changed

2 files changed

+242
-11
lines changed

src/mcp/server/lowlevel/server.py

Lines changed: 52 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ async def main():
6868
from __future__ import annotations as _annotations
6969

7070
import contextvars
71+
import inspect
7172
import json
7273
import logging
7374
import warnings
@@ -104,6 +105,9 @@ async def main():
104105
# This will be properly typed in each Server instance's context
105106
request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any, Any]] = contextvars.ContextVar("request_ctx")
106107

108+
# Context variable to hold the current ServerSession, accessible by notification handlers
109+
current_session_ctx: contextvars.ContextVar[ServerSession] = contextvars.ContextVar("current_server_session")
110+
107111

108112
class NotificationOptions:
109113
def __init__(
@@ -520,6 +524,36 @@ async def handler(req: types.ProgressNotification):
520524

521525
return decorator
522526

527+
def initialized_notification(self):
528+
"""Decorator to register a handler for InitializedNotification."""
529+
530+
def decorator(
531+
func: (
532+
Callable[[types.InitializedNotification, ServerSession], Awaitable[None]]
533+
| Callable[[types.InitializedNotification], Awaitable[None]]
534+
),
535+
):
536+
logger.debug("Registering handler for InitializedNotification")
537+
self.notification_handlers[types.InitializedNotification] = func
538+
return func
539+
540+
return decorator
541+
542+
def roots_list_changed_notification(self):
543+
"""Decorator to register a handler for RootsListChangedNotification."""
544+
545+
def decorator(
546+
func: (
547+
Callable[[types.RootsListChangedNotification, ServerSession], Awaitable[None]]
548+
| Callable[[types.RootsListChangedNotification], Awaitable[None]]
549+
),
550+
):
551+
logger.debug("Registering handler for RootsListChangedNotification")
552+
self.notification_handlers[types.RootsListChangedNotification] = func
553+
return func
554+
555+
return decorator
556+
523557
def completion(self):
524558
"""Provides completions for prompts and resource templates"""
525559

@@ -591,22 +625,26 @@ async def run(
591625

592626
async def _handle_message(
593627
self,
594-
message: RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception,
628+
message: (RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception),
595629
session: ServerSession,
596630
lifespan_context: LifespanResultT,
597631
raise_exceptions: bool = False,
598632
):
599-
with warnings.catch_warnings(record=True) as w:
600-
# TODO(Marcelo): We should be checking if message is Exception here.
601-
match message: # type: ignore[reportMatchNotExhaustive]
602-
case RequestResponder(request=types.ClientRequest(root=req)) as responder:
603-
with responder:
604-
await self._handle_request(message, req, session, lifespan_context, raise_exceptions)
605-
case types.ClientNotification(root=notify):
606-
await self._handle_notification(notify)
633+
session_token = current_session_ctx.set(session)
634+
try:
635+
with warnings.catch_warnings(record=True) as w:
636+
# TODO(Marcelo): We should be checking if message is Exception here.
637+
match message: # type: ignore[reportMatchNotExhaustive]
638+
case RequestResponder(request=types.ClientRequest(root=req)) as responder:
639+
with responder:
640+
await self._handle_request(message, req, session, lifespan_context, raise_exceptions)
641+
case types.ClientNotification(root=notify):
642+
await self._handle_notification(notify)
607643

608644
for warning in w:
609645
logger.info("Warning: %s: %s", warning.category.__name__, warning.message)
646+
finally:
647+
current_session_ctx.reset(session_token)
610648

611649
async def _handle_request(
612650
self,
@@ -666,7 +704,11 @@ async def _handle_notification(self, notify: Any):
666704
logger.debug("Dispatching notification of type %s", type(notify).__name__)
667705

668706
try:
669-
await handler(notify)
707+
sig = inspect.signature(handler)
708+
if "session" in sig.parameters:
709+
await handler(notify, current_session_ctx.get())
710+
else:
711+
await handler(notify)
670712
except Exception:
671713
logger.exception("Uncaught exception in notification handler")
672714

tests/shared/test_progress_notifications.py renamed to tests/shared/test_notifications.py

Lines changed: 190 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
from typing import Any, cast
23

34
import anyio
@@ -10,11 +11,11 @@
1011
from mcp.server.models import InitializationOptions
1112
from mcp.server.session import ServerSession
1213
from mcp.shared.context import RequestContext
14+
from mcp.shared.message import SessionMessage
1315
from mcp.shared.progress import progress
1416
from mcp.shared.session import (
1517
BaseSession,
1618
RequestResponder,
17-
SessionMessage,
1819
)
1920

2021

@@ -333,3 +334,191 @@ async def handle_client_message(
333334
assert server_progress_updates[3]["progress"] == 100
334335
assert server_progress_updates[3]["total"] == 100
335336
assert server_progress_updates[3]["message"] == "Processing results..."
337+
338+
339+
@pytest.mark.anyio
340+
async def test_initialized_notification():
341+
"""Test that the server receives and handles InitializedNotification."""
342+
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
343+
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1)
344+
345+
server = Server("test")
346+
initialized_received = asyncio.Event()
347+
348+
@server.initialized_notification()
349+
async def handle_initialized(notification: types.InitializedNotification):
350+
initialized_received.set()
351+
352+
async def run_server():
353+
await server.run(
354+
client_to_server_receive,
355+
server_to_client_send,
356+
server.create_initialization_options(),
357+
)
358+
359+
async def message_handler(
360+
message: (RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception),
361+
) -> None:
362+
if isinstance(message, Exception):
363+
raise message
364+
365+
async with (
366+
ClientSession(
367+
server_to_client_receive,
368+
client_to_server_send,
369+
message_handler=message_handler,
370+
) as client_session,
371+
anyio.create_task_group() as tg,
372+
):
373+
tg.start_soon(run_server)
374+
await client_session.initialize()
375+
await initialized_received.wait()
376+
tg.cancel_scope.cancel()
377+
378+
assert initialized_received.is_set()
379+
380+
381+
@pytest.mark.anyio
382+
async def test_roots_list_changed_notification():
383+
"""Test that the server receives and handles RootsListChangedNotification."""
384+
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
385+
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1)
386+
387+
server = Server("test")
388+
roots_list_changed_received = asyncio.Event()
389+
390+
@server.roots_list_changed_notification()
391+
async def handle_roots_list_changed(
392+
notification: types.RootsListChangedNotification,
393+
):
394+
roots_list_changed_received.set()
395+
396+
async def run_server():
397+
await server.run(
398+
client_to_server_receive,
399+
server_to_client_send,
400+
server.create_initialization_options(),
401+
)
402+
403+
async def message_handler(
404+
message: (RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception),
405+
) -> None:
406+
if isinstance(message, Exception):
407+
raise message
408+
409+
async with (
410+
ClientSession(
411+
server_to_client_receive,
412+
client_to_server_send,
413+
message_handler=message_handler,
414+
) as client_session,
415+
anyio.create_task_group() as tg,
416+
):
417+
tg.start_soon(run_server)
418+
await client_session.initialize()
419+
await client_session.send_notification(
420+
types.ClientNotification(
421+
root=types.RootsListChangedNotification(method="notifications/roots/list_changed", params=None)
422+
)
423+
)
424+
await roots_list_changed_received.wait()
425+
tg.cancel_scope.cancel()
426+
427+
assert roots_list_changed_received.is_set()
428+
429+
430+
@pytest.mark.anyio
431+
async def test_initialized_notification_with_session():
432+
"""Test that the server receives and handles InitializedNotification with a session."""
433+
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
434+
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1)
435+
436+
server = Server("test")
437+
initialized_received = asyncio.Event()
438+
received_session = None
439+
440+
@server.initialized_notification()
441+
async def handle_initialized(notification: types.InitializedNotification, session: ServerSession):
442+
nonlocal received_session
443+
received_session = session
444+
initialized_received.set()
445+
446+
async def run_server():
447+
await server.run(
448+
client_to_server_receive,
449+
server_to_client_send,
450+
server.create_initialization_options(),
451+
)
452+
453+
async def message_handler(
454+
message: (RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception),
455+
) -> None:
456+
if isinstance(message, Exception):
457+
raise message
458+
459+
async with (
460+
ClientSession(
461+
server_to_client_receive,
462+
client_to_server_send,
463+
message_handler=message_handler,
464+
) as client_session,
465+
anyio.create_task_group() as tg,
466+
):
467+
tg.start_soon(run_server)
468+
await client_session.initialize()
469+
await initialized_received.wait()
470+
tg.cancel_scope.cancel()
471+
472+
assert initialized_received.is_set()
473+
assert isinstance(received_session, ServerSession)
474+
475+
476+
@pytest.mark.anyio
477+
async def test_roots_list_changed_notification_with_session():
478+
"""Test that the server receives and handles RootsListChangedNotification with a session."""
479+
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
480+
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1)
481+
482+
server = Server("test")
483+
roots_list_changed_received = asyncio.Event()
484+
received_session = None
485+
486+
@server.roots_list_changed_notification()
487+
async def handle_roots_list_changed(notification: types.RootsListChangedNotification, session: ServerSession):
488+
nonlocal received_session
489+
received_session = session
490+
roots_list_changed_received.set()
491+
492+
async def run_server():
493+
await server.run(
494+
client_to_server_receive,
495+
server_to_client_send,
496+
server.create_initialization_options(),
497+
)
498+
499+
async def message_handler(
500+
message: (RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception),
501+
) -> None:
502+
if isinstance(message, Exception):
503+
raise message
504+
505+
async with (
506+
ClientSession(
507+
server_to_client_receive,
508+
client_to_server_send,
509+
message_handler=message_handler,
510+
) as client_session,
511+
anyio.create_task_group() as tg,
512+
):
513+
tg.start_soon(run_server)
514+
await client_session.initialize()
515+
await client_session.send_notification(
516+
types.ClientNotification(
517+
root=types.RootsListChangedNotification(method="notifications/roots/list_changed", params=None)
518+
)
519+
)
520+
await roots_list_changed_received.wait()
521+
tg.cancel_scope.cancel()
522+
523+
assert roots_list_changed_received.is_set()
524+
assert isinstance(received_session, ServerSession)

0 commit comments

Comments
 (0)