Skip to content

feat: add elicitation callback support to MCP servers #2373

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 143 additions & 5 deletions docs/mcp/client.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pip/uv-add "pydantic-ai-slim[mcp]"
```

!!! note
MCP integration requires Python 3.10 or higher.
MCP integration requires Python 3.10 or higher.

## Usage

Expand All @@ -34,7 +34,7 @@ You can use the [`async with agent`][pydantic_ai.Agent.__aenter__] context manag
[Streamable HTTP](https://modelcontextprotocol.io/introduction#streamable-http) transport to a server.

!!! note
[`MCPServerStreamableHTTP`][pydantic_ai.mcp.MCPServerStreamableHTTP] requires an MCP server to be running and accepting HTTP connections before running the agent. Running the server is not managed by Pydantic AI.
[`MCPServerStreamableHTTP`][pydantic_ai.mcp.MCPServerStreamableHTTP] requires an MCP server to be running and accepting HTTP connections before running the agent. Running the server is not managed by Pydantic AI.

Before creating the Streamable HTTP client, we need to run a server that supports the Streamable HTTP transport.

Expand Down Expand Up @@ -100,7 +100,7 @@ Will display as follows:
[`MCPServerSSE`][pydantic_ai.mcp.MCPServerSSE] connects over HTTP using the [HTTP + Server Sent Events transport](https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#http-with-sse) to a server.

!!! note
[`MCPServerSSE`][pydantic_ai.mcp.MCPServerSSE] requires an MCP server to be running and accepting HTTP connections before running the agent. Running the server is not managed by Pydantic AI.
[`MCPServerSSE`][pydantic_ai.mcp.MCPServerSSE] requires an MCP server to be running and accepting HTTP connections before running the agent. Running the server is not managed by Pydantic AI.

The name "HTTP" is used since this implementation will be adapted in future to use the new
[Streamable HTTP](https://github.com/modelcontextprotocol/specification/pull/206) currently in development.
Expand Down Expand Up @@ -280,13 +280,13 @@ async def main():
```

