Skip to content

Commit 77155fc

Browse files
committed
Refactor client task handlers into ExperimentalTaskHandlers dataclass
- Replace 6 individual task handler parameters with single `experimental_task_handlers: ExperimentalTaskHandlers` (keyword-only) - ExperimentalTaskHandlers dataclass groups all handlers and provides: - `build_capability()` - auto-builds ClientTasksCapability from handlers - `handles_request()` - checks if request is task-related - `handle_request()` - dispatches to appropriate handler - Simplify ClientSession._received_request by delegating task requests - Update tests to use new ExperimentalTaskHandlers API
1 parent 52e8c41 commit 77155fc

File tree

4 files changed

+184
-233
lines changed

4 files changed

+184
-233
lines changed

src/mcp/client/experimental/task_handlers.py

Lines changed: 117 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@
1212
- Server polls client's task status via tasks/get, tasks/result, etc.
1313
"""
1414

15+
from dataclasses import dataclass, field
1516
from typing import TYPE_CHECKING, Any, Protocol
1617

1718
import mcp.types as types
1819
from mcp.shared.context import RequestContext
20+
from mcp.shared.session import RequestResponder
1921

2022
if TYPE_CHECKING:
2123
from mcp.client.session import ClientSession
@@ -109,7 +111,11 @@ async def __call__(
109111
) -> types.CreateTaskResult | types.ErrorData: ... # pragma: no branch
110112

111113

112-
# Default handlers for experimental task requests (return "not supported" errors)
114+
# =============================================================================
115+
# Default Handlers (return "not supported" errors)
116+
# =============================================================================
117+
118+
113119
async def default_get_task_handler(
114120
context: RequestContext["ClientSession", Any],
115121
params: types.GetTaskRequestParams,
@@ -150,7 +156,7 @@ async def default_cancel_task_handler(
150156
)
151157

152158

153-
async def default_task_augmented_sampling_callback(
159+
async def default_task_augmented_sampling(
154160
context: RequestContext["ClientSession", Any],
155161
params: types.CreateMessageRequestParams,
156162
task_metadata: types.TaskMetadata,
@@ -161,7 +167,7 @@ async def default_task_augmented_sampling_callback(
161167
)
162168

163169

164-
async def default_task_augmented_elicitation_callback(
170+
async def default_task_augmented_elicitation(
165171
context: RequestContext["ClientSession", Any],
166172
params: types.ElicitRequestParams,
167173
task_metadata: types.TaskMetadata,
@@ -172,58 +178,118 @@ async def default_task_augmented_elicitation_callback(
172178
)
173179

174180

175-
def build_client_tasks_capability(
176-
*,
177-
list_tasks_handler: ListTasksHandlerFnT | None = None,
178-
cancel_task_handler: CancelTaskHandlerFnT | None = None,
179-
task_augmented_sampling_callback: TaskAugmentedSamplingFnT | None = None,
180-
task_augmented_elicitation_callback: TaskAugmentedElicitationFnT | None = None,
181-
) -> types.ClientTasksCapability | None:
182-
"""Build ClientTasksCapability from the provided handlers.
183-
184-
This helper builds the appropriate capability object based on which
185-
handlers are provided (non-None and not the default handlers).
181+
@dataclass
182+
class ExperimentalTaskHandlers:
183+
"""Container for experimental task handlers.
186184
187-
WARNING: This is experimental and may change without notice.
185+
Groups all task-related handlers that handle server -> client requests.
186+
This includes both pure task requests (get, list, cancel, result) and
187+
task-augmented request handlers (sampling, elicitation with task field).
188188
189-
Args:
190-
list_tasks_handler: Handler for tasks/list requests
191-
cancel_task_handler: Handler for tasks/cancel requests
192-
task_augmented_sampling_callback: Handler for task-augmented sampling
193-
task_augmented_elicitation_callback: Handler for task-augmented elicitation
189+
WARNING: These APIs are experimental and may change without notice.
194190
195-
Returns:
196-
ClientTasksCapability if any handlers are provided, None otherwise
191+
Example:
192+
handlers = ExperimentalTaskHandlers(
193+
get_task=my_get_task_handler,
194+
list_tasks=my_list_tasks_handler,
195+
)
196+
session = ClientSession(..., experimental_task_handlers=handlers)
197197
"""
198-
has_list = list_tasks_handler is not None and list_tasks_handler is not default_list_tasks_handler
199-
has_cancel = cancel_task_handler is not None and cancel_task_handler is not default_cancel_task_handler
200-
has_sampling = (
201-
task_augmented_sampling_callback is not None
202-
and task_augmented_sampling_callback is not default_task_augmented_sampling_callback
203-
)
204-
has_elicitation = (
205-
task_augmented_elicitation_callback is not None
206-
and task_augmented_elicitation_callback is not default_task_augmented_elicitation_callback
207-
)
208198

