@@ -24,6 +24,13 @@ async def __call__(
2424 ) -> types .ListRootsResult | types .ErrorData : ...
2525
2626
27+ class LoggingFnT (Protocol ):
28+ async def __call__ (
29+ self ,
30+ params : types .LoggingMessageNotificationParams ,
31+ ) -> None : ...
32+
33+
2734async def _default_sampling_callback (
2835 context : RequestContext ["ClientSession" , Any ],
2936 params : types .CreateMessageRequestParams ,
@@ -43,6 +50,12 @@ async def _default_list_roots_callback(
4350 )
4451
4552
53+ async def _default_logging_callback (
54+ params : types .LoggingMessageNotificationParams ,
55+ ) -> None :
56+ pass
57+
58+
4659ClientResponse : TypeAdapter [types .ClientResult | types .ErrorData ] = TypeAdapter (
4760 types .ClientResult | types .ErrorData
4861)
@@ -64,6 +77,7 @@ def __init__(
6477 read_timeout_seconds : timedelta | None = None ,
6578 sampling_callback : SamplingFnT | None = None ,
6679 list_roots_callback : ListRootsFnT | None = None ,
80+ logging_callback : LoggingFnT | None = None ,
6781 ) -> None :
6882 super ().__init__ (
6983 read_stream ,
@@ -74,20 +88,15 @@ def __init__(
7488 )
7589 self ._sampling_callback = sampling_callback or _default_sampling_callback
7690 self ._list_roots_callback = list_roots_callback or _default_list_roots_callback
91+ self ._logging_callback = logging_callback or _default_logging_callback
7792
7893 async def initialize (self ) -> types .InitializeResult :
79- sampling = (
80- types .SamplingCapability () if self ._sampling_callback is not None else None
81- )
82- roots = (
83- types .RootsCapability (
84- # TODO: Should this be based on whether we
85- # _will_ send notifications, or only whether
86- # they're supported?
87- listChanged = True ,
88- )
89- if self ._list_roots_callback is not None
90- else None
94+ sampling = types .SamplingCapability ()
95+ roots = types .RootsCapability (
96+ # TODO: Should this be based on whether we
97+ # _will_ send notifications, or only whether
98+ # they're supported?
99+ listChanged = True ,
91100 )
92101
93102 result = await self .send_request (
@@ -327,3 +336,13 @@ async def _received_request(
327336 return await responder .respond (
328337 types .ClientResult (root = types .EmptyResult ())
329338 )
339+
340+ async def _received_notification (
341+ self , notification : types .ServerNotification
342+ ) -> None :
343+ """Handle notifications from the server."""
344+ match notification .root :
345+ case types .LoggingMessageNotification (params = params ):
346+ await self ._logging_callback (params )
347+ case _:
348+ pass
0 commit comments