Skip to content

Commit b802dc4

Browse files
committed
Operation token plumbing to support async elicitation/sampling
1 parent f8ca895 commit b802dc4

File tree

14 files changed

+433
-46
lines changed

14 files changed

+433
-46
lines changed

examples/snippets/clients/async_tools_client.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from mcp import ClientSession, StdioServerParameters, types
1515
from mcp.client.stdio import stdio_client
16+
from mcp.shared.context import RequestContext
1617

1718
# Create server parameters for stdio connection
1819
server_params = StdioServerParameters(
@@ -22,6 +23,22 @@
2223
)
2324

2425

26+
async def elicitation_callback(context: RequestContext[ClientSession, None], params: types.ElicitRequestParams):
27+
"""Handle elicitation requests from the server."""
28+
if "data_migration" in params.message:
29+
return types.ElicitResult(
30+
action="accept",
31+
content={"continue_processing": True, "priority_level": "normal"},
32+
)
33+
else:
34+
return types.ElicitResult(action="decline")
35+
36+
37+
async def logging_callback(params: types.LoggingMessageNotificationParams):
38+
"""Handle logging messages from the server."""
39+
print(f"Server log: {params.data}", file=sys.stderr)
40+
41+
2542
async def demonstrate_sync_tool(session: ClientSession):
2643
"""Demonstrate calling a synchronous tool."""
2744
print("\n=== Synchronous Tool Demo ===")
@@ -174,6 +191,37 @@ async def demonstrate_data_processing(session: ClientSession):
174191
await asyncio.sleep(0.8)
175192

176193

194+
async def demonstrate_elicitation(session: ClientSession):
195+
"""Demonstrate async elicitation tool."""
196+
print("\n=== Async Elicitation Demo ===")
197+
198+
result = await session.call_tool("async_elicitation_tool", arguments={"operation": "data_migration"})
199+
200+
if result.operation:
201+
token = result.operation.token
202+
print(f"Elicitation operation started with token: {token}")
203+
204+
# Poll for completion
205+
while True:
206+
status = await session.get_operation_status(token)
207+
print(f"Status: {status.status}")
208+
209+
if status.status == "completed":
210+
final_result = await session.get_operation_result(token)
211+
for content in final_result.result.content:
212+
if isinstance(content, types.TextContent):
213+
print(f"Elicitation result: {content.text}")
214+
break
215+
elif status.status == "failed":
216+
print(f"Elicitation failed: {status.error}")
217+
break
218+
elif status.status in ("canceled", "unknown"):
219+
print(f"Elicitation ended with status: {status.status}")
220+
break
221+
222+
await asyncio.sleep(0.5)
223+
224+
177225
async def run():
178226
"""Run all async tool demonstrations."""
179227
# Determine protocol version from command line
@@ -189,7 +237,13 @@ async def run():
189237

190238
async with stdio_client(server_params) as (read, write):
191239
# Use configured protocol version
192-
async with ClientSession(read, write, protocol_version=protocol_version) as session:
240+
async with ClientSession(
241+
read,
242+
write,
243+
protocol_version=protocol_version,
244+
elicitation_callback=elicitation_callback,
245+
logging_callback=logging_callback,
246+
) as session:
193247
# Initialize the connection
194248
await session.initialize()
195249

@@ -206,6 +260,7 @@ async def run():
206260
await demonstrate_async_tool(session)
207261
await demonstrate_batch_processing(session)
208262
await demonstrate_data_processing(session)
263+
await demonstrate_elicitation(session)
209264

210265
print("\n=== All demonstrations complete! ===")
211266

examples/snippets/servers/async_tools.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,61 @@
77

88
import asyncio
99

10+
from pydantic import BaseModel, Field
11+
1012
from mcp.server.fastmcp import Context, FastMCP
1113

1214
# Create an MCP server with async operations support
1315
mcp = FastMCP("Async Tools Demo")
1416

1517

