Skip to content

Commit 8fbb6fe

Browse files
committed
Support session-binding in TaskStore
1 parent d29c215 commit 8fbb6fe

File tree

9 files changed

+246
-26
lines changed

9 files changed

+246
-26
lines changed

examples/shared/in_memory_task_store.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@ def __init__(self) -> None:
4141
self._tasks: dict[str, StoredTask] = {}
4242
self._cleanup_tasks: dict[str, asyncio.Task[None]] = {}
4343

44-
async def create_task(self, task: TaskMetadata, request_id: RequestId, request: Request[Any, Any]) -> None:
44+
async def create_task(
45+
self, task: TaskMetadata, request_id: RequestId, request: Request[Any, Any], session_id: str | None = None
46+
) -> None:
4547
"""Create a new task with the given metadata and original request."""
4648
task_id = task.taskId
4749

@@ -61,7 +63,7 @@ async def create_task(self, task: TaskMetadata, request_id: RequestId, request:
6163
if task.keepAlive is not None:
6264
self._schedule_cleanup(task_id, task.keepAlive / 1000.0)
6365

64-
async def get_task(self, task_id: str) -> Task | None:
66+
async def get_task(self, task_id: str, session_id: str | None = None) -> Task | None:
6567
"""Get the current status of a task."""
6668
stored = self._tasks.get(task_id)
6769
if stored is None:
@@ -70,7 +72,7 @@ async def get_task(self, task_id: str) -> Task | None:
7072
# Return a copy to prevent external modification
7173
return Task(**stored.task.model_dump())
7274

73-
async def store_task_result(self, task_id: str, result: Result) -> None:
75+
async def store_task_result(self, task_id: str, result: Result, session_id: str | None = None) -> None:
7476
"""Store the result of a completed task."""
7577
stored = self._tasks.get(task_id)
7678
if stored is None:
@@ -84,7 +86,7 @@ async def store_task_result(self, task_id: str, result: Result) -> None:
8486
self._cancel_cleanup(task_id)
8587
self._schedule_cleanup(task_id, stored.task.keepAlive / 1000.0)
8688

87-
async def get_task_result(self, task_id: str) -> Result:
89+
async def get_task_result(self, task_id: str, session_id: str | None = None) -> Result:
8890
"""Retrieve the stored result of a task."""
8991
stored = self._tasks.get(task_id)
9092
if stored is None:
@@ -95,7 +97,9 @@ async def get_task_result(self, task_id: str) -> Result:
9597

9698
return stored.result
9799

98-
async def update_task_status(self, task_id: str, status: TaskStatus, error: str | None = None) -> None:
100+
async def update_task_status(
101+
self, task_id: str, status: TaskStatus, error: str | None = None, session_id: str | None = None
102+
) -> None:
99103
"""Update a task's status."""
100104
stored = self._tasks.get(task_id)
101105
if stored is None:
@@ -110,7 +114,7 @@ async def update_task_status(self, task_id: str, status: TaskStatus, error: str
110114
self._cancel_cleanup(task_id)
111115
self._schedule_cleanup(task_id, stored.task.keepAlive / 1000.0)
112116

113-
async def list_tasks(self, cursor: str | None = None) -> dict[str, Any]:
117+
async def list_tasks(self, cursor: str | None = None, session_id: str | None = None) -> dict[str, Any]:
114118
"""
115119
List tasks, optionally starting from a pagination cursor.
116120
@@ -134,7 +138,7 @@ async def list_tasks(self, cursor: str | None = None) -> dict[str, Any]:
134138

135139
return {"tasks": tasks, "nextCursor": next_cursor}
136140

137-
async def delete_task(self, task_id: str) -> None:
141+
async def delete_task(self, task_id: str, session_id: str | None = None) -> None:
138142
"""Delete a task from storage."""
139143
if task_id not in self._tasks:
140144
raise ValueError(f"Task with ID {task_id} not found")
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
"""
2+
Run from the repository root:
3+
uv run examples/snippets/clients/task_based_tool_client.py
4+
5+
Prerequisites:
6+
The task_based_tool server must be running on http://localhost:8000
7+
Start it with:
8+
cd examples/snippets && uv run server task_based_tool streamable-http
9+
"""
10+
11+
import asyncio
12+
13+
from mcp import ClientSession
14+
from mcp.client.streamable_http import MCP_SESSION_ID, streamablehttp_client
15+
from mcp.types import CallToolResult
16+
17+
18+
async def main():
19+
async with streamablehttp_client(
20+
"http://localhost:3000/mcp",
21+
headers={MCP_SESSION_ID: "5771f709-66f5-4176-9f32-ce91e3117df2"},
22+
terminate_on_close=False,
23+
) as (
24+
read_stream,
25+
write_stream,
26+
_,
27+
):
28+
async with ClientSession(read_stream, write_stream) as session:
29+
result = await session.get_task_result("736054ac-5f10-409e-a06a-526761ea827a", CallToolResult)
30+
print(result)
31+
32+
33+
if __name__ == "__main__":
34+
asyncio.run(main())
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
"""
2+
Run from the repository root:
3+
uv run examples/snippets/clients/task_based_tool_client.py
4+
5+
Prerequisites:
6+
The task_based_tool server must be running on http://localhost:8000
7+
Start it with:
8+
cd examples/snippets && uv run server task_based_tool streamable-http
9+
"""
10+
11+
import asyncio
12+
13+
from mcp import ClientSession, types
14+
from mcp.client.streamable_http import streamablehttp_client
15+
from mcp.shared.context import RequestContext
16+
from mcp.shared.request import TaskHandlerOptions
17+
18+
19+
async def elicitation_handler(
20+
context: RequestContext[ClientSession, None], params: types.ElicitRequestParams
21+
) -> types.ElicitResult | types.ErrorData:
22+
"""
23+
Handle elicitation requests from the server.
24+
25+
This handler collects user feedback with a predefined schema including:
26+
- rating (1-5, required)
27+
- comments (optional text up to 500 chars)
28+
- recommend (boolean, required)
29+
"""
30+
print(f"\n🎯 Elicitation request received: {params.message}")
31+
print(f"Schema: {params.requestedSchema}")
32+
await asyncio.sleep(5)
33+
34+
# In a real application, you would collect this data from the user
35+
# For this example, we'll return mock data
36+
feedback_data: dict[str, str | int | float | bool | None] = {
37+
"rating": 5,
38+
"comments": "The task execution was excellent and fast!",
39+
"recommend": True,
40+
}
41+
42+
print(f"📝 Returning feedback: {feedback_data}")
43+
44+
return types.ElicitResult(action="accept", content=feedback_data)
45+
46+
47+
async def main():
48+
"""
49+
Demonstrate task-based execution with begin_call_tool.
50+
51+
This example shows how to:
52+
1. Start a long-running tool call with begin_call_tool()
53+
2. Get task status updates through callbacks
54+
3. Wait for the final result with polling
55+
4. Handle elicitation requests from the server
56+
"""
57+
# Connect to the task-based tool example server via streamable HTTP
58+
async with streamablehttp_client("http://localhost:3000/mcp", terminate_on_close=False) as (
59+
read_stream,
60+
write_stream,
61+
_,
62+
):
63+
async with ClientSession(read_stream, write_stream, elicitation_callback=elicitation_handler) as session:
64+
# Initialize the connection
65+
await session.initialize()
66+
67+
print("Starting task-based tool execution...")
68+
69+
# Track callback invocations
70+
task_created = False
71+
status_updates: list[str] = []
72+
73+
async def on_task_created() -> None:
74+
"""Called when the task is first created."""
75+
nonlocal task_created
76+
task_created = True
77+
print("✓ Task created on server")
78+
79+
async def on_task_status(task_result: types.GetTaskResult) -> None:
80+
"""Called whenever the task status is polled."""
81+
status_updates.append(task_result.status)
82+
print(f" Status ({task_result.taskId}): {task_result.status}")
83+
84+
# Begin the tool call (returns immediately with a PendingRequest)
85+
print("\nCalling begin_call_tool...")
86+
# pending_request = session.begin_call_tool(
87+
# "collect-user-info",
88+
# arguments={"infoType": "feedback"},
89+
# )
90+
pending_request = session.begin_call_tool(
91+
"delay",
92+
arguments={},
93+
)
94+
95+
print("Tool call initiated! Now waiting for result with task polling...\n")
96+
97+
# Wait for the result with task callbacks
98+
result = await pending_request.result(
99+
TaskHandlerOptions(on_task_created=on_task_created, on_task_status=on_task_status)
100+
)
101+
102+
# Display the result
103+
print("\n✓ Tool execution completed!")
104+
if result.content:
105+
content_block = result.content[0]
106+
if isinstance(content_block, types.TextContent):
107+
print(f"Result: {content_block.text}")
108+
else:
109+
print(f"Result: {content_block}")
110+
else:
111+
print("Result: No content")
112+
113+
# Show callback statistics
114+
print("\nTask callbacks:")
115+
print(f" - Task created callback: {'Yes' if task_created else 'No'}")
116+
print(f" - Status updates received: {len(status_updates)}")
117+
if status_updates:
118+
print(f" - Final status: {status_updates[-1]}")
119+
120+
121+
if __name__ == "__main__":
122+
asyncio.run(main())

src/mcp/client/session.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ def __init__(
128128
message_handler: MessageHandlerFnT | None = None,
129129
client_info: types.Implementation | None = None,
130130
task_store: TaskStore | None = None,
131+
session_id: str | None = None,
131132
) -> None:
132133
super().__init__(
133134
read_stream,
@@ -136,6 +137,7 @@ def __init__(
136137
types.ServerNotification,
137138
read_timeout_seconds=read_timeout_seconds,
138139
task_store=task_store,
140+
session_id=session_id,
139141
)
140142
self._client_info = client_info or DEFAULT_CLIENT_INFO
141143
self._sampling_callback = sampling_callback or _default_sampling_callback
@@ -605,6 +607,29 @@ async def delete_task(self, task_id: str) -> types.EmptyResult:
605607
)
606608

607609
async def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None:
610+
# Handle task creation if task metadata is present
611+
if responder.request_meta and responder.request_meta.task and self._task_store:
612+
task_meta = responder.request_meta.task
613+
# Create the task in the task store
614+
await self._task_store.create_task(
615+
task_meta,
616+
responder.request_id,
617+
responder.request.root,
618+
session_id=self._session_id, # type: ignore[arg-type]
619+
)
620+
# Send task created notification with related task metadata
621+
notification_params = types.TaskCreatedNotificationParams(
622+
_meta=types.NotificationParams.Meta(
623+
**{types.RELATED_TASK_META_KEY: types.RelatedTaskMetadata(taskId=task_meta.taskId)}
624+
)
625+
)
626+
await self.send_notification(
627+
types.ClientNotification(
628+
types.TaskCreatedNotification(method="notifications/tasks/created", params=notification_params)
629+
),
630+
related_request_id=responder.request_id,
631+
)
632+
608633
ctx = RequestContext[ClientSession, Any](
609634
request_id=responder.request_id,
610635
meta=responder.request_meta,
@@ -638,7 +663,7 @@ async def _received_request(self, responder: RequestResponder[types.ServerReques
638663
case types.GetTaskRequest(params=params):
639664
# Handle get task requests if task store is available
640665
if self._task_store:
641-
task = await self._task_store.get_task(params.taskId)
666+
task = await self._task_store.get_task(params.taskId, session_id=self._session_id)
642667
if task is None:
643668
with responder:
644669
await responder.respond(
@@ -666,7 +691,7 @@ async def _received_request(self, responder: RequestResponder[types.ServerReques
666691
case types.GetTaskPayloadRequest(params=params):
667692
# Handle get task result requests if task store is available
668693
if self._task_store:
669-
task = await self._task_store.get_task(params.taskId)
694+
task = await self._task_store.get_task(params.taskId, session_id=self._session_id)
670695
if task is None:
671696
with responder:
672697
await responder.respond(
@@ -683,7 +708,7 @@ async def _received_request(self, responder: RequestResponder[types.ServerReques
683708
)
684709
)
685710
else:
686-
result = await self._task_store.get_task_result(params.taskId)
711+
result = await self._task_store.get_task_result(params.taskId, session_id=self._session_id)
687712
# Add related-task metadata
688713
result_dict = result.model_dump(by_alias=True, mode="json", exclude_none=True)
689714
if "_meta" not in result_dict:
@@ -701,7 +726,9 @@ async def _received_request(self, responder: RequestResponder[types.ServerReques
701726
# Handle list tasks requests if task store is available
702727
if self._task_store:
703728
try:
704-
result = await self._task_store.list_tasks(params.cursor if params else None)
729+
result = await self._task_store.list_tasks(
730+
params.cursor if params else None, session_id=self._session_id
731+
)
705732
with responder:
706733
await responder.respond(
707734
types.ClientResult(

src/mcp/server/lowlevel/server.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -664,6 +664,8 @@ async def run(
664664
# the initialization lifecycle, but can do so with any available node
665665
# rather than requiring initialization for each connection.
666666
stateless: bool = False,
667+
# Optional session identifier to pass to ServerSession for multi-session task stores
668+
session_id: str | None = None,
667669
):
668670
async with AsyncExitStack() as stack:
669671
lifespan_context = await stack.enter_async_context(self.lifespan(self))
@@ -674,6 +676,7 @@ async def run(
674676
initialization_options,
675677
stateless=stateless,
676678
task_store=self.task_store,
679+
session_id=session_id,
677680
)
678681
)
679682

0 commit comments

Comments
 (0)