1212- Server polls client's task status via tasks/get, tasks/result, etc.
1313"""
1414
15+ from dataclasses import dataclass , field
1516from typing import TYPE_CHECKING , Any , Protocol
1617
1718import mcp .types as types
1819from mcp .shared .context import RequestContext
20+ from mcp .shared .session import RequestResponder
1921
2022if TYPE_CHECKING :
2123 from mcp .client .session import ClientSession
@@ -109,7 +111,11 @@ async def __call__(
109111 ) -> types .CreateTaskResult | types .ErrorData : ... # pragma: no branch
110112
111113
112- # Default handlers for experimental task requests (return "not supported" errors)
114+ # =============================================================================
115+ # Default Handlers (return "not supported" errors)
116+ # =============================================================================
117+
118+
113119async def default_get_task_handler (
114120 context : RequestContext ["ClientSession" , Any ],
115121 params : types .GetTaskRequestParams ,
@@ -150,7 +156,7 @@ async def default_cancel_task_handler(
150156 )
151157
152158
153- async def default_task_augmented_sampling_callback (
159+ async def default_task_augmented_sampling (
154160 context : RequestContext ["ClientSession" , Any ],
155161 params : types .CreateMessageRequestParams ,
156162 task_metadata : types .TaskMetadata ,
@@ -161,7 +167,7 @@ async def default_task_augmented_sampling_callback(
161167 )
162168
163169
164- async def default_task_augmented_elicitation_callback (
170+ async def default_task_augmented_elicitation (
165171 context : RequestContext ["ClientSession" , Any ],
166172 params : types .ElicitRequestParams ,
167173 task_metadata : types .TaskMetadata ,
@@ -172,58 +178,118 @@ async def default_task_augmented_elicitation_callback(
172178 )
173179
174180
175- def build_client_tasks_capability (
176- * ,
177- list_tasks_handler : ListTasksHandlerFnT | None = None ,
178- cancel_task_handler : CancelTaskHandlerFnT | None = None ,
179- task_augmented_sampling_callback : TaskAugmentedSamplingFnT | None = None ,
180- task_augmented_elicitation_callback : TaskAugmentedElicitationFnT | None = None ,
181- ) -> types .ClientTasksCapability | None :
182- """Build ClientTasksCapability from the provided handlers.
183-
184- This helper builds the appropriate capability object based on which
185- handlers are provided (non-None and not the default handlers).
181+ @dataclass
182+ class ExperimentalTaskHandlers :
183+ """Container for experimental task handlers.
186184
187- WARNING: This is experimental and may change without notice.
185+ Groups all task-related handlers that handle server -> client requests.
186+ This includes both pure task requests (get, list, cancel, result) and
187+ task-augmented request handlers (sampling, elicitation with task field).
188188
189- Args:
190- list_tasks_handler: Handler for tasks/list requests
191- cancel_task_handler: Handler for tasks/cancel requests
192- task_augmented_sampling_callback: Handler for task-augmented sampling
193- task_augmented_elicitation_callback: Handler for task-augmented elicitation
189+ WARNING: These APIs are experimental and may change without notice.
194190
195- Returns:
196- ClientTasksCapability if any handlers are provided, None otherwise
191+ Example:
192+ handlers = ExperimentalTaskHandlers(
193+ get_task=my_get_task_handler,
194+ list_tasks=my_list_tasks_handler,
195+ )
196+ session = ClientSession(..., experimental_task_handlers=handlers)
197197 """
198- has_list = list_tasks_handler is not None and list_tasks_handler is not default_list_tasks_handler
199- has_cancel = cancel_task_handler is not None and cancel_task_handler is not default_cancel_task_handler
200- has_sampling = (
201- task_augmented_sampling_callback is not None
202- and task_augmented_sampling_callback is not default_task_augmented_sampling_callback
203- )
204- has_elicitation = (
205- task_augmented_elicitation_callback is not None
206- and task_augmented_elicitation_callback is not default_task_augmented_elicitation_callback
207- )
208198
209- # If no handlers are provided, return None
210- if not any ([has_list , has_cancel , has_sampling , has_elicitation ]):
211- return None
212-
213- # Build requests capability if any request handlers are provided
214- requests_capability : types .ClientTasksRequestsCapability | None = None
215- if has_sampling or has_elicitation :
216- requests_capability = types .ClientTasksRequestsCapability (
217- sampling = types .TasksSamplingCapability (createMessage = types .TasksCreateMessageCapability ())
218- if has_sampling
219- else None ,
220- elicitation = types .TasksElicitationCapability (create = types .TasksCreateElicitationCapability ())
221- if has_elicitation
222- else None ,
199+ # Pure task request handlers
200+ get_task : GetTaskHandlerFnT = field (default = default_get_task_handler )
201+ get_task_result : GetTaskResultHandlerFnT = field (default = default_get_task_result_handler )
202+ list_tasks : ListTasksHandlerFnT = field (default = default_list_tasks_handler )
203+ cancel_task : CancelTaskHandlerFnT = field (default = default_cancel_task_handler )
204+
205+ # Task-augmented request handlers
206+ augmented_sampling : TaskAugmentedSamplingFnT = field (default = default_task_augmented_sampling )
207+ augmented_elicitation : TaskAugmentedElicitationFnT = field (default = default_task_augmented_elicitation )
208+
209+ def build_capability (self ) -> types .ClientTasksCapability | None :
210+ """Build ClientTasksCapability from the configured handlers.
211+
212+ Returns a capability object that reflects which handlers are configured
213+ (i.e., not using the default "not supported" handlers).
214+
215+ Returns:
216+ ClientTasksCapability if any handlers are provided, None otherwise
217+ """
218+ has_list = self .list_tasks is not default_list_tasks_handler
219+ has_cancel = self .cancel_task is not default_cancel_task_handler
220+ has_sampling = self .augmented_sampling is not default_task_augmented_sampling
221+ has_elicitation = self .augmented_elicitation is not default_task_augmented_elicitation
222+
223+ # If no handlers are provided, return None
224+ if not any ([has_list , has_cancel , has_sampling , has_elicitation ]):
225+ return None
226+
227+ # Build requests capability if any request handlers are provided
228+ requests_capability : types .ClientTasksRequestsCapability | None = None
229+ if has_sampling or has_elicitation :
230+ requests_capability = types .ClientTasksRequestsCapability (
231+ sampling = types .TasksSamplingCapability (createMessage = types .TasksCreateMessageCapability ())
232+ if has_sampling
233+ else None ,
234+ elicitation = types .TasksElicitationCapability (create = types .TasksCreateElicitationCapability ())
235+ if has_elicitation
236+ else None ,
237+ )
238+
239+ return types .ClientTasksCapability (
240+ list = types .TasksListCapability () if has_list else None ,
241+ cancel = types .TasksCancelCapability () if has_cancel else None ,
242+ requests = requests_capability ,
223243 )
224244
225- return types .ClientTasksCapability (
226- list = types .TasksListCapability () if has_list else None ,
227- cancel = types .TasksCancelCapability () if has_cancel else None ,
228- requests = requests_capability ,
229- )
245+ @staticmethod
246+ def handles_request (request : types .ServerRequest ) -> bool :
247+ """Check if this handler handles the given request type."""
248+ return isinstance (
249+ request .root ,
250+ types .GetTaskRequest | types .GetTaskPayloadRequest | types .ListTasksRequest | types .CancelTaskRequest ,
251+ )
252+
253+ async def handle_request (
254+ self ,
255+ ctx : RequestContext ["ClientSession" , Any ],
256+ responder : RequestResponder [types .ServerRequest , types .ClientResult ],
257+ ) -> None :
258+ """Handle a task-related request from the server.
259+
260+ Call handles_request() first to check if this handler can handle the request.
261+ """
262+ from pydantic import TypeAdapter
263+
264+ client_response_type : TypeAdapter [types .ClientResult | types .ErrorData ] = TypeAdapter (
265+ types .ClientResult | types .ErrorData
266+ )
267+
268+ match responder .request .root :
269+ case types .GetTaskRequest (params = params ):
270+ response = await self .get_task (ctx , params )
271+ client_response = client_response_type .validate_python (response )
272+ await responder .respond (client_response )
273+
274+ case types .GetTaskPayloadRequest (params = params ):
275+ response = await self .get_task_result (ctx , params )
276+ client_response = client_response_type .validate_python (response )
277+ await responder .respond (client_response )
278+
279+ case types .ListTasksRequest (params = params ):
280+ response = await self .list_tasks (ctx , params )
281+ client_response = client_response_type .validate_python (response )
282+ await responder .respond (client_response )
283+
284+ case types .CancelTaskRequest (params = params ):
285+ response = await self .cancel_task (ctx , params )
286+ client_response = client_response_type .validate_python (response )
287+ await responder .respond (client_response )
288+
289+ case _: # pragma: no cover
290+ raise ValueError (f"Unhandled request type: { type (responder .request .root )} " )
291+
292+
293+ # Backwards compatibility aliases
294+ default_task_augmented_sampling_callback = default_task_augmented_sampling
295+ default_task_augmented_elicitation_callback = default_task_augmented_elicitation
0 commit comments