Skip to content

feat: add elicitation callback support to MCP servers #2373

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions docs/mcp/client.md
Original file line number Diff line number Diff line change
Expand Up @@ -391,3 +391,106 @@ server = MCPServerStdio(
allow_sampling=False,
)
```

## Elicitation

Sometimes MCP servers need to ask the user questions during tool execution. For example, a file management tool might ask "Are you sure you want to delete this file?" before performing a destructive action, or a deployment tool might need confirmation before deploying to production.

In MCP, elicitation allows servers to pause tool execution and request input from the user via the client. The server sends a message, the client presents it to the user, collects their response, and sends it back to the server.

### Setting up Elicitation

To enable elicitation, provide an [`elicitation_callback`][pydantic_ai.mcp.MCPServerStdio.elicitation_callback] function when creating your MCP server instance:

```python {title="simple_elicitation.py" py="3.10"}
from pydantic_ai import Agent
from pydantic_ai.mcp import MCPServerStdio


async def ask_user(message: str) -> str:
print(f"Server asks: {message}")
return input("Your answer: ")


server = MCPServerStdio(
command='python',
args=['my_server.py'],
elicitation_callback=ask_user # (1)!
)

agent = Agent('openai:gpt-4o', toolsets=[server])
```

1. This function is called whenever the server needs user input.

The elicitation callback is an async function that receives a message from the server and returns the user's response as a string. Your callback can be as simple as requesting terminal input, or as sophisticated as showing GUI dialogs, sending notifications, or integrating with web interfaces.

### File Deletion Example

Here's a practical example showing how an MCP server might use elicitation for confirmation dialogs:

```python {title="file_server.py" py="3.10"}
from mcp.server.fastmcp import Context, FastMCP

app = FastMCP('File Manager')

@app.tool()
async def delete_file(ctx: Context, filename: str) -> str:
"""Delete a file after getting user confirmation."""
# The server asks the client for input
user_response = await ctx.session.elicitation(
message=f"Delete '{filename}'? This cannot be undone! (yes/no)",
timeout_seconds=30,
)

if user_response.lower() in ['yes', 'y']:
# In real life, you'd actually delete the file here
return f"Deleted {filename}"
else:
return f"Cancelled deletion of {filename}"

if __name__ == '__main__':
app.run()
```

The corresponding client handles the confirmation request:

```python {title="file_client.py" py="3.10" test="skip"}
from pydantic_ai import Agent
from pydantic_ai.mcp import MCPServerStdio


async def handle_confirmation(message: str) -> str:
"""Present the confirmation dialog to the user."""
print(f"Warning: {message}")

while True:
response = input("Your choice: ").strip().lower()
if response in ['yes', 'no', 'y', 'n']:
return response
print("Please answer 'yes' or 'no'")


server = MCPServerStdio(
command='python',
args=['file_server.py'],
elicitation_callback=handle_confirmation
)

agent = Agent('openai:gpt-4o', toolsets=[server])