18+
class UserPreferences(BaseModel):
19+
"""Schema for collecting user preferences."""
20+
21+
continue_processing: bool = Field(description="Should we continue with the operation?")
22+
priority_level: str = Field(
23+
default="normal",
24+
description="Priority level: low, normal, high",
25+
)
26+
27+
28+
@mcp.tool(invocation_modes=["async"])
29+
async def async_elicitation_tool(operation: str, ctx: Context) -> str: # type: ignore[type-arg]
30+
"""An async tool that uses elicitation to get user input."""
31+
await ctx.info(f"Starting operation: {operation}")
32+
33+
# Simulate some initial processing
34+
await asyncio.sleep(0.5)
35+
await ctx.report_progress(0.3, 1.0, "Initial processing complete")
36+
37+
await ctx.debug("About to call elicit")
38+
try:
39+
# Ask user for preferences
40+
result = await ctx.elicit(
41+
message=f"Operation '{operation}' requires user input. How should we proceed?",
42+
schema=UserPreferences,
43+
)
44+
await ctx.debug(f"Elicit result: {result}")
45+
except Exception as e:
46+
await ctx.error(f"Elicitation failed: {e}")
47+
raise
48+
49+
if result.action == "accept" and result.data:
50+
if result.data.continue_processing:
51+
await ctx.info(f"Continuing with {result.data.priority_level} priority")
52+
# Simulate processing based on user choice
53+
processing_time = {"low": 0.5, "normal": 1.0, "high": 1.5}.get(result.data.priority_level, 1.0)
54+
await asyncio.sleep(processing_time)
55+
await ctx.report_progress(1.0, 1.0, "Operation complete")
56+
return f"Operation '{operation}' completed successfully with {result.data.priority_level} priority"
57+
else:
58+
await ctx.warning("User chose not to continue")
59+
return f"Operation '{operation}' cancelled by user"
60+
else:
61+
await ctx.error("User declined or cancelled the operation")
62+
return f"Operation '{operation}' aborted"
63+
64+
1665
@mcp.tool()
1766
def sync_tool(x: int) -> str:
1867
"""An implicitly-synchronous tool."""

