Skip to content

Commit 0198f48

Browse files
yamanahlawatKludex
andauthored
feat: add elicitation callback support to MCP servers (#2373)
Co-authored-by: Marcelo Trylesinski <[email protected]>
1 parent d273775 commit 0198f48

File tree

6 files changed

+227
-18
lines changed

6 files changed

+227
-18
lines changed

docs/mcp/client.md

Lines changed: 141 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ async def main():
280280
```
281281

282282
1. When you supply `http_client`, Pydantic AI re-uses this client for every
283-
request. Anything supported by **httpx** (`verify`, `cert`, custom
283+
request. Anything supported by **httpx** (`verify`, `cert`, custom
284284
proxies, timeouts, etc.) therefore applies to all MCP traffic.
285285

286286
## MCP Sampling
@@ -391,3 +391,143 @@ server = MCPServerStdio(
391391
allow_sampling=False,
392392
)
393393
```
394+
395+
## Elicitation
396+
397+
In MCP, [elicitation](https://modelcontextprotocol.io/docs/concepts/elicitation) allows a server to request for [structured input](https://modelcontextprotocol.io/specification/2025-06-18/client/elicitation#supported-schema-types) from the client for missing or additional context during a session.
398+
399+
Elicitation let models essentially say "Hold on - I need to know X before i can continue" rather than requiring everything upfront or taking a shot in the dark.
400+
401+
### How Elicitation works
402+
403+
Elicitation introduces a new protocol message type called [`ElicitRequest`](https://modelcontextprotocol.io/specification/2025-06-18/schema#elicitrequest), which is sent from the server to the client when it needs additional information. The client can then respond with an [`ElicitResult`](https://modelcontextprotocol.io/specification/2025-06-18/schema#elicitresult) or an `ErrorData` message.
404+
405+
Here's a typical interaction:
406+
407+
- User makes a request to the MCP server (e.g. "Book a table at that Italian place")
408+
- The server identifies that it needs more information (e.g. "Which Italian place?", "What date and time?")
409+
- The server sends an `ElicitRequest` to the client asking for the missing information.
410+
- The client receives the request, presents it to the user (e.g. via a terminal prompt, GUI dialog, or web interface).
411+
- User provides the requested information, `decline` or `cancel` the request.
412+
- The client sends an `ElicitResult` back to the server with the user's response.
413+
- With the structured data, the server can continue processing the original request.
414+
415+
This allows for a more interactive and user-friendly experience, especially for multi-staged workflows. Instead of requiring all information upfront, the server can ask for it as needed, making the interaction feel more natural.
416+
417+
### Setting up Elicitation
418+
419+
To enable elicitation, provide an [`elicitation_callback`][pydantic_ai.mcp.MCPServer.elicitation_callback] function when creating your MCP server instance:
420+
421+
```python {title="restaurant_server.py" py="3.10"}
422+
from mcp.server.fastmcp import Context, FastMCP
423+
from pydantic import BaseModel, Field
424+
425+
mcp = FastMCP(name='Restaurant Booking')
426+
427+
428+
class BookingDetails(BaseModel):
429+
"""Schema for restaurant booking information."""
430+
431+
restaurant: str = Field(description='Choose a restaurant')
432+
party_size: int = Field(description='Number of people', ge=1, le=8)
433+
date: str = Field(description='Reservation date (DD-MM-YYYY)')
434+
435+
436+
@mcp.tool()
437+
async def book_table(ctx: Context) -> str:
438+
"""Book a restaurant table with user input."""
439+
# Ask user for booking details using Pydantic schema
440+
result = await ctx.elicit(message='Please provide your booking details:', schema=BookingDetails)
441+
442+
if result.action == 'accept' and result.data:
443+
booking = result.data
444+
return f'✅ Booked table for {booking.party_size} at {booking.restaurant} on {booking.date}'
445+
elif result.action == 'decline':
446+
return 'No problem! Maybe another time.'
447+
else: # cancel
448+
return 'Booking cancelled.'
449+
450+
451+
if __name__ == '__main__':
452+
mcp.run(transport='stdio')
453+
```
454+
455+
This server demonstrates elicitation by requesting structured booking details from the client when the `book_table` tool is called. Here's how to create a client that handles these elicitation requests:
456+
457+
```python {title="client_example.py" py="3.10" requires="restaurant_server.py" test="skip"}
458+
import asyncio
459+
from typing import Any
460+
461+
from mcp.client.session import ClientSession
462+
from mcp.shared.context import RequestContext
463+
from mcp.types import ElicitRequestParams, ElicitResult
464+
465+
from pydantic_ai import Agent
466+
from pydantic_ai.mcp import MCPServerStdio
467+
468+
469+
async def handle_elicitation(
470+
context: RequestContext[ClientSession, Any, Any],
471+
params: ElicitRequestParams,
472+
) -> ElicitResult:
473+
"""Handle elicitation requests from MCP server."""
474+
print(f'\n{params.message}')
475+
476+
if not params.requestedSchema:
477+
response = input('Response: ')
478+
return ElicitResult(action='accept', content={'response': response})
479+
480+
# Collect data for each field
481+
properties = params.requestedSchema['properties']
482+
data = {}
483+
484+
for field, info in properties.items():
485+
description = info.get('description', field)
486+
487+
value = input(f'{description}: ')
488+
489+
# Convert to proper type based on JSON schema
490+
if info.get('type') == 'integer':
491+
data[field] = int(value)
492+
else:
493+
data[field] = value
494+
495+
# Confirm
496+
confirm = input('\nConfirm booking? (y/n/c): ').lower()
497+
498+
if confirm == 'y':
499+
print('Booking details:', data)
500+
return ElicitResult(action='accept', content=data)
501+
elif confirm == 'n':
502+
return ElicitResult(action='decline')
503+
else:
504+
return ElicitResult(action='cancel')
505+
506+
507+
# Set up MCP server connection
508+
restaurant_server = MCPServerStdio(
509+
command='python', args=['restaurant_server.py'], elicitation_callback=handle_elicitation
510+
)
511+
512+
# Create agent
513+
agent = Agent('openai:gpt-4o', toolsets=[restaurant_server])
514+
515+
516+
async def main():
517+
"""Run the agent to book a restaurant table."""
518+
async with agent:
519+
result = await agent.run('Book me a table')
520+
print(f'\nResult: {result.output}')
521+
522+
523+
if __name__ == '__main__':
524+
asyncio.run(main())
525+
```
526+
527+
### Supported Schema Types
528+
529+
MCP elicitation supports string, number, boolean, and enum types with flat object structures only. These limitations ensure reliable cross-client compatibility. See [supported schema types](https://modelcontextprotocol.io/specification/2025-06-18/client/elicitation#supported-schema-types) for details.
530+
531+
### Security
532+
533+
MCP Elicitation requires careful handling - servers must not request sensitive information, and clients must implement user approval controls with clear explanations. See [security considerations](https://modelcontextprotocol.io/specification/2025-06-18/client/elicitation#security-considerations) for details.

pydantic_ai_slim/pydantic_ai/mcp.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,13 @@
1818
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
1919
from typing_extensions import Self, assert_never, deprecated
2020

21-
from pydantic_ai._run_context import RunContext
22-
from pydantic_ai.tools import ToolDefinition
21+
from pydantic_ai.tools import RunContext, ToolDefinition
2322

2423
from .toolsets.abstract import AbstractToolset, ToolsetTool
2524

2625
try:
2726
from mcp import types as mcp_types
28-
from mcp.client.session import ClientSession, LoggingFnT
27+
from mcp.client.session import ClientSession, ElicitationFnT, LoggingFnT
2928
from mcp.client.sse import sse_client
3029
from mcp.client.stdio import StdioServerParameters, stdio_client
3130
from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client
@@ -65,6 +64,7 @@ class MCPServer(AbstractToolset[Any], ABC):
6564
allow_sampling: bool
6665
sampling_model: models.Model | None
6766
max_retries: int
67+
elicitation_callback: ElicitationFnT | None = None
6868

6969
_id: str | None
7070

@@ -87,6 +87,7 @@ def __init__(
8787
allow_sampling: bool = True,
8888
sampling_model: models.Model | None = None,
8989
max_retries: int = 1,
90+
elicitation_callback: ElicitationFnT | None = None,
9091
*,
9192
id: str | None = None,
9293
):
@@ -99,6 +100,7 @@ def __init__(
99100
self.allow_sampling = allow_sampling
100101
self.sampling_model = sampling_model
101102
self.max_retries = max_retries
103+
self.elicitation_callback = elicitation_callback
102104

103105
self._id = id or tool_prefix
104106

@@ -247,6 +249,7 @@ async def __aenter__(self) -> Self:
247249
read_stream=self._read_stream,
248250
write_stream=self._write_stream,
249251
sampling_callback=self._sampling_callback if self.allow_sampling else None,
252+
elicitation_callback=self.elicitation_callback,
250253
logging_callback=self.log_handler,
251254
read_timeout_seconds=timedelta(seconds=self.read_timeout),
252255
)
@@ -445,6 +448,9 @@ async def main():
445448
max_retries: int
446449
"""The maximum number of times to retry a tool call."""
447450

451+
elicitation_callback: ElicitationFnT | None = None
452+
"""Callback function to handle elicitation requests from the server."""
453+
448454
def __init__(
449455
self,
450456
command: str,
@@ -460,6 +466,7 @@ def __init__(
460466
allow_sampling: bool = True,
461467
sampling_model: models.Model | None = None,
462468
max_retries: int = 1,
469+
elicitation_callback: ElicitationFnT | None = None,
463470
*,
464471
id: str | None = None,
465472
):
@@ -479,6 +486,7 @@ def __init__(
479486
allow_sampling: Whether to allow MCP sampling through this client.
480487
sampling_model: The model to use for sampling.
481488
max_retries: The maximum number of times to retry a tool call.
489+
elicitation_callback: Callback function to handle elicitation requests from the server.
482490
id: An optional unique ID for the MCP server. An MCP server needs to have an ID in order to be used in a durable execution environment like Temporal, in which case the ID will be used to identify the server's activities within the workflow.
483491
"""
484492
self.command = command
@@ -496,6 +504,7 @@ def __init__(
496504
allow_sampling,
497505
sampling_model,
498506
max_retries,
507+
elicitation_callback,
499508
id=id,
500509
)
501510

@@ -605,6 +614,9 @@ class _MCPServerHTTP(MCPServer):
605614
max_retries: int
606615
"""The maximum number of times to retry a tool call."""
607616

617+
elicitation_callback: ElicitationFnT | None = None
618+
"""Callback function to handle elicitation requests from the server."""
619+
608620
def __init__(
609621
self,
610622
*,
@@ -621,6 +633,7 @@ def __init__(
621633
allow_sampling: bool = True,
622634
sampling_model: models.Model | None = None,
623635
max_retries: int = 1,
636+
elicitation_callback: ElicitationFnT | None = None,
624637
**_deprecated_kwargs: Any,
625638
):
626639
"""Build a new MCP server.
@@ -639,6 +652,7 @@ def __init__(
639652
allow_sampling: Whether to allow MCP sampling through this client.
640653
sampling_model: The model to use for sampling.
641654
max_retries: The maximum number of times to retry a tool call.
655+
elicitation_callback: Callback function to handle elicitation requests from the server.
642656
"""
643657
if 'sse_read_timeout' in _deprecated_kwargs:
644658
if read_timeout is not None:
@@ -668,6 +682,7 @@ def __init__(
668682
allow_sampling,
669683
sampling_model,
670684
max_retries,
685+
elicitation_callback,
671686
id=id,
672687
)
673688

pydantic_ai_slim/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ tavily = ["tavily-python>=0.5.0"]
8484
# CLI
8585
cli = ["rich>=13", "prompt-toolkit>=3", "argcomplete>=3.5.0", "pyperclip>=1.9.0"]
8686
# MCP
87-
mcp = ["mcp>=1.10.0; python_version >= '3.10'"]
87+
mcp = ["mcp>=1.12.3; python_version >= '3.10'"]
8888
# Evals
8989
evals = ["pydantic-evals=={{ version }}"]
9090
# A2A

tests/mcp_server.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
from typing import Any
44

55
from mcp.server.fastmcp import Context, FastMCP, Image
6-
from mcp.server.session import ServerSessionT
7-
from mcp.shared.context import LifespanContextT, RequestT
6+
from mcp.server.session import ServerSession
87
from mcp.types import (
98
BlobResourceContents,
109
EmbeddedResource,
@@ -13,7 +12,7 @@
1312
TextContent,
1413
TextResourceContents,
1514
)
16-
from pydantic import AnyUrl
15+
from pydantic import AnyUrl, BaseModel
1716

1817
mcp = FastMCP('Pydantic AI MCP Server')
1918
log_level = 'unset'
@@ -170,7 +169,7 @@ async def get_log_level(ctx: Context) -> str: # type: ignore
170169

171170

172171
@mcp.tool()
173-
async def echo_deps(ctx: Context[ServerSessionT, LifespanContextT, RequestT]) -> dict[str, Any]:
172+
async def echo_deps(ctx: Context[ServerSession, None]) -> dict[str, Any]:
174173
"""Echo the run context.
175174
176175
Args:
@@ -186,7 +185,7 @@ async def echo_deps(ctx: Context[ServerSessionT, LifespanContextT, RequestT]) ->
186185

187186

188187
@mcp.tool()
189-
async def use_sampling(ctx: Context, foo: str) -> str: # type: ignore
188+
async def use_sampling(ctx: Context[ServerSession, None], foo: str) -> str:
190189
"""Use sampling callback."""
191190

192191
result = await ctx.session.create_message(
@@ -202,6 +201,22 @@ async def use_sampling(ctx: Context, foo: str) -> str: # type: ignore
202201
return result.model_dump_json(indent=2)
203202

204203

204+
class UserResponse(BaseModel):
205+
response: str
206+
207+
208+
@mcp.tool()
209+
async def use_elicitation(ctx: Context[ServerSession, None], question: str) -> str:
210+
"""Use elicitation callback to ask the user a question."""
211+
212+
result = await ctx.elicit(message=question, schema=UserResponse)
213+
214+
if result.action == 'accept' and result.data:
215+
return f'User responded: {result.data.response}'
216+
else:
217+
return f'User {result.action}ed the elicitation'
218+
219+
205220
@mcp._mcp_server.set_logging_level() # pyright: ignore[reportPrivateUsage]
206221
async def set_logging_level(level: str) -> None:
207222
global log_level

0 commit comments

Comments
 (0)