Skip to content

Commit 52e8c41

Browse files
committed
Add client-side task handler protocols and auto-capability building
- Move task handler protocols to experimental/task_handlers.py - Add build_client_tasks_capability() helper to auto-build ClientTasksCapability from handlers - ClientSession now automatically infers tasks capability from provided handlers - Add Resolver class for async result handling in task message queues - Refactor result_handler to use Resolver pattern - Add test for auto-built capabilities from handlers
1 parent ca7ccac commit 52e8c41

File tree

10 files changed

+472
-264
lines changed

10 files changed

+472
-264
lines changed
Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
"""
2+
Experimental task handler protocols for server -> client requests.
3+
4+
This module provides Protocol types and default handlers for when servers
5+
send task-related requests to clients (the reverse of normal client -> server flow).
6+
7+
WARNING: These APIs are experimental and may change without notice.
8+
9+
Use cases:
10+
- Server sends task-augmented sampling/elicitation request to client
11+
- Client creates a local task, spawns background work, returns CreateTaskResult
12+
- Server polls client's task status via tasks/get, tasks/result, etc.
13+
"""
14+
15+
from typing import TYPE_CHECKING, Any, Protocol
16+
17+
import mcp.types as types
18+
from mcp.shared.context import RequestContext
19+
20+
if TYPE_CHECKING:
21+
from mcp.client.session import ClientSession
22+
23+
24+
class GetTaskHandlerFnT(Protocol):
25+
"""Handler for tasks/get requests from server.
26+
27+
WARNING: This is experimental and may change without notice.
28+
"""
29+
30+
async def __call__(
31+
self,
32+
context: RequestContext["ClientSession", Any],
33+
params: types.GetTaskRequestParams,
34+
) -> types.GetTaskResult | types.ErrorData: ... # pragma: no branch
35+
36+
37+
class GetTaskResultHandlerFnT(Protocol):
38+
"""Handler for tasks/result requests from server.
39+
40+
WARNING: This is experimental and may change without notice.
41+
"""
42+
43+
async def __call__(
44+
self,
45+
context: RequestContext["ClientSession", Any],
46+
params: types.GetTaskPayloadRequestParams,
47+
) -> types.GetTaskPayloadResult | types.ErrorData: ... # pragma: no branch
48+
49+
50+
class ListTasksHandlerFnT(Protocol):
51+
"""Handler for tasks/list requests from server.
52+
53+
WARNING: This is experimental and may change without notice.
54+
"""
55+
56+
async def __call__(
57+
self,
58+
context: RequestContext["ClientSession", Any],
59+
params: types.PaginatedRequestParams | None,
60+
) -> types.ListTasksResult | types.ErrorData: ... # pragma: no branch
61+
62+
63+
class CancelTaskHandlerFnT(Protocol):
64+
"""Handler for tasks/cancel requests from server.
65+
66+
WARNING: This is experimental and may change without notice.
67+
"""
68+
69+
async def __call__(
70+
self,
71+
context: RequestContext["ClientSession", Any],
72+
params: types.CancelTaskRequestParams,
73+
) -> types.CancelTaskResult | types.ErrorData: ... # pragma: no branch
74+
75+
76+
class TaskAugmentedSamplingFnT(Protocol):
77+
"""Handler for task-augmented sampling/createMessage requests from server.
78+
79+
When server sends a CreateMessageRequest with task field, this callback
80+
is invoked. The callback should create a task, spawn background work,
81+
and return CreateTaskResult immediately.
82+
83+
WARNING: This is experimental and may change without notice.
84+
"""
85+
86+
async def __call__(
87+
self,
88+
context: RequestContext["ClientSession", Any],
89+
params: types.CreateMessageRequestParams,
90+
task_metadata: types.TaskMetadata,
91+
) -> types.CreateTaskResult | types.ErrorData: ... # pragma: no branch
92+
93+
94+
class TaskAugmentedElicitationFnT(Protocol):
95+
"""Handler for task-augmented elicitation/create requests from server.
96+
97+
When server sends an ElicitRequest with task field, this callback
98+
is invoked. The callback should create a task, spawn background work,
99+
and return CreateTaskResult immediately.
100+
101+
WARNING: This is experimental and may change without notice.
102+
"""
103+
104+
async def __call__(
105+
self,
106+
context: RequestContext["ClientSession", Any],
107+
params: types.ElicitRequestParams,
108+
task_metadata: types.TaskMetadata,
109+
) -> types.CreateTaskResult | types.ErrorData: ... # pragma: no branch
110+
111+
112+
# Default handlers for experimental task requests (return "not supported" errors)
113+
async def default_get_task_handler(
114+
context: RequestContext["ClientSession", Any],
115+
params: types.GetTaskRequestParams,
116+
) -> types.GetTaskResult | types.ErrorData:
117+
return types.ErrorData(
118+
code=types.METHOD_NOT_FOUND,
119+
message="tasks/get not supported",
120+
)
121+
122+
123+
async def default_get_task_result_handler(
124+
context: RequestContext["ClientSession", Any],
125+
params: types.GetTaskPayloadRequestParams,
126+
) -> types.GetTaskPayloadResult | types.ErrorData:
127+
return types.ErrorData(
128+
code=types.METHOD_NOT_FOUND,
129+
message="tasks/result not supported",
130+
)
131+
132+
133+
async def default_list_tasks_handler(
134+
context: RequestContext["ClientSession", Any],
135+
params: types.PaginatedRequestParams | None,
136+
) -> types.ListTasksResult | types.ErrorData:
137+
return types.ErrorData(
138+
code=types.METHOD_NOT_FOUND,
139+
message="tasks/list not supported",
140+
)
141+
142+
143+
async def default_cancel_task_handler(
144+
context: RequestContext["ClientSession", Any],
145+
params: types.CancelTaskRequestParams,
146+
) -> types.CancelTaskResult | types.ErrorData:
147+
return types.ErrorData(
148+
code=types.METHOD_NOT_FOUND,
149+
message="tasks/cancel not supported",
150+
)
151+
152+
153+
async def default_task_augmented_sampling_callback(
154+
context: RequestContext["ClientSession", Any],
155+
params: types.CreateMessageRequestParams,
156+
task_metadata: types.TaskMetadata,
157+
) -> types.CreateTaskResult | types.ErrorData:
158+
return types.ErrorData(
159+
code=types.INVALID_REQUEST,
160+
message="Task-augmented sampling not supported",
161+
)
162+
163+
164+
async def default_task_augmented_elicitation_callback(
165+
context: RequestContext["ClientSession", Any],
166+
params: types.ElicitRequestParams,
167+
task_metadata: types.TaskMetadata,
168+
) -> types.CreateTaskResult | types.ErrorData:
169+
return types.ErrorData(
170+
code=types.INVALID_REQUEST,
171+
message="Task-augmented elicitation not supported",
172+
)
173+
174+
175+
def build_client_tasks_capability(
176+
*,
177+
list_tasks_handler: ListTasksHandlerFnT | None = None,
178+
cancel_task_handler: CancelTaskHandlerFnT | None = None,
179+
task_augmented_sampling_callback: TaskAugmentedSamplingFnT | None = None,
180+
task_augmented_elicitation_callback: TaskAugmentedElicitationFnT | None = None,
181+
) -> types.ClientTasksCapability | None:
182+
"""Build ClientTasksCapability from the provided handlers.
183+
184+
This helper builds the appropriate capability object based on which
185+
handlers are provided (non-None and not the default handlers).
186+
187+
WARNING: This is experimental and may change without notice.
188+
189+
Args:
190+
list_tasks_handler: Handler for tasks/list requests
191+
cancel_task_handler: Handler for tasks/cancel requests
192+
task_augmented_sampling_callback: Handler for task-augmented sampling
193+
task_augmented_elicitation_callback: Handler for task-augmented elicitation
194+
195+
Returns:
196+
ClientTasksCapability if any handlers are provided, None otherwise
197+
"""
198+
has_list = list_tasks_handler is not None and list_tasks_handler is not default_list_tasks_handler
199+
has_cancel = cancel_task_handler is not None and cancel_task_handler is not default_cancel_task_handler
200+
has_sampling = (
201+
task_augmented_sampling_callback is not None
202+
and task_augmented_sampling_callback is not default_task_augmented_sampling_callback
203+
)
204+
has_elicitation = (
205+
task_augmented_elicitation_callback is not None
206+
and task_augmented_elicitation_callback is not default_task_augmented_elicitation_callback
207+
)
208+
209+
# If no handlers are provided, return None
210+
if not any([has_list, has_cancel, has_sampling, has_elicitation]):
211+
return None
212+
213+
# Build requests capability if any request handlers are provided
214+
requests_capability: types.ClientTasksRequestsCapability | None = None
215+
if has_sampling or has_elicitation:
216+
requests_capability = types.ClientTasksRequestsCapability(
217+
sampling=types.TasksSamplingCapability(createMessage=types.TasksCreateMessageCapability())
218+
if has_sampling
219+
else None,
220+
elicitation=types.TasksElicitationCapability(create=types.TasksCreateElicitationCapability())
221+
if has_elicitation
222+
else None,
223+
)
224+
225+
return types.ClientTasksCapability(
226+
list=types.TasksListCapability() if has_list else None,
227+
cancel=types.TasksCancelCapability() if has_cancel else None,
228+
requests=requests_capability,
229+
)

0 commit comments

Comments
 (0)