src/mcp/client/session.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,7 @@ async def send_roots_list_changed(self) -> None:
466466
async def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None:
467467
ctx = RequestContext[ClientSession, Any](
468468
request_id=responder.request_id,
469+
operation_token=responder.operation.token if responder.operation is not None else None,
469470
meta=responder.request_meta,
470471
session=self,
471472
lifespan_context=None,
@@ -475,12 +476,36 @@ async def _received_request(self, responder: RequestResponder[types.ServerReques
475476
case types.CreateMessageRequest(params=params):
476477
with responder:
477478
response = await self._sampling_callback(ctx, params)
479+
if isinstance(response, types.CreateMessageResult):
480+
response.operation_props = (
481+
types.Operation(token=responder.operation.token)
482+
if responder.operation is not None
483+
else None
484+
)
485+
else:
486+
response.operation = (
487+
types.Operation(token=responder.operation.token)
488+
if responder.operation is not None
489+
else None
490+
)
478491
client_response = ClientResponse.validate_python(response)
479492
await responder.respond(client_response)
480493

481494
case types.ElicitRequest(params=params):
482495
with responder:
483496
response = await self._elicitation_callback(ctx, params)
497+
if isinstance(response, types.ElicitResult):
498+
response.operation_props = (
499+
types.Operation(token=responder.operation.token)
500+
if responder.operation is not None
501+
else None
502+
)
503+
else:
504+
response.operation = (
505+
types.Operation(token=responder.operation.token)
506+
if responder.operation is not None
507+
else None
508+
)
484509
client_response = ClientResponse.validate_python(response)
485510
await responder.respond(client_response)
486511

src/mcp/server/elicitation.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ async def elicit_with_validation(
7878
message: str,
7979
schema: type[ElicitSchemaModelT],
8080
related_request_id: RequestId | None = None,
81+
related_operation_token: str | None = None,
8182
) -> ElicitationResult[ElicitSchemaModelT]:
8283
"""Elicit information from the client/user with schema validation.
8384
@@ -96,6 +97,7 @@ async def elicit_with_validation(
9697
message=message,
9798
requestedSchema=json_schema,
9899
related_request_id=related_request_id,
100+
related_operation_token=related_operation_token,
99101
)
100102

101103
if result.action == "accept" and result.content is not None:

src/mcp/server/fastmcp/server.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1213,6 +1213,8 @@ async def report_progress(self, progress: float, total: float | None = None, mes
12131213
progress=progress,
12141214
total=total,
12151215
message=message,
1216+
related_request_id=self.request_id,
1217+
related_operation_token=self.request_context.operation_token,
12161218
)
12171219

12181220
async def read_resource(self, uri: str | AnyUrl) -> Iterable[ReadResourceContents]:
@@ -1255,7 +1257,11 @@ async def elicit(
12551257
"""
12561258

12571259
return await elicit_with_validation(
1258-
session=self.request_context.session, message=message, schema=schema, related_request_id=self.request_id
1260+
session=self.request_context.session,
1261+
message=message,
1262+
schema=schema,
1263+
related_request_id=self.request_id,
1264+
related_operation_token=self.request_context.operation_token,
12591265
)
12601266

12611267
async def log(

src/mcp/server/lowlevel/server.py

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ async def main():
9191
from mcp.shared.exceptions import McpError
9292
from mcp.shared.message import ServerMessageMetadata, SessionMessage
9393
from mcp.shared.session import RequestResponder
94-
from mcp.types import RequestId
94+
from mcp.types import Operation, RequestId
9595

9696
logger = logging.getLogger(__name__)
9797

@@ -478,6 +478,10 @@ async def handler(req: types.CallToolRequest):
478478
)
479479
logger.debug(f"Created async operation with token: {operation.token}")
480480

481+
ctx = self.request_context
482+
ctx.operation_token = operation.token
483+
request_ctx.set(ctx)
484+
481485
# Start async execution in background
482486
async def execute_async():
483487
try:
@@ -560,6 +564,9 @@ def _process_tool_result(
560564
content=list(unstructured_content),
561565
structuredContent=maybe_structured_content,
562566
isError=False,
567+
_operation=Operation(token=self.request_context.operation_token)
568+
if self.request_context and self.request_context.operation_token
569+
else None,
563570
)
564571

565572
def _should_execute_async(self, tool: types.Tool) -> bool:
@@ -720,9 +727,7 @@ def send_request_for_operation(self, token: str, request: types.ServerRequest) -
720727
# Add operation token to request
721728
if hasattr(request.root, "params") and request.root.params is not None:
722729
if not hasattr(request.root.params, "operation") or request.root.params.operation is None:
723-
# Create operation field if it doesn't exist
724-
operation_data = types.RequestParams.Operation(token=token)
725-
request.root.params.operation = operation_data
730+
request.root.params.operation = Operation(token=token)
726731
logger.debug(f"Marked operation {token} as input_required and added to request")
727732

728733
def send_notification_for_operation(self, token: str, notification: types.ServerNotification) -> None:
@@ -732,9 +737,7 @@ def send_notification_for_operation(self, token: str, notification: types.Server
732737
# Add operation token to notification
733738
if hasattr(notification.root, "params") and notification.root.params is not None:
734739
if not hasattr(notification.root.params, "operation") or notification.root.params.operation is None:
735-
# Create operation field if it doesn't exist
736-
operation_data = types.NotificationParams.Operation(token=token)
737-
notification.root.params.operation = operation_data
740+
notification.root.params.operation = Operation(token=token)
738741
logger.debug(f"Marked operation {token} as input_required and added to notification")
739742

740743
def complete_request_for_operation(self, token: str) -> None:
@@ -833,25 +836,16 @@ async def _handle_request(
833836
# app.get_request_context()
834837
context_token = request_ctx.set(
835838
RequestContext(
836-
message.request_id,
837-
message.request_meta,
838-
session,
839-
lifespan_context,
839+
request_id=message.request_id,
840+
operation_token=message.operation.token if message.operation else None,
841+
meta=message.request_meta,
842+
session=session,
843+
lifespan_context=lifespan_context,
840844
request=request_data,
841845
)
842846
)
843847
response = await handler(req)
844848

845-
# Handle operation token in response (for input_required operations)
846-
if (
847-
hasattr(req, "params")
848-
and req.params is not None
849-
and hasattr(req.params, "operation")
850-
and req.params.operation is not None
851-
):
852-
operation_token = req.params.operation.token
853-
self.complete_request_for_operation(operation_token)
854-
855849
# Track async operations for cancellation
856850
if isinstance(req, types.CallToolRequest):
857851
result = response.root

0 commit comments

Comments
 (0)