209-
# If no handlers are provided, return None
210-
if not any([has_list, has_cancel, has_sampling, has_elicitation]):
211-
return None
212-
213-
# Build requests capability if any request handlers are provided
214-
requests_capability: types.ClientTasksRequestsCapability | None = None
215-
if has_sampling or has_elicitation:
216-
requests_capability = types.ClientTasksRequestsCapability(
217-
sampling=types.TasksSamplingCapability(createMessage=types.TasksCreateMessageCapability())
218-
if has_sampling
219-
else None,
220-
elicitation=types.TasksElicitationCapability(create=types.TasksCreateElicitationCapability())
221-
if has_elicitation
222-
else None,
199+
# Pure task request handlers
200+
get_task: GetTaskHandlerFnT = field(default=default_get_task_handler)
201+
get_task_result: GetTaskResultHandlerFnT = field(default=default_get_task_result_handler)
202+
list_tasks: ListTasksHandlerFnT = field(default=default_list_tasks_handler)
203+
cancel_task: CancelTaskHandlerFnT = field(default=default_cancel_task_handler)
204+
205+
# Task-augmented request handlers
206+
augmented_sampling: TaskAugmentedSamplingFnT = field(default=default_task_augmented_sampling)
207+
augmented_elicitation: TaskAugmentedElicitationFnT = field(default=default_task_augmented_elicitation)
208+
209+
def build_capability(self) -> types.ClientTasksCapability | None:
210+
"""Build ClientTasksCapability from the configured handlers.
211+
212+
Returns a capability object that reflects which handlers are configured
213+
(i.e., not using the default "not supported" handlers).
214+
215+
Returns:
216+
ClientTasksCapability if any handlers are provided, None otherwise
217+
"""
218+
has_list = self.list_tasks is not default_list_tasks_handler
219+
has_cancel = self.cancel_task is not default_cancel_task_handler
220+
has_sampling = self.augmented_sampling is not default_task_augmented_sampling
221+
has_elicitation = self.augmented_elicitation is not default_task_augmented_elicitation
222+
223+
# If no handlers are provided, return None
224+
if not any([has_list, has_cancel, has_sampling, has_elicitation]):
225+
return None
226+
227+
# Build requests capability if any request handlers are provided
228+
requests_capability: types.ClientTasksRequestsCapability | None = None
229+
if has_sampling or has_elicitation:
230+
requests_capability = types.ClientTasksRequestsCapability(
231+
sampling=types.TasksSamplingCapability(createMessage=types.TasksCreateMessageCapability())
232+
if has_sampling
233+
else None,
234+
elicitation=types.TasksElicitationCapability(create=types.TasksCreateElicitationCapability())
235+
if has_elicitation
236+
else None,
237+
)
238+
239+
return types.ClientTasksCapability(
240+
list=types.TasksListCapability() if has_list else None,
241+
cancel=types.TasksCancelCapability() if has_cancel else None,
242+
requests=requests_capability,
223243
)
224244

225-
return types.ClientTasksCapability(
226-
list=types.TasksListCapability() if has_list else None,
227-
cancel=types.TasksCancelCapability() if has_cancel else None,
228-
requests=requests_capability,
229-
)
245+
@staticmethod
246+
def handles_request(request: types.ServerRequest) -> bool:
247+
"""Check if this handler handles the given request type."""
248+
return isinstance(
249+
request.root,
250+
types.GetTaskRequest | types.GetTaskPayloadRequest | types.ListTasksRequest | types.CancelTaskRequest,
251+
)
252+
253+
async def handle_request(
254+
self,
255+
ctx: RequestContext["ClientSession", Any],
256+
responder: RequestResponder[types.ServerRequest, types.ClientResult],
257+
) -> None:
258+
"""Handle a task-related request from the server.
259+
260+
Call handles_request() first to check if this handler can handle the request.
261+
"""
262+
from pydantic import TypeAdapter
263+
264+
client_response_type: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(
265+
types.ClientResult | types.ErrorData
266+
)
267+
268+
match responder.request.root:
269+
case types.GetTaskRequest(params=params):
270+
response = await self.get_task(ctx, params)
271+
client_response = client_response_type.validate_python(response)
272+
await responder.respond(client_response)
273+
274+
case types.GetTaskPayloadRequest(params=params):
275+
response = await self.get_task_result(ctx, params)
276+
client_response = client_response_type.validate_python(response)
277+
await responder.respond(client_response)
278+
279+
case types.ListTasksRequest(params=params):
280+
response = await self.list_tasks(ctx, params)
281+
client_response = client_response_type.validate_python(response)
282+
await responder.respond(client_response)
283+
284+
case types.CancelTaskRequest(params=params):
285+
response = await self.cancel_task(ctx, params)
286+
client_response = client_response_type.validate_python(response)
287+
await responder.respond(client_response)
288+
289+
case _: # pragma: no cover
290+
raise ValueError(f"Unhandled request type: {type(responder.request.root)}")
291+
292+
293+
# Backwards compatibility aliases
294+
default_task_augmented_sampling_callback = default_task_augmented_sampling
295+
default_task_augmented_elicitation_callback = default_task_augmented_elicitation

src/mcp/client/session.py

Lines changed: 21 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,7 @@
1010

1111
import mcp.types as types
1212
from mcp.client.experimental import ExperimentalClientFeatures
13-
from mcp.client.experimental.task_handlers import (
14-
CancelTaskHandlerFnT,
15-
GetTaskHandlerFnT,
16-
GetTaskResultHandlerFnT,
17-
ListTasksHandlerFnT,
18-
TaskAugmentedElicitationFnT,
19-
TaskAugmentedSamplingFnT,
20-
build_client_tasks_capability,
21-
default_cancel_task_handler,
22-
default_get_task_handler,
23-
default_get_task_result_handler,
24-
default_list_tasks_handler,
25-
default_task_augmented_elicitation_callback,
26-
default_task_augmented_sampling_callback,
27-
)
13+
from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers
2814
from mcp.shared.context import RequestContext
2915
from mcp.shared.message import SessionMessage
3016
from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder
@@ -135,14 +121,8 @@ def __init__(
135121
logging_callback: LoggingFnT | None = None,
136122
message_handler: MessageHandlerFnT | None = None,
137123
client_info: types.Implementation | None = None,
138-
tasks_capability: types.ClientTasksCapability | None = None,
139-
# Experimental: Task handlers for server -> client requests
140-
get_task_handler: GetTaskHandlerFnT | None = None,
141-
get_task_result_handler: GetTaskResultHandlerFnT | None = None,
142-
list_tasks_handler: ListTasksHandlerFnT | None = None,
143-
cancel_task_handler: CancelTaskHandlerFnT | None = None,
144-
task_augmented_sampling_callback: TaskAugmentedSamplingFnT | None = None,
145-
task_augmented_elicitation_callback: TaskAugmentedElicitationFnT | None = None,
124+
*,
125+
experimental_task_handlers: ExperimentalTaskHandlers | None = None,
146126
) -> None:
147127
super().__init__(
148128
read_stream,
@@ -159,25 +139,10 @@ def __init__(
159139
self._message_handler = message_handler or _default_message_handler
160140
self._tool_output_schemas: dict[str, dict[str, Any] | None] = {}
161141
self._server_capabilities: types.ServerCapabilities | None = None
162-
self._experimental: ExperimentalClientFeatures | None = None
163-
# Experimental: Task handlers
164-
self._get_task_handler = get_task_handler or default_get_task_handler
165-
self._get_task_result_handler = get_task_result_handler or default_get_task_result_handler
166-
self._list_tasks_handler = list_tasks_handler or default_list_tasks_handler
167-
self._cancel_task_handler = cancel_task_handler or default_cancel_task_handler
168-
self._task_augmented_sampling_callback = (
169-
task_augmented_sampling_callback or default_task_augmented_sampling_callback
170-
)
171-
self._task_augmented_elicitation_callback = (
172-
task_augmented_elicitation_callback or default_task_augmented_elicitation_callback
173-
)
174-
# Build tasks capability from handlers if not explicitly provided
175-
self._tasks_capability = tasks_capability or build_client_tasks_capability(
176-
list_tasks_handler=list_tasks_handler,
177-
cancel_task_handler=cancel_task_handler,
178-
task_augmented_sampling_callback=task_augmented_sampling_callback,
179-
task_augmented_elicitation_callback=task_augmented_elicitation_callback,
180-
)
142+
self._experimental_features: ExperimentalClientFeatures | None = None
143+
144+
# Experimental: Task handlers (use defaults if not provided)
145+
self._task_handlers = experimental_task_handlers or ExperimentalTaskHandlers()
181146

182147
async def initialize(self) -> types.InitializeResult:
183148
sampling = types.SamplingCapability() if self._sampling_callback is not _default_sampling_callback else None
@@ -203,7 +168,7 @@ async def initialize(self) -> types.InitializeResult:
203168
elicitation=elicitation,
204169
experimental=None,
205170
roots=roots,
206-
tasks=self._tasks_capability,
171+
tasks=self._task_handlers.build_capability(),
207172
),
208173
clientInfo=self._client_info,
209174
),
@@ -238,9 +203,9 @@ def experimental(self) -> ExperimentalClientFeatures:
238203
status = await session.experimental.get_task(task_id)
239204
result = await session.experimental.get_task_result(task_id, CallToolResult)
240205
"""
241-
if self._experimental is None:
242-
self._experimental = ExperimentalClientFeatures(self)
243-
return self._experimental
206+
if self._experimental_features is None:
207+
self._experimental_features = ExperimentalClientFeatures(self)
208+
return self._experimental_features
244209

245210
async def send_ping(self) -> types.EmptyResult:
246211
"""Send a ping request."""
@@ -573,12 +538,19 @@ async def _received_request(self, responder: RequestResponder[types.ServerReques
573538
lifespan_context=None,
574539
)
575540

541+
# Delegate to experimental task handler if applicable
542+
if self._task_handlers.handles_request(responder.request):
543+
with responder:
544+
await self._task_handlers.handle_request(ctx, responder)
545+
return None
546+
547+
# Core request handling
576548
match responder.request.root:
577549
case types.CreateMessageRequest(params=params):
578550
with responder:
579551
# Check if this is a task-augmented request
580552
if params.task is not None:
581-
response = await self._task_augmented_sampling_callback(ctx, params, params.task)
553+
response = await self._task_handlers.augmented_sampling(ctx, params, params.task)
582554
else:
583555
response = await self._sampling_callback(ctx, params)
584556
client_response = ClientResponse.validate_python(response)
@@ -588,7 +560,7 @@ async def _received_request(self, responder: RequestResponder[types.ServerReques
588560
with responder:
589561
# Check if this is a task-augmented request
590562
if params.task is not None:
591-
response = await self._task_augmented_elicitation_callback(ctx, params, params.task)
563+
response = await self._task_handlers.augmented_elicitation(ctx, params, params.task)
592564
else:
593565
response = await self._elicitation_callback(ctx, params)
594566
client_response = ClientResponse.validate_python(response)
@@ -604,33 +576,9 @@ async def _received_request(self, responder: RequestResponder[types.ServerReques
604576
with responder:
605577
return await responder.respond(types.ClientResult(root=types.EmptyResult()))
606578

607-
# Experimental: Task management requests from server
608-
case types.GetTaskRequest(params=params):
609-
with responder:
610-
response = await self._get_task_handler(ctx, params)
611-
client_response = ClientResponse.validate_python(response)
612-
await responder.respond(client_response)
613-
614-
case types.GetTaskPayloadRequest(params=params):
615-
with responder:
616-
response = await self._get_task_result_handler(ctx, params)
617-
client_response = ClientResponse.validate_python(response)
618-
await responder.respond(client_response)
619-
620-
case types.ListTasksRequest(params=params):
621-
with responder:
622-
response = await self._list_tasks_handler(ctx, params)
623-
client_response = ClientResponse.validate_python(response)
624-
await responder.respond(client_response)
625-
626-
case types.CancelTaskRequest(params=params):
627-
with responder:
628-
response = await self._cancel_task_handler(ctx, params)
629-
client_response = ClientResponse.validate_python(response)
630-
await responder.respond(client_response)
631-
632579
case _: # pragma: no cover
633580
raise NotImplementedError()
581+
return None
634582

635583
async def _handle_incoming(
636584
self,

0 commit comments

Comments
 (0)