Skip to content

Commit ca7ccac

Browse files
committed
notifications and client side
1 parent 341ad92 commit ca7ccac

File tree

17 files changed

+3323
-15
lines changed

17 files changed

+3323
-15
lines changed

examples/servers/simple-task/mcp_simple_task/server.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import anyio
99
import click
1010
import mcp.types as types
11+
import uvicorn
1112
from anyio.abc import TaskGroup
1213
from mcp.server.lowlevel import Server
1314
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
@@ -107,8 +108,6 @@ async def handle_get_task_result(request: types.GetTaskPayloadRequest) -> types.
107108
@click.command()
108109
@click.option("--port", default=8000, help="Port to listen on")
109110
def main(port: int) -> int:
110-
import uvicorn
111-
112111
session_manager = StreamableHTTPSessionManager(app=server)
113112

114113
@asynccontextmanager

src/mcp/client/session.py

Lines changed: 211 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,95 @@ async def __call__(
4949
) -> None: ... # pragma: no branch
5050

5151

52+
# Experimental: Task handler protocols for server -> client requests
53+
class GetTaskHandlerFnT(Protocol):
54+
"""Handler for tasks/get requests from server.
55+
56+
WARNING: This is experimental and may change without notice.
57+
"""
58+
59+
async def __call__(
60+
self,
61+
context: RequestContext["ClientSession", Any],
62+
params: types.GetTaskRequestParams,
63+
) -> types.GetTaskResult | types.ErrorData: ... # pragma: no branch
64+
65+
66+
class GetTaskResultHandlerFnT(Protocol):
67+
"""Handler for tasks/result requests from server.
68+
69+
WARNING: This is experimental and may change without notice.
70+
"""
71+
72+
async def __call__(
73+
self,
74+
context: RequestContext["ClientSession", Any],
75+
params: types.GetTaskPayloadRequestParams,
76+
) -> types.GetTaskPayloadResult | types.ErrorData: ... # pragma: no branch
77+
78+
79+
class ListTasksHandlerFnT(Protocol):
80+
"""Handler for tasks/list requests from server.
81+
82+
WARNING: This is experimental and may change without notice.
83+
"""
84+
85+
async def __call__(
86+
self,
87+
context: RequestContext["ClientSession", Any],
88+
params: types.PaginatedRequestParams | None,
89+
) -> types.ListTasksResult | types.ErrorData: ... # pragma: no branch
90+
91+
92+
class CancelTaskHandlerFnT(Protocol):
93+
"""Handler for tasks/cancel requests from server.
94+
95+
WARNING: This is experimental and may change without notice.
96+
"""
97+
98+
async def __call__(
99+
self,
100+
context: RequestContext["ClientSession", Any],
101+
params: types.CancelTaskRequestParams,
102+
) -> types.CancelTaskResult | types.ErrorData: ... # pragma: no branch
103+
104+
105+
class TaskAugmentedSamplingFnT(Protocol):
106+
"""Handler for task-augmented sampling/createMessage requests from server.
107+
108+
When server sends a CreateMessageRequest with task field, this callback
109+
is invoked. The callback should create a task, spawn background work,
110+
and return CreateTaskResult immediately.
111+
112+
WARNING: This is experimental and may change without notice.
113+
"""
114+
115+
async def __call__(
116+
self,
117+
context: RequestContext["ClientSession", Any],
118+
params: types.CreateMessageRequestParams,
119+
task_metadata: types.TaskMetadata,
120+
) -> types.CreateTaskResult | types.ErrorData: ... # pragma: no branch
121+
122+
123+
class TaskAugmentedElicitationFnT(Protocol):
124+
"""Handler for task-augmented elicitation/create requests from server.
125+
126+
When server sends an ElicitRequest with task field, this callback
127+
is invoked. The callback should create a task, spawn background work,
128+
and return CreateTaskResult immediately.
129+
130+
WARNING: This is experimental and may change without notice.
131+
"""
132+
133+
async def __call__(
134+
self,
135+
context: RequestContext["ClientSession", Any],
136+
params: types.ElicitRequestParams,
137+
task_metadata: types.TaskMetadata,
138+
) -> types.CreateTaskResult | types.ErrorData: ... # pragma: no branch
139+
140+
52141
class MessageHandlerFnT(Protocol):
53142
async def __call__(
54143
self,
@@ -97,6 +186,69 @@ async def _default_logging_callback(
97186
pass
98187

99188

189+
# Default handlers for experimental task requests (return "not supported" errors)
190+
async def _default_get_task_handler(
191+
context: RequestContext["ClientSession", Any],
192+
params: types.GetTaskRequestParams,
193+
) -> types.GetTaskResult | types.ErrorData:
194+
return types.ErrorData(
195+
code=types.METHOD_NOT_FOUND,
196+
message="tasks/get not supported",
197+
)
198+
199+
200+
async def _default_get_task_result_handler(
201+
context: RequestContext["ClientSession", Any],
202+
params: types.GetTaskPayloadRequestParams,
203+
) -> types.GetTaskPayloadResult | types.ErrorData:
204+
return types.ErrorData(
205+
code=types.METHOD_NOT_FOUND,
206+
message="tasks/result not supported",
207+
)
208+
209+
210+
async def _default_list_tasks_handler(
211+
context: RequestContext["ClientSession", Any],
212+
params: types.PaginatedRequestParams | None,
213+
) -> types.ListTasksResult | types.ErrorData:
214+
return types.ErrorData(
215+
code=types.METHOD_NOT_FOUND,
216+
message="tasks/list not supported",
217+
)
218+
219+
220+
async def _default_cancel_task_handler(
221+
context: RequestContext["ClientSession", Any],
222+
params: types.CancelTaskRequestParams,
223+
) -> types.CancelTaskResult | types.ErrorData:
224+
return types.ErrorData(
225+
code=types.METHOD_NOT_FOUND,
226+
message="tasks/cancel not supported",
227+
)
228+
229+
230+
async def _default_task_augmented_sampling_callback(
231+
context: RequestContext["ClientSession", Any],
232+
params: types.CreateMessageRequestParams,
233+
task_metadata: types.TaskMetadata,
234+
) -> types.CreateTaskResult | types.ErrorData:
235+
return types.ErrorData(
236+
code=types.INVALID_REQUEST,
237+
message="Task-augmented sampling not supported",
238+
)
239+
240+
241+
async def _default_task_augmented_elicitation_callback(
242+
context: RequestContext["ClientSession", Any],
243+
params: types.ElicitRequestParams,
244+
task_metadata: types.TaskMetadata,
245+
) -> types.CreateTaskResult | types.ErrorData:
246+
return types.ErrorData(
247+
code=types.INVALID_REQUEST,
248+
message="Task-augmented elicitation not supported",
249+
)
250+
251+
100252
ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(types.ClientResult | types.ErrorData)
101253

102254

@@ -120,6 +272,14 @@ def __init__(
120272
logging_callback: LoggingFnT | None = None,
121273
message_handler: MessageHandlerFnT | None = None,
122274
client_info: types.Implementation | None = None,
275+
tasks_capability: types.ClientTasksCapability | None = None,
276+
# Experimental: Task handlers for server -> client requests
277+
get_task_handler: GetTaskHandlerFnT | None = None,
278+
get_task_result_handler: GetTaskResultHandlerFnT | None = None,
279+
list_tasks_handler: ListTasksHandlerFnT | None = None,
280+
cancel_task_handler: CancelTaskHandlerFnT | None = None,
281+
task_augmented_sampling_callback: TaskAugmentedSamplingFnT | None = None,
282+
task_augmented_elicitation_callback: TaskAugmentedElicitationFnT | None = None,
123283
) -> None:
124284
super().__init__(
125285
read_stream,
@@ -134,9 +294,21 @@ def __init__(
134294
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
135295
self._logging_callback = logging_callback or _default_logging_callback
136296
self._message_handler = message_handler or _default_message_handler
297+
self._tasks_capability = tasks_capability
137298
self._tool_output_schemas: dict[str, dict[str, Any] | None] = {}
138299
self._server_capabilities: types.ServerCapabilities | None = None
139300
self._experimental: ExperimentalClientFeatures | None = None
301+
# Experimental: Task handlers
302+
self._get_task_handler = get_task_handler or _default_get_task_handler
303+
self._get_task_result_handler = get_task_result_handler or _default_get_task_result_handler
304+
self._list_tasks_handler = list_tasks_handler or _default_list_tasks_handler
305+
self._cancel_task_handler = cancel_task_handler or _default_cancel_task_handler
306+
self._task_augmented_sampling_callback = (
307+
task_augmented_sampling_callback or _default_task_augmented_sampling_callback
308+
)
309+
self._task_augmented_elicitation_callback = (
310+
task_augmented_elicitation_callback or _default_task_augmented_elicitation_callback
311+
)
140312

141313
async def initialize(self) -> types.InitializeResult:
142314
sampling = types.SamplingCapability() if self._sampling_callback is not _default_sampling_callback else None
@@ -162,6 +334,7 @@ async def initialize(self) -> types.InitializeResult:
162334
elicitation=elicitation,
163335
experimental=None,
164336
roots=roots,
337+
tasks=self._tasks_capability,
165338
),
166339
clientInfo=self._client_info,
167340
),
@@ -187,7 +360,7 @@ def get_server_capabilities(self) -> types.ServerCapabilities | None:
187360
return self._server_capabilities
188361

189362
@property
190-
def experimental(self) -> "ExperimentalClientFeatures":
363+
def experimental(self) -> ExperimentalClientFeatures:
191364
"""Experimental APIs for tasks and other features.
192365
193366
WARNING: These APIs are experimental and may change without notice.
@@ -534,13 +707,21 @@ async def _received_request(self, responder: RequestResponder[types.ServerReques
534707
match responder.request.root:
535708
case types.CreateMessageRequest(params=params):
536709
with responder:
537-
response = await self._sampling_callback(ctx, params)
710+
# Check if this is a task-augmented request
711+
if params.task is not None:
712+
response = await self._task_augmented_sampling_callback(ctx, params, params.task)
713+
else:
714+
response = await self._sampling_callback(ctx, params)
538715
client_response = ClientResponse.validate_python(response)
539716
await responder.respond(client_response)
540717

541718
case types.ElicitRequest(params=params):
542719
with responder:
543-
response = await self._elicitation_callback(ctx, params)
720+
# Check if this is a task-augmented request
721+
if params.task is not None:
722+
response = await self._task_augmented_elicitation_callback(ctx, params, params.task)
723+
else:
724+
response = await self._elicitation_callback(ctx, params)
544725
client_response = ClientResponse.validate_python(response)
545726
await responder.respond(client_response)
546727

@@ -553,7 +734,33 @@ async def _received_request(self, responder: RequestResponder[types.ServerReques
553734
case types.PingRequest(): # pragma: no cover
554735
with responder:
555736
return await responder.respond(types.ClientResult(root=types.EmptyResult()))
556-
case _:
737+
738+
# Experimental: Task management requests from server
739+
case types.GetTaskRequest(params=params):
740+
with responder:
741+
response = await self._get_task_handler(ctx, params)
742+
client_response = ClientResponse.validate_python(response)
743+
await responder.respond(client_response)
744+
745+
case types.GetTaskPayloadRequest(params=params):
746+
with responder:
747+
response = await self._get_task_result_handler(ctx, params)
748+
client_response = ClientResponse.validate_python(response)
749+
await responder.respond(client_response)
750+
751+
case types.ListTasksRequest(params=params):
752+
with responder:
753+
response = await self._list_tasks_handler(ctx, params)
754+
client_response = ClientResponse.validate_python(response)
755+
await responder.respond(client_response)
756+
757+
case types.CancelTaskRequest(params=params):
758+
with responder:
759+
response = await self._cancel_task_handler(ctx, params)
760+
client_response = ClientResponse.validate_python(response)
761+
await responder.respond(client_response)
762+
763+
case _: # pragma: no cover
557764
raise NotImplementedError()
558765

559766
async def _handle_incoming(

src/mcp/server/lowlevel/server.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,14 @@ async def main():
6767

6868
from __future__ import annotations as _annotations
6969

70+
import base64
7071
import contextvars
7172
import json
7273
import logging
7374
import warnings
7475
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable
7576
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
77+
from importlib.metadata import version as pkg_version
7678
from typing import Any, Generic, TypeAlias, cast
7779

7880
import anyio
@@ -165,19 +167,17 @@ def create_initialization_options(
165167
) -> InitializationOptions:
166168
"""Create initialization options from this server instance."""
167169

168-
def pkg_version(package: str) -> str:
170+
def get_package_version(package: str) -> str:
169171
try:
170-
from importlib.metadata import version
171-
172-
return version(package)
172+
return pkg_version(package)
173173
except Exception: # pragma: no cover
174174
pass
175175

176176
return "unknown" # pragma: no cover
177177

178178
return InitializationOptions(
179179
server_name=self.name,
180-
server_version=self.version if self.version else pkg_version("mcp"),
180+
server_version=self.version if self.version else get_package_version("mcp"),
181181
capabilities=self.get_capabilities(
182182
notification_options or NotificationOptions(),
183183
experimental_capabilities or {},
@@ -344,8 +344,6 @@ def create_content(data: str | bytes, mime_type: str | None):
344344
mimeType=mime_type or "text/plain",
345345
)
346346
case bytes() as data: # pragma: no cover
347-
import base64
348-
349347
return types.BlobResourceContents(
350348
uri=req.params.uri,
351349
blob=base64.b64encode(data).decode(),

src/mcp/shared/experimental/tasks/__init__.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,13 @@
55
- TaskStore: Abstract interface for task state storage
66
- TaskContext: Context object for task work to interact with state/notifications
77
- InMemoryTaskStore: Reference implementation for testing/development
8+
- TaskMessageQueue: FIFO queue for task messages delivered via tasks/result
9+
- InMemoryTaskMessageQueue: Reference implementation for message queue
810
- Helper functions: run_task, is_terminal, create_task_state, generate_task_id
911
1012
Architecture:
1113
- TaskStore is pure storage - it doesn't know about execution
14+
- TaskMessageQueue stores messages to be delivered via tasks/result
1215
- TaskContext wraps store + session, providing a clean API for task work
1316
- run_task is optional convenience for spawning in-process tasks
1417
@@ -24,15 +27,31 @@
2427
task_execution,
2528
)
2629
from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore
30+
from mcp.shared.experimental.tasks.message_queue import (
31+
InMemoryTaskMessageQueue,
32+
QueuedMessage,
33+
TaskMessageQueue,
34+
)
35+
from mcp.shared.experimental.tasks.result_handler import (
36+
TaskResultHandler,
37+
create_task_result_handler,
38+
)
2739
from mcp.shared.experimental.tasks.store import TaskStore
40+
from mcp.shared.experimental.tasks.task_session import TaskSession
2841

2942
__all__ = [
3043
"TaskStore",
3144
"TaskContext",
45+
"TaskSession",
46+
"TaskResultHandler",
3247
"InMemoryTaskStore",
48+
"TaskMessageQueue",
49+
"InMemoryTaskMessageQueue",
50+
"QueuedMessage",
3351
"run_task",
3452
"task_execution",
3553
"is_terminal",
3654
"create_task_state",
3755
"generate_task_id",
56+
"create_task_result_handler",
3857
]

0 commit comments

Comments
 (0)