async def main():
async with agent:
result = await agent.run('Delete the file called important_data.txt')
print(result.output)
```

When executed, this produces an interactive confirmation dialog:

```
Warning: Delete 'important_data.txt'? This cannot be undone! (yes/no)
Your choice: no
Cancelled deletion of important_data.txt
```

The interaction flows from the AI agent through the server's elicitation request, to your callback function, and back through the system with the user's decision.
23 changes: 19 additions & 4 deletions pydantic_ai_slim/pydantic_ai/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,13 @@
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from typing_extensions import Self, assert_never, deprecated

from pydantic_ai._run_context import RunContext
from pydantic_ai.tools import ToolDefinition
from pydantic_ai.tools import RunContext, ToolDefinition

from .toolsets.abstract import AbstractToolset, ToolsetTool

try:
from mcp import types as mcp_types
from mcp.client.session import ClientSession, LoggingFnT
from mcp.client.session import ClientSession, ElicitationFnT, LoggingFnT
from mcp.client.sse import sse_client
from mcp.client.stdio import StdioServerParameters, stdio_client
from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client
Expand All @@ -41,7 +40,13 @@
# after mcp imports so any import error maps to this file, not _mcp.py
from . import _mcp, _utils, exceptions, messages, models

__all__ = 'MCPServer', 'MCPServerStdio', 'MCPServerHTTP', 'MCPServerSSE', 'MCPServerStreamableHTTP'
__all__ = (
'MCPServer',
'MCPServerStdio',
'MCPServerHTTP',
'MCPServerSSE',
'MCPServerStreamableHTTP',
)

TOOL_SCHEMA_VALIDATOR = pydantic_core.SchemaValidator(
schema=pydantic_core.core_schema.dict_schema(
Expand All @@ -66,6 +71,7 @@ class MCPServer(AbstractToolset[Any], ABC):
allow_sampling: bool = True
max_retries: int = 1
sampling_model: models.Model | None = None
elicitation_callback: ElicitationFnT | None = None
# } end of "abstract fields"

_enter_lock: Lock = field(compare=False)
Expand Down Expand Up @@ -207,6 +213,7 @@ async def __aenter__(self) -> Self:
read_stream=self._read_stream,
write_stream=self._write_stream,
sampling_callback=self._sampling_callback if self.allow_sampling else None,
elicitation_callback=self.elicitation_callback,
logging_callback=self.log_handler,
read_timeout_seconds=timedelta(seconds=self.read_timeout),
)
Expand Down Expand Up @@ -398,6 +405,9 @@ async def main():
sampling_model: models.Model | None = None
"""The model to use for sampling."""

elicitation_callback: ElicitationFnT | None = None
"""Callback function to handle elicitation requests from the server."""

@asynccontextmanager
async def client_streams(
self,
Expand Down Expand Up @@ -499,6 +509,9 @@ class _MCPServerHTTP(MCPServer):
sampling_model: models.Model | None = None
"""The model to use for sampling."""

elicitation_callback: ElicitationFnT | None = None
"""Callback function to handle elicitation requests from the server."""

def __init__(
self,
*,
Expand All @@ -514,6 +527,7 @@ def __init__(
allow_sampling: bool = True,
max_retries: int = 1,
sampling_model: models.Model | None = None,
elicitation_callback: ElicitationFnT | None = None,
**kwargs: Any,
):
# Handle deprecated sse_read_timeout parameter
Expand Down Expand Up @@ -542,6 +556,7 @@ def __init__(
self.allow_sampling = allow_sampling
self.max_retries = max_retries
self.sampling_model = sampling_model
self.elicitation_callback = elicitation_callback
self.read_timeout = read_timeout

@property
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ tavily = ["tavily-python>=0.5.0"]
# CLI
cli = ["rich>=13", "prompt-toolkit>=3", "argcomplete>=3.5.0"]
# MCP
mcp = ["mcp>=1.10.0; python_version >= '3.10'"]
mcp = ["mcp>=1.12.2; python_version >= '3.10'"]
# Evals
evals = ["pydantic-evals=={{ version }}"]
# A2A
Expand Down
20 changes: 18 additions & 2 deletions tests/mcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
TextContent,
TextResourceContents,
)
from pydantic import AnyUrl
from pydantic import AnyUrl, BaseModel

mcp = FastMCP('Pydantic AI MCP Server')
log_level = 'unset'
Expand Down Expand Up @@ -186,7 +186,7 @@ async def echo_deps(ctx: Context[ServerSessionT, LifespanContextT, RequestT]) ->


@mcp.tool()
async def use_sampling(ctx: Context, foo: str) -> str: # type: ignore
async def use_sampling(ctx: Context[ServerSessionT, LifespanContextT, RequestT], foo: str) -> str:
"""Use sampling callback."""

result = await ctx.session.create_message(
Expand All @@ -202,6 +202,22 @@ async def use_sampling(ctx: Context, foo: str) -> str: # type: ignore
return result.model_dump_json(indent=2)


class UserResponse(BaseModel):
response: str


@mcp.tool()
async def use_elicitation(ctx: Context[ServerSessionT, LifespanContextT, RequestT], question: str) -> str:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not the correct type.

It's probably this:

Suggested change
async def use_elicitation(ctx: Context[ServerSessionT, LifespanContextT, RequestT], question: str) -> str:
async def use_elicitation(ctx: Context[ServerSession, None], question: str) -> str:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i noticed the other mcp test functions (echo_deps, use_sampling) all use Context[ServerSessionT, LifespanContextT, RequestT] as types. for consistency, should we update other functions or is there a reason use_elicitation should use concrete types instead of generics?

"""Use elicitation callback to ask the user a question."""

result = await ctx.elicit(message=question, schema=UserResponse)

if result.action == 'accept' and result.data:
return f'User responded: {result.data.response}'
else:
return f'User {result.action}ed the elicitation'


@mcp._mcp_server.set_logging_level() # pyright: ignore[reportPrivateUsage]
async def set_logging_level(level: str) -> None:
global log_level
Expand Down
47 changes: 43 additions & 4 deletions tests/test_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@

with try_import() as imports_successful:
from mcp import ErrorData, McpError, SamplingMessage
from mcp.types import CreateMessageRequestParams, ImageContent, TextContent
from mcp.client.session import ClientSession
from mcp.shared.context import RequestContext
from mcp.types import CreateMessageRequestParams, ElicitRequestParams, ElicitResult, ImageContent, TextContent

from pydantic_ai._mcp import map_from_mcp_params, map_from_model_response
from pydantic_ai.mcp import CallToolFunc, MCPServerSSE, MCPServerStdio, ToolResult
Expand Down Expand Up @@ -74,7 +76,7 @@ async def test_stdio_server(run_context: RunContext[int]):
server = MCPServerStdio('python', ['-m', 'tests.mcp_server'])
async with server:
tools = [tool.tool_def for tool in (await server.get_tools(run_context)).values()]
assert len(tools) == snapshot(16)
assert len(tools) == snapshot(17)
assert tools[0].name == 'celsius_to_fahrenheit'
assert isinstance(tools[0].description, str)
assert tools[0].description.startswith('Convert Celsius to Fahrenheit.')
Expand Down Expand Up @@ -122,7 +124,7 @@ async def test_stdio_server_with_cwd(run_context: RunContext[int]):
server = MCPServerStdio('python', ['mcp_server.py'], cwd=test_dir)
async with server:
tools = await server.get_tools(run_context)
assert len(tools) == snapshot(16)
assert len(tools) == snapshot(17)


async def test_process_tool_call(run_context: RunContext[int]) -> int:
Expand Down Expand Up @@ -297,7 +299,7 @@ async def test_log_level_unset(run_context: RunContext[int]):
assert server.log_level is None
async with server:
tools = [tool.tool_def for tool in (await server.get_tools(run_context)).values()]
assert len(tools) == snapshot(16)
assert len(tools) == snapshot(17)
assert tools[13].name == 'get_log_level'

result = await server.direct_call_tool('get_log_level', {})
Expand Down Expand Up @@ -1322,3 +1324,40 @@ def test_map_from_mcp_params_model_response():
def test_map_from_model_response():
with pytest.raises(UnexpectedModelBehavior, match='Unexpected part type: ThinkingPart, expected TextPart'):
map_from_model_response(ModelResponse(parts=[ThinkingPart(content='Thinking...')]))


async def test_elicitation_callback_functionality(run_context: RunContext[int]):
"""Test that elicitation callback is actually called and works."""
# Track callback execution
callback_called = False
callback_message = None
callback_response = 'Yes, proceed with the action'

async def mock_elicitation_callback(
context: RequestContext[ClientSession, Any, Any], params: ElicitRequestParams
) -> ElicitResult:
nonlocal callback_called, callback_message
callback_called = True
callback_message = params.message
return ElicitResult(action='accept', content={'response': callback_response})

server = MCPServerStdio('python', ['-m', 'tests.mcp_server'], elicitation_callback=mock_elicitation_callback)

async with server:
# Call the tool that uses elicitation
result = await server.direct_call_tool('use_elicitation', {'question': 'Should I continue?'})

# Verify the callback was called
assert callback_called, 'Elicitation callback should have been called'
assert callback_message == 'Should I continue?', 'Callback should receive the question'
assert result == f'User responded: {callback_response}', 'Tool should return the callback response'


async def test_elicitation_callback_not_set(run_context: RunContext[int]):
"""Test that elicitation fails when no callback is set."""
server = MCPServerStdio('python', ['-m', 'tests.mcp_server'])

async with server:
# Should raise an error when elicitation is attempted without callback
with pytest.raises(ModelRetry, match='Elicitation not supported'):
await server.direct_call_tool('use_elicitation', {'question': 'Should I continue?'})
8 changes: 4 additions & 4 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.