1. When you supply `http_client`, Pydantic AI re-uses this client for every
request. Anything supported by **httpx** (`verify`, `cert`, custom
request. Anything supported by **httpx** (`verify`, `cert`, custom
proxies, timeouts, etc.) therefore applies to all MCP traffic.

## MCP Sampling

!!! info "What is MCP Sampling?"
In MCP [sampling](https://modelcontextprotocol.io/docs/concepts/sampling) is a system by which an MCP server can make LLM calls via the MCP client - effectively proxying requests to an LLM via the client over whatever transport is being used.
In MCP [sampling](https://modelcontextprotocol.io/docs/concepts/sampling) is a system by which an MCP server can make LLM calls via the MCP client - effectively proxying requests to an LLM via the client over whatever transport is being used.

Sampling is extremely useful when MCP servers need to use Gen AI but you don't want to provision them each with their own LLM credentials or when a public MCP server would like the connecting client to pay for LLM calls.

Expand Down Expand Up @@ -391,3 +391,141 @@ server = MCPServerStdio(
allow_sampling=False,
)
```

## Elicitation

In MCP, [elicitation](https://modelcontextprotocol.io/docs/concepts/elicitation) allows a server to request for [structured input](https://modelcontextprotocol.io/specification/2025-06-18/client/elicitation#supported-schema-types) from the client for missing or additional context during a session.

Elicitation let models essentially say "Hold on - I need to know X before i can continue" rather than requiring everything upfront or taking a shot in the dark.

### How Elicitation works

Elicitation introduces a new protocol message type called [`ElicitRequest`](https://modelcontextprotocol.io/specification/2025-06-18/schema#elicitrequest), which is sent from the server to the client when it needs additional information. The client can then respond with an [`ElicitResult`](https://modelcontextprotocol.io/specification/2025-06-18/schema#elicitresult) or an `ErrorData` message.

Here's a typical interaction:

- User makes a request to the MCP server (e.g. "Book a table at that Italian place")
- The server identifies that it needs more information (e.g. "Which Italian place?", "What date and time?")
- The server sends an `ElicitRequest` to the client asking for the missing information.
- The client receives the request, presents it to the user (e.g. via a terminal prompt, GUI dialog, or web interface).
- User provides the requested information, `decline` or `cancel` the request.
- The client sends an `ElicitResult` back to the server with the user's response.
- With the structured data, the server can continue processing the original request.

This allows for a more interactive and user-friendly experience, especially for multi-staged workflows. Instead of requiring all information upfront, the server can ask for it as needed, making the interaction feel more natural.

### Setting up Elicitation

To enable elicitation, provide an [`elicitation_callback`][pydantic_ai.mcp.MCPServer.elicitation_callback] function when creating your MCP server instance:

```python {title="restaurant_server.py" py="3.10"}
from mcp.server.fastmcp import Context, FastMCP
from pydantic import BaseModel, Field

mcp = FastMCP(name='Restaurant Booking')


class BookingDetails(BaseModel):
"""Schema for restaurant booking information."""

restaurant: str = Field(description='Choose a restaurant')
party_size: int = Field(description='Number of people', ge=1, le=8)
date: str = Field(description='Reservation date (DD-MM-YYYY)')


@mcp.tool()
async def book_table(ctx: Context) -> str:
"""Book a restaurant table with user input."""
# Ask user for booking details using Pydantic schema
result = await ctx.elicit(message='Please provide your booking details:', schema=BookingDetails)

if result.action == 'accept' and result.data:
booking = result.data
return f'✅ Booked table for {booking.party_size} at {booking.restaurant} on {booking.date}'
elif result.action == 'decline':
return 'No problem! Maybe another time.'
else: # cancel
return 'Booking cancelled.'


if __name__ == '__main__':
mcp.run(transport='stdio')
```

```python {title="client_example.py" py="3.10" requires="restaurant_server.py" test="skip"}
import asyncio
from typing import Any

from mcp.client.session import ClientSession
from mcp.shared.context import RequestContext
from mcp.types import ElicitRequestParams, ElicitResult

from pydantic_ai import Agent
from pydantic_ai.mcp import MCPServerStdio


async def handle_elicitation(
context: RequestContext[ClientSession, Any, Any],
params: ElicitRequestParams,
) -> ElicitResult:
"""Handle elicitation requests from MCP server."""
print(f'\n{params.message}')

if not params.requestedSchema:
response = input('Response: ')
return ElicitResult(action='accept', content={'response': response})

# Collect data for each field
properties = params.requestedSchema['properties']
data = {}

for field, info in properties.items():
description = info.get('description', field)

value = input(f'{description}: ')

# Convert to proper type based on JSON schema
if info.get('type') == 'integer':
data[field] = int(value)
else:
data[field] = value

# Confirm
confirm = input('\nConfirm booking? (y/n/c): ').lower()

if confirm == 'y':
print('Booking details:', data)
return ElicitResult(action='accept', content=data)
elif confirm == 'n':
return ElicitResult(action='decline')
else:
return ElicitResult(action='cancel')


# Set up MCP server connection
restaurant_server = MCPServerStdio(
command='python', args=['restaurant_server.py'], elicitation_callback=handle_elicitation
)

# Create agent
agent = Agent('openai:gpt-4o', toolsets=[restaurant_server])


async def main():
"""Run the agent to book a restaurant table."""
async with agent:
result = await agent.run('Book me a table')
print(f'\nResult: {result.output}')


if __name__ == '__main__':
asyncio.run(main())
```

### Supported Schema Types

MCP elicitation supports string, number, boolean, and enum types with flat object structures only. These limitations ensure reliable cross-client compatibility. See [supported schema types](https://modelcontextprotocol.io/specification/2025-06-18/client/elicitation#supported-schema-types) for details.
Comment on lines +527 to +529
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this section is valuable since the schema type limitations are key constraint that everybody should know. should we move it earlier or completely remove it?


### Security

MCP Elicitation requires careful handling - servers must not request sensitive information, and clients must implement user approval controls with clear explanations. See [security considerations](https://modelcontextprotocol.io/specification/2025-06-18/client/elicitation#security-considerations) for details.
29 changes: 25 additions & 4 deletions pydantic_ai_slim/pydantic_ai/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,13 @@
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from typing_extensions import Self, assert_never, deprecated

from pydantic_ai._run_context import RunContext
from pydantic_ai.tools import ToolDefinition
from pydantic_ai.tools import RunContext, ToolDefinition

from .toolsets.abstract import AbstractToolset, ToolsetTool

try:
from mcp import types as mcp_types
from mcp.client.session import ClientSession, LoggingFnT
from mcp.client.session import ClientSession, ElicitationFnT, LoggingFnT
from mcp.client.sse import sse_client
from mcp.client.stdio import StdioServerParameters, stdio_client
from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client
Expand All @@ -41,7 +40,13 @@
# after mcp imports so any import error maps to this file, not _mcp.py
from . import _mcp, _utils, exceptions, messages, models

__all__ = 'MCPServer', 'MCPServerStdio', 'MCPServerHTTP', 'MCPServerSSE', 'MCPServerStreamableHTTP'
__all__ = (
'MCPServer',
'MCPServerStdio',
'MCPServerHTTP',
'MCPServerSSE',
'MCPServerStreamableHTTP',
)

TOOL_SCHEMA_VALIDATOR = pydantic_core.SchemaValidator(
schema=pydantic_core.core_schema.dict_schema(
Expand All @@ -65,6 +70,7 @@ class MCPServer(AbstractToolset[Any], ABC):
allow_sampling: bool
sampling_model: models.Model | None
max_retries: int
elicitation_callback: ElicitationFnT | None = None

_id: str | None

Expand All @@ -87,6 +93,7 @@ def __init__(
allow_sampling: bool = True,
sampling_model: models.Model | None = None,
max_retries: int = 1,
elicitation_callback: ElicitationFnT | None = None,
*,
id: str | None = None,
):
Expand All @@ -99,6 +106,7 @@ def __init__(
self.allow_sampling = allow_sampling
self.sampling_model = sampling_model
self.max_retries = max_retries
self.elicitation_callback = elicitation_callback

self._id = id or tool_prefix

Expand Down Expand Up @@ -247,6 +255,7 @@ async def __aenter__(self) -> Self:
read_stream=self._read_stream,
write_stream=self._write_stream,
sampling_callback=self._sampling_callback if self.allow_sampling else None,
elicitation_callback=self.elicitation_callback,
logging_callback=self.log_handler,
read_timeout_seconds=timedelta(seconds=self.read_timeout),
)
Expand Down Expand Up @@ -445,6 +454,9 @@ async def main():
max_retries: int
"""The maximum number of times to retry a tool call."""

elicitation_callback: ElicitationFnT | None = None
"""Callback function to handle elicitation requests from the server."""

def __init__(
self,
command: str,
Expand All @@ -460,6 +472,7 @@ def __init__(
allow_sampling: bool = True,
sampling_model: models.Model | None = None,
max_retries: int = 1,
elicitation_callback: ElicitationFnT | None = None,
*,
id: str | None = None,
):
Expand All @@ -479,6 +492,7 @@ def __init__(
allow_sampling: Whether to allow MCP sampling through this client.
sampling_model: The model to use for sampling.
max_retries: The maximum number of times to retry a tool call.
elicitation_callback: Callback function to handle elicitation requests from the server.
id: An optional unique ID for the MCP server. An MCP server needs to have an ID in order to be used in a durable execution environment like Temporal, in which case the ID will be used to identify the server's activities within the workflow.
"""
self.command = command
Expand All @@ -496,6 +510,7 @@ def __init__(
allow_sampling,
sampling_model,
max_retries,
elicitation_callback,
id=id,
)

Expand Down Expand Up @@ -605,6 +620,9 @@ class _MCPServerHTTP(MCPServer):
max_retries: int
"""The maximum number of times to retry a tool call."""

elicitation_callback: ElicitationFnT | None = None
"""Callback function to handle elicitation requests from the server."""

def __init__(
self,
*,
Expand All @@ -621,6 +639,7 @@ def __init__(
allow_sampling: bool = True,
sampling_model: models.Model | None = None,
max_retries: int = 1,
elicitation_callback: ElicitationFnT | None = None,
**_deprecated_kwargs: Any,
):
"""Build a new MCP server.
Expand All @@ -639,6 +658,7 @@ def __init__(
allow_sampling: Whether to allow MCP sampling through this client.
sampling_model: The model to use for sampling.
max_retries: The maximum number of times to retry a tool call.
elicitation_callback: Callback function to handle elicitation requests from the server.
"""
if 'sse_read_timeout' in _deprecated_kwargs:
if read_timeout is not None:
Expand Down Expand Up @@ -668,6 +688,7 @@ def __init__(
allow_sampling,
sampling_model,
max_retries,
elicitation_callback,
id=id,
)

Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ tavily = ["tavily-python>=0.5.0"]
# CLI
cli = ["rich>=13", "prompt-toolkit>=3", "argcomplete>=3.5.0"]
# MCP
mcp = ["mcp>=1.10.0; python_version >= '3.10'"]
mcp = ["mcp>=1.12.3; python_version >= '3.10'"]
# Evals
evals = ["pydantic-evals=={{ version }}"]
# A2A
Expand Down
20 changes: 18 additions & 2 deletions tests/mcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
TextContent,
TextResourceContents,
)
from pydantic import AnyUrl
from pydantic import AnyUrl, BaseModel

mcp = FastMCP('Pydantic AI MCP Server')
log_level = 'unset'
Expand Down Expand Up @@ -186,7 +186,7 @@ async def echo_deps(ctx: Context[ServerSessionT, LifespanContextT, RequestT]) ->


@mcp.tool()
async def use_sampling(ctx: Context, foo: str) -> str: # type: ignore
async def use_sampling(ctx: Context[ServerSessionT, LifespanContextT, RequestT], foo: str) -> str:
"""Use sampling callback."""

result = await ctx.session.create_message(
Expand All @@ -202,6 +202,22 @@ async def use_sampling(ctx: Context, foo: str) -> str: # type: ignore
return result.model_dump_json(indent=2)


class UserResponse(BaseModel):
response: str


@mcp.tool()
async def use_elicitation(ctx: Context[ServerSessionT, LifespanContextT, RequestT], question: str) -> str:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not the correct type.

It's probably this:

Suggested change
async def use_elicitation(ctx: Context[ServerSessionT, LifespanContextT, RequestT], question: str) -> str:
async def use_elicitation(ctx: Context[ServerSession, None], question: str) -> str:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i noticed the other mcp test functions (echo_deps, use_sampling) all use Context[ServerSessionT, LifespanContextT, RequestT] as types. for consistency, should we update other functions or is there a reason use_elicitation should use concrete types instead of generics?

"""Use elicitation callback to ask the user a question."""

result = await ctx.elicit(message=question, schema=UserResponse)

if result.action == 'accept' and result.data:
return f'User responded: {result.data.response}'
else:
return f'User {result.action}ed the elicitation'


@mcp._mcp_server.set_logging_level() # pyright: ignore[reportPrivateUsage]
async def set_logging_level(level: str) -> None:
global log_level
Expand Down
Loading