@@ -22,6 +22,14 @@ async def __call__(
2222 ) -> types .CreateMessageResult | types .ErrorData : ...
2323
2424
25+ class ElicitationFnT (Protocol ):
26+ async def __call__ (
27+ self ,
28+ context : RequestContext ["ClientSession" , Any ],
29+ params : types .ElicitRequestParams ,
30+ ) -> types .ElicitResult | types .ErrorData : ...
31+
32+
2533class ListRootsFnT (Protocol ):
2634 async def __call__ (
2735 self , context : RequestContext ["ClientSession" , Any ]
@@ -62,6 +70,16 @@ async def _default_sampling_callback(
6270 )
6371
6472
73+ async def _default_elicitation_callback (
74+ context : RequestContext ["ClientSession" , Any ],
75+ params : types .ElicitRequestParams ,
76+ ) -> types .ElicitResult | types .ErrorData :
77+ return types .ErrorData (
78+ code = types .INVALID_REQUEST ,
79+ message = "Elicitation not supported" ,
80+ )
81+
82+
6583async def _default_list_roots_callback (
6684 context : RequestContext ["ClientSession" , Any ],
6785) -> types .ListRootsResult | types .ErrorData :
@@ -97,6 +115,7 @@ def __init__(
97115 write_stream : MemoryObjectSendStream [SessionMessage ],
98116 read_timeout_seconds : timedelta | None = None ,
99117 sampling_callback : SamplingFnT | None = None ,
118+ elicitation_callback : ElicitationFnT | None = None ,
100119 list_roots_callback : ListRootsFnT | None = None ,
101120 logging_callback : LoggingFnT | None = None ,
102121 message_handler : MessageHandlerFnT | None = None ,
@@ -111,12 +130,16 @@ def __init__(
111130 )
112131 self ._client_info = client_info or DEFAULT_CLIENT_INFO
113132 self ._sampling_callback = sampling_callback or _default_sampling_callback
133+ self ._elicitation_callback = (
134+ elicitation_callback or _default_elicitation_callback
135+ )
114136 self ._list_roots_callback = list_roots_callback or _default_list_roots_callback
115137 self ._logging_callback = logging_callback or _default_logging_callback
116138 self ._message_handler = message_handler or _default_message_handler
117139
118140 async def initialize (self ) -> types .InitializeResult :
119141 sampling = types .SamplingCapability ()
142+ elicitation = types .ElicitationCapability ()
120143 roots = types .RootsCapability (
121144 # TODO: Should this be based on whether we
122145 # _will_ send notifications, or only whether
@@ -132,6 +155,7 @@ async def initialize(self) -> types.InitializeResult:
132155 protocolVersion = types .LATEST_PROTOCOL_VERSION ,
133156 capabilities = types .ClientCapabilities (
134157 sampling = sampling ,
158+ elicitation = elicitation ,
135159 experimental = None ,
136160 roots = roots ,
137161 ),
@@ -355,6 +379,12 @@ async def _received_request(
355379 client_response = ClientResponse .validate_python (response )
356380 await responder .respond (client_response )
357381
382+ case types .ElicitRequest (params = params ):
383+ with responder :
384+ response = await self ._elicitation_callback (ctx , params )
385+ client_response = ClientResponse .validate_python (response )
386+ await responder .respond (client_response )
387+
358388 case types .ListRootsRequest ():
359389 with responder :
360390 response = await self ._list_roots_callback (ctx )
0 commit comments