@@ -49,6 +49,95 @@ async def __call__(
4949 ) -> None : ... # pragma: no branch
5050
5151
52+ # Experimental: Task handler protocols for server -> client requests
53+ class GetTaskHandlerFnT (Protocol ):
54+ """Handler for tasks/get requests from server.
55+
56+ WARNING: This is experimental and may change without notice.
57+ """
58+
59+ async def __call__ (
60+ self ,
61+ context : RequestContext ["ClientSession" , Any ],
62+ params : types .GetTaskRequestParams ,
63+ ) -> types .GetTaskResult | types .ErrorData : ... # pragma: no branch
64+
65+
66+ class GetTaskResultHandlerFnT (Protocol ):
67+ """Handler for tasks/result requests from server.
68+
69+ WARNING: This is experimental and may change without notice.
70+ """
71+
72+ async def __call__ (
73+ self ,
74+ context : RequestContext ["ClientSession" , Any ],
75+ params : types .GetTaskPayloadRequestParams ,
76+ ) -> types .GetTaskPayloadResult | types .ErrorData : ... # pragma: no branch
77+
78+
79+ class ListTasksHandlerFnT (Protocol ):
80+ """Handler for tasks/list requests from server.
81+
82+ WARNING: This is experimental and may change without notice.
83+ """
84+
85+ async def __call__ (
86+ self ,
87+ context : RequestContext ["ClientSession" , Any ],
88+ params : types .PaginatedRequestParams | None ,
89+ ) -> types .ListTasksResult | types .ErrorData : ... # pragma: no branch
90+
91+
92+ class CancelTaskHandlerFnT (Protocol ):
93+ """Handler for tasks/cancel requests from server.
94+
95+ WARNING: This is experimental and may change without notice.
96+ """
97+
98+ async def __call__ (
99+ self ,
100+ context : RequestContext ["ClientSession" , Any ],
101+ params : types .CancelTaskRequestParams ,
102+ ) -> types .CancelTaskResult | types .ErrorData : ... # pragma: no branch
103+
104+
105+ class TaskAugmentedSamplingFnT (Protocol ):
106+ """Handler for task-augmented sampling/createMessage requests from server.
107+
108+ When server sends a CreateMessageRequest with task field, this callback
109+ is invoked. The callback should create a task, spawn background work,
110+ and return CreateTaskResult immediately.
111+
112+ WARNING: This is experimental and may change without notice.
113+ """
114+
115+ async def __call__ (
116+ self ,
117+ context : RequestContext ["ClientSession" , Any ],
118+ params : types .CreateMessageRequestParams ,
119+ task_metadata : types .TaskMetadata ,
120+ ) -> types .CreateTaskResult | types .ErrorData : ... # pragma: no branch
121+
122+
123+ class TaskAugmentedElicitationFnT (Protocol ):
124+ """Handler for task-augmented elicitation/create requests from server.
125+
126+ When server sends an ElicitRequest with task field, this callback
127+ is invoked. The callback should create a task, spawn background work,
128+ and return CreateTaskResult immediately.
129+
130+ WARNING: This is experimental and may change without notice.
131+ """
132+
133+ async def __call__ (
134+ self ,
135+ context : RequestContext ["ClientSession" , Any ],
136+ params : types .ElicitRequestParams ,
137+ task_metadata : types .TaskMetadata ,
138+ ) -> types .CreateTaskResult | types .ErrorData : ... # pragma: no branch
139+
140+
52141class MessageHandlerFnT (Protocol ):
53142 async def __call__ (
54143 self ,
@@ -97,6 +186,69 @@ async def _default_logging_callback(
97186 pass
98187
99188
189+ # Default handlers for experimental task requests (return "not supported" errors)
190+ async def _default_get_task_handler (
191+ context : RequestContext ["ClientSession" , Any ],
192+ params : types .GetTaskRequestParams ,
193+ ) -> types .GetTaskResult | types .ErrorData :
194+ return types .ErrorData (
195+ code = types .METHOD_NOT_FOUND ,
196+ message = "tasks/get not supported" ,
197+ )
198+
199+
200+ async def _default_get_task_result_handler (
201+ context : RequestContext ["ClientSession" , Any ],
202+ params : types .GetTaskPayloadRequestParams ,
203+ ) -> types .GetTaskPayloadResult | types .ErrorData :
204+ return types .ErrorData (
205+ code = types .METHOD_NOT_FOUND ,
206+ message = "tasks/result not supported" ,
207+ )
208+
209+
210+ async def _default_list_tasks_handler (
211+ context : RequestContext ["ClientSession" , Any ],
212+ params : types .PaginatedRequestParams | None ,
213+ ) -> types .ListTasksResult | types .ErrorData :
214+ return types .ErrorData (
215+ code = types .METHOD_NOT_FOUND ,
216+ message = "tasks/list not supported" ,
217+ )
218+
219+
220+ async def _default_cancel_task_handler (
221+ context : RequestContext ["ClientSession" , Any ],
222+ params : types .CancelTaskRequestParams ,
223+ ) -> types .CancelTaskResult | types .ErrorData :
224+ return types .ErrorData (
225+ code = types .METHOD_NOT_FOUND ,
226+ message = "tasks/cancel not supported" ,
227+ )
228+
229+
230+ async def _default_task_augmented_sampling_callback (
231+ context : RequestContext ["ClientSession" , Any ],
232+ params : types .CreateMessageRequestParams ,
233+ task_metadata : types .TaskMetadata ,
234+ ) -> types .CreateTaskResult | types .ErrorData :
235+ return types .ErrorData (
236+ code = types .INVALID_REQUEST ,
237+ message = "Task-augmented sampling not supported" ,
238+ )
239+
240+
241+ async def _default_task_augmented_elicitation_callback (
242+ context : RequestContext ["ClientSession" , Any ],
243+ params : types .ElicitRequestParams ,
244+ task_metadata : types .TaskMetadata ,
245+ ) -> types .CreateTaskResult | types .ErrorData :
246+ return types .ErrorData (
247+ code = types .INVALID_REQUEST ,
248+ message = "Task-augmented elicitation not supported" ,
249+ )
250+
251+
100252ClientResponse : TypeAdapter [types .ClientResult | types .ErrorData ] = TypeAdapter (types .ClientResult | types .ErrorData )
101253
102254
@@ -120,6 +272,14 @@ def __init__(
120272 logging_callback : LoggingFnT | None = None ,
121273 message_handler : MessageHandlerFnT | None = None ,
122274 client_info : types .Implementation | None = None ,
275+ tasks_capability : types .ClientTasksCapability | None = None ,
276+ # Experimental: Task handlers for server -> client requests
277+ get_task_handler : GetTaskHandlerFnT | None = None ,
278+ get_task_result_handler : GetTaskResultHandlerFnT | None = None ,
279+ list_tasks_handler : ListTasksHandlerFnT | None = None ,
280+ cancel_task_handler : CancelTaskHandlerFnT | None = None ,
281+ task_augmented_sampling_callback : TaskAugmentedSamplingFnT | None = None ,
282+ task_augmented_elicitation_callback : TaskAugmentedElicitationFnT | None = None ,
123283 ) -> None :
124284 super ().__init__ (
125285 read_stream ,
@@ -134,9 +294,21 @@ def __init__(
134294 self ._list_roots_callback = list_roots_callback or _default_list_roots_callback
135295 self ._logging_callback = logging_callback or _default_logging_callback
136296 self ._message_handler = message_handler or _default_message_handler
297+ self ._tasks_capability = tasks_capability
137298 self ._tool_output_schemas : dict [str , dict [str , Any ] | None ] = {}
138299 self ._server_capabilities : types .ServerCapabilities | None = None
139300 self ._experimental : ExperimentalClientFeatures | None = None
301+ # Experimental: Task handlers
302+ self ._get_task_handler = get_task_handler or _default_get_task_handler
303+ self ._get_task_result_handler = get_task_result_handler or _default_get_task_result_handler
304+ self ._list_tasks_handler = list_tasks_handler or _default_list_tasks_handler
305+ self ._cancel_task_handler = cancel_task_handler or _default_cancel_task_handler
306+ self ._task_augmented_sampling_callback = (
307+ task_augmented_sampling_callback or _default_task_augmented_sampling_callback
308+ )
309+ self ._task_augmented_elicitation_callback = (
310+ task_augmented_elicitation_callback or _default_task_augmented_elicitation_callback
311+ )
140312
141313 async def initialize (self ) -> types .InitializeResult :
142314 sampling = types .SamplingCapability () if self ._sampling_callback is not _default_sampling_callback else None
@@ -162,6 +334,7 @@ async def initialize(self) -> types.InitializeResult:
162334 elicitation = elicitation ,
163335 experimental = None ,
164336 roots = roots ,
337+ tasks = self ._tasks_capability ,
165338 ),
166339 clientInfo = self ._client_info ,
167340 ),
@@ -187,7 +360,7 @@ def get_server_capabilities(self) -> types.ServerCapabilities | None:
187360 return self ._server_capabilities
188361
189362 @property
190- def experimental (self ) -> " ExperimentalClientFeatures" :
363+ def experimental (self ) -> ExperimentalClientFeatures :
191364 """Experimental APIs for tasks and other features.
192365
193366 WARNING: These APIs are experimental and may change without notice.
@@ -534,13 +707,21 @@ async def _received_request(self, responder: RequestResponder[types.ServerReques
534707 match responder .request .root :
535708 case types .CreateMessageRequest (params = params ):
536709 with responder :
537- response = await self ._sampling_callback (ctx , params )
710+ # Check if this is a task-augmented request
711+ if params .task is not None :
712+ response = await self ._task_augmented_sampling_callback (ctx , params , params .task )
713+ else :
714+ response = await self ._sampling_callback (ctx , params )
538715 client_response = ClientResponse .validate_python (response )
539716 await responder .respond (client_response )
540717
541718 case types .ElicitRequest (params = params ):
542719 with responder :
543- response = await self ._elicitation_callback (ctx , params )
720+ # Check if this is a task-augmented request
721+ if params .task is not None :
722+ response = await self ._task_augmented_elicitation_callback (ctx , params , params .task )
723+ else :
724+ response = await self ._elicitation_callback (ctx , params )
544725 client_response = ClientResponse .validate_python (response )
545726 await responder .respond (client_response )
546727
@@ -553,7 +734,33 @@ async def _received_request(self, responder: RequestResponder[types.ServerReques
553734 case types .PingRequest (): # pragma: no cover
554735 with responder :
555736 return await responder .respond (types .ClientResult (root = types .EmptyResult ()))
556- case _:
737+
738+ # Experimental: Task management requests from server
739+ case types .GetTaskRequest (params = params ):
740+ with responder :
741+ response = await self ._get_task_handler (ctx , params )
742+ client_response = ClientResponse .validate_python (response )
743+ await responder .respond (client_response )
744+
745+ case types .GetTaskPayloadRequest (params = params ):
746+ with responder :
747+ response = await self ._get_task_result_handler (ctx , params )
748+ client_response = ClientResponse .validate_python (response )
749+ await responder .respond (client_response )
750+
751+ case types .ListTasksRequest (params = params ):
752+ with responder :
753+ response = await self ._list_tasks_handler (ctx , params )
754+ client_response = ClientResponse .validate_python (response )
755+ await responder .respond (client_response )
756+
757+ case types .CancelTaskRequest (params = params ):
758+ with responder :
759+ response = await self ._cancel_task_handler (ctx , params )
760+ client_response = ClientResponse .validate_python (response )
761+ await responder .respond (client_response )
762+
763+ case _: # pragma: no cover
557764 raise NotImplementedError ()
558765
559766 async def _handle_incoming (
0 commit comments