diff --git a/docs/mcp/client.md b/docs/mcp/client.md index ce2f30d33..5d9e723d9 100644 --- a/docs/mcp/client.md +++ b/docs/mcp/client.md @@ -100,7 +100,7 @@ Will display as follows: [`MCPServerSSE`][pydantic_ai.mcp.MCPServerSSE] connects over HTTP using the [HTTP + Server Sent Events transport](https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#http-with-sse) to a server. !!! note - [`MCPServerSSE`][pydantic_ai.mcp.MCPServerSSE] requires an MCP server to be running and accepting HTTP connections before running the agent. Running the server is not managed by Pydantic AI. +[`MCPServerSSE`][pydantic_ai.mcp.MCPServerSSE] requires an MCP server to be running and accepting HTTP connections before running the agent. Running the server is not managed by Pydantic AI. The name "HTTP" is used since this implementation will be adapted in future to use the new [Streamable HTTP](https://github.com/modelcontextprotocol/specification/pull/206) currently in development. @@ -280,13 +280,13 @@ async def main(): ``` 1. When you supply `http_client`, Pydantic AI re-uses this client for every - request. Anything supported by **httpx** (`verify`, `cert`, custom + request. Anything supported by **httpx** (`verify`, `cert`, custom proxies, timeouts, etc.) therefore applies to all MCP traffic. ## MCP Sampling !!! info "What is MCP Sampling?" - In MCP [sampling](https://modelcontextprotocol.io/docs/concepts/sampling) is a system by which an MCP server can make LLM calls via the MCP client - effectively proxying requests to an LLM via the client over whatever transport is being used. +In MCP [sampling](https://modelcontextprotocol.io/docs/concepts/sampling) is a system by which an MCP server can make LLM calls via the MCP client - effectively proxying requests to an LLM via the client over whatever transport is being used. Sampling is extremely useful when MCP servers need to use Gen AI but you don't want to provision them each with their own LLM credentials or when a public MCP server would like the connecting client to pay for LLM calls. @@ -391,3 +391,143 @@ server = MCPServerStdio( allow_sampling=False, ) ``` + +## Elicitation + +In MCP, [elicitation](https://modelcontextprotocol.io/docs/concepts/elicitation) allows a server to request for [structured input](https://modelcontextprotocol.io/specification/2025-06-18/client/elicitation#supported-schema-types) from the client for missing or additional context during a session. + +Elicitation let models essentially say "Hold on - I need to know X before i can continue" rather than requiring everything upfront or taking a shot in the dark. + +### How Elicitation works + +Elicitation introduces a new protocol message type called [`ElicitRequest`](https://modelcontextprotocol.io/specification/2025-06-18/schema#elicitrequest), which is sent from the server to the client when it needs additional information. The client can then respond with an [`ElicitResult`](https://modelcontextprotocol.io/specification/2025-06-18/schema#elicitresult) or an `ErrorData` message. + +Here's a typical interaction: + +- User makes a request to the MCP server (e.g. "Book a table at that Italian place") +- The server identifies that it needs more information (e.g. "Which Italian place?", "What date and time?") +- The server sends an `ElicitRequest` to the client asking for the missing information. +- The client receives the request, presents it to the user (e.g. via a terminal prompt, GUI dialog, or web interface). +- User provides the requested information, `decline` or `cancel` the request. +- The client sends an `ElicitResult` back to the server with the user's response. +- With the structured data, the server can continue processing the original request. + +This allows for a more interactive and user-friendly experience, especially for multi-staged workflows. Instead of requiring all information upfront, the server can ask for it as needed, making the interaction feel more natural. + +### Setting up Elicitation + +To enable elicitation, provide an [`elicitation_callback`][pydantic_ai.mcp.MCPServer.elicitation_callback] function when creating your MCP server instance: + +```python {title="restaurant_server.py" py="3.10"} +from mcp.server.fastmcp import Context, FastMCP +from pydantic import BaseModel, Field + +mcp = FastMCP(name='Restaurant Booking') + + +class BookingDetails(BaseModel): + """Schema for restaurant booking information.""" + + restaurant: str = Field(description='Choose a restaurant') + party_size: int = Field(description='Number of people', ge=1, le=8) + date: str = Field(description='Reservation date (DD-MM-YYYY)') + + +@mcp.tool() +async def book_table(ctx: Context) -> str: + """Book a restaurant table with user input.""" + # Ask user for booking details using Pydantic schema + result = await ctx.elicit(message='Please provide your booking details:', schema=BookingDetails) + + if result.action == 'accept' and result.data: + booking = result.data + return f'✅ Booked table for {booking.party_size} at {booking.restaurant} on {booking.date}' + elif result.action == 'decline': + return 'No problem! Maybe another time.' + else: # cancel + return 'Booking cancelled.' + + +if __name__ == '__main__': + mcp.run(transport='stdio') +``` + +This server demonstrates elicitation by requesting structured booking details from the client when the `book_table` tool is called. Here's how to create a client that handles these elicitation requests: + +```python {title="client_example.py" py="3.10" requires="restaurant_server.py" test="skip"} +import asyncio +from typing import Any + +from mcp.client.session import ClientSession +from mcp.shared.context import RequestContext +from mcp.types import ElicitRequestParams, ElicitResult + +from pydantic_ai import Agent +from pydantic_ai.mcp import MCPServerStdio + + +async def handle_elicitation( + context: RequestContext[ClientSession, Any, Any], + params: ElicitRequestParams, +) -> ElicitResult: + """Handle elicitation requests from MCP server.""" + print(f'\n{params.message}') + + if not params.requestedSchema: + response = input('Response: ') + return ElicitResult(action='accept', content={'response': response}) + + # Collect data for each field + properties = params.requestedSchema['properties'] + data = {} + + for field, info in properties.items(): + description = info.get('description', field) + + value = input(f'{description}: ') + + # Convert to proper type based on JSON schema + if info.get('type') == 'integer': + data[field] = int(value) + else: + data[field] = value + + # Confirm + confirm = input('\nConfirm booking? (y/n/c): ').lower() + + if confirm == 'y': + print('Booking details:', data) + return ElicitResult(action='accept', content=data) + elif confirm == 'n': + return ElicitResult(action='decline') + else: + return ElicitResult(action='cancel') + + +# Set up MCP server connection +restaurant_server = MCPServerStdio( + command='python', args=['restaurant_server.py'], elicitation_callback=handle_elicitation +) + +# Create agent +agent = Agent('openai:gpt-4o', toolsets=[restaurant_server]) + + +async def main(): + """Run the agent to book a restaurant table.""" + async with agent: + result = await agent.run('Book me a table') + print(f'\nResult: {result.output}') + + +if __name__ == '__main__': + asyncio.run(main()) +``` + +### Supported Schema Types + +MCP elicitation supports string, number, boolean, and enum types with flat object structures only. These limitations ensure reliable cross-client compatibility. See [supported schema types](https://modelcontextprotocol.io/specification/2025-06-18/client/elicitation#supported-schema-types) for details. + +### Security + +MCP Elicitation requires careful handling - servers must not request sensitive information, and clients must implement user approval controls with clear explanations. See [security considerations](https://modelcontextprotocol.io/specification/2025-06-18/client/elicitation#security-considerations) for details. diff --git a/pydantic_ai_slim/pydantic_ai/mcp.py b/pydantic_ai_slim/pydantic_ai/mcp.py index 8dd155992..35f0297f5 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp.py +++ b/pydantic_ai_slim/pydantic_ai/mcp.py @@ -18,14 +18,13 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from typing_extensions import Self, assert_never, deprecated -from pydantic_ai._run_context import RunContext -from pydantic_ai.tools import ToolDefinition +from pydantic_ai.tools import RunContext, ToolDefinition from .toolsets.abstract import AbstractToolset, ToolsetTool try: from mcp import types as mcp_types - from mcp.client.session import ClientSession, LoggingFnT + from mcp.client.session import ClientSession, ElicitationFnT, LoggingFnT from mcp.client.sse import sse_client from mcp.client.stdio import StdioServerParameters, stdio_client from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client @@ -65,6 +64,7 @@ class MCPServer(AbstractToolset[Any], ABC): allow_sampling: bool sampling_model: models.Model | None max_retries: int + elicitation_callback: ElicitationFnT | None = None _id: str | None @@ -87,6 +87,7 @@ def __init__( allow_sampling: bool = True, sampling_model: models.Model | None = None, max_retries: int = 1, + elicitation_callback: ElicitationFnT | None = None, *, id: str | None = None, ): @@ -99,6 +100,7 @@ def __init__( self.allow_sampling = allow_sampling self.sampling_model = sampling_model self.max_retries = max_retries + self.elicitation_callback = elicitation_callback self._id = id or tool_prefix @@ -247,6 +249,7 @@ async def __aenter__(self) -> Self: read_stream=self._read_stream, write_stream=self._write_stream, sampling_callback=self._sampling_callback if self.allow_sampling else None, + elicitation_callback=self.elicitation_callback, logging_callback=self.log_handler, read_timeout_seconds=timedelta(seconds=self.read_timeout), ) @@ -445,6 +448,9 @@ async def main(): max_retries: int """The maximum number of times to retry a tool call.""" + elicitation_callback: ElicitationFnT | None = None + """Callback function to handle elicitation requests from the server.""" + def __init__( self, command: str, @@ -460,6 +466,7 @@ def __init__( allow_sampling: bool = True, sampling_model: models.Model | None = None, max_retries: int = 1, + elicitation_callback: ElicitationFnT | None = None, *, id: str | None = None, ): @@ -479,6 +486,7 @@ def __init__( allow_sampling: Whether to allow MCP sampling through this client. sampling_model: The model to use for sampling. max_retries: The maximum number of times to retry a tool call. + elicitation_callback: Callback function to handle elicitation requests from the server. id: An optional unique ID for the MCP server. An MCP server needs to have an ID in order to be used in a durable execution environment like Temporal, in which case the ID will be used to identify the server's activities within the workflow. """ self.command = command @@ -496,6 +504,7 @@ def __init__( allow_sampling, sampling_model, max_retries, + elicitation_callback, id=id, ) @@ -605,6 +614,9 @@ class _MCPServerHTTP(MCPServer): max_retries: int """The maximum number of times to retry a tool call.""" + elicitation_callback: ElicitationFnT | None = None + """Callback function to handle elicitation requests from the server.""" + def __init__( self, *, @@ -621,6 +633,7 @@ def __init__( allow_sampling: bool = True, sampling_model: models.Model | None = None, max_retries: int = 1, + elicitation_callback: ElicitationFnT | None = None, **_deprecated_kwargs: Any, ): """Build a new MCP server. @@ -639,6 +652,7 @@ def __init__( allow_sampling: Whether to allow MCP sampling through this client. sampling_model: The model to use for sampling. max_retries: The maximum number of times to retry a tool call. + elicitation_callback: Callback function to handle elicitation requests from the server. """ if 'sse_read_timeout' in _deprecated_kwargs: if read_timeout is not None: @@ -668,6 +682,7 @@ def __init__( allow_sampling, sampling_model, max_retries, + elicitation_callback, id=id, ) diff --git a/pydantic_ai_slim/pyproject.toml b/pydantic_ai_slim/pyproject.toml index b1f1ea09e..dba2b6910 100644 --- a/pydantic_ai_slim/pyproject.toml +++ b/pydantic_ai_slim/pyproject.toml @@ -77,7 +77,7 @@ tavily = ["tavily-python>=0.5.0"] # CLI cli = ["rich>=13", "prompt-toolkit>=3", "argcomplete>=3.5.0"] # MCP -mcp = ["mcp>=1.10.0; python_version >= '3.10'"] +mcp = ["mcp>=1.12.3; python_version >= '3.10'"] # Evals evals = ["pydantic-evals=={{ version }}"] # A2A diff --git a/tests/mcp_server.py b/tests/mcp_server.py index f7fa75e9b..07a866a50 100644 --- a/tests/mcp_server.py +++ b/tests/mcp_server.py @@ -13,7 +13,7 @@ TextContent, TextResourceContents, ) -from pydantic import AnyUrl +from pydantic import AnyUrl, BaseModel mcp = FastMCP('Pydantic AI MCP Server') log_level = 'unset' @@ -186,7 +186,7 @@ async def echo_deps(ctx: Context[ServerSessionT, LifespanContextT, RequestT]) -> @mcp.tool() -async def use_sampling(ctx: Context, foo: str) -> str: # type: ignore +async def use_sampling(ctx: Context[ServerSessionT, LifespanContextT, RequestT], foo: str) -> str: """Use sampling callback.""" result = await ctx.session.create_message( @@ -202,6 +202,22 @@ async def use_sampling(ctx: Context, foo: str) -> str: # type: ignore return result.model_dump_json(indent=2) +class UserResponse(BaseModel): + response: str + + +@mcp.tool() +async def use_elicitation(ctx: Context[ServerSessionT, LifespanContextT, RequestT], question: str) -> str: + """Use elicitation callback to ask the user a question.""" + + result = await ctx.elicit(message=question, schema=UserResponse) + + if result.action == 'accept' and result.data: + return f'User responded: {result.data.response}' + else: + return f'User {result.action}ed the elicitation' + + @mcp._mcp_server.set_logging_level() # pyright: ignore[reportPrivateUsage] async def set_logging_level(level: str) -> None: global log_level diff --git a/tests/test_mcp.py b/tests/test_mcp.py index 178d0b362..c73184ade 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -34,7 +34,9 @@ with try_import() as imports_successful: from mcp import ErrorData, McpError, SamplingMessage - from mcp.types import CreateMessageRequestParams, ImageContent, TextContent + from mcp.client.session import ClientSession + from mcp.shared.context import RequestContext + from mcp.types import CreateMessageRequestParams, ElicitRequestParams, ElicitResult, ImageContent, TextContent from pydantic_ai._mcp import map_from_mcp_params, map_from_model_response from pydantic_ai.mcp import CallToolFunc, MCPServerSSE, MCPServerStdio, ToolResult @@ -74,7 +76,7 @@ async def test_stdio_server(run_context: RunContext[int]): server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) async with server: tools = [tool.tool_def for tool in (await server.get_tools(run_context)).values()] - assert len(tools) == snapshot(16) + assert len(tools) == snapshot(17) assert tools[0].name == 'celsius_to_fahrenheit' assert isinstance(tools[0].description, str) assert tools[0].description.startswith('Convert Celsius to Fahrenheit.') @@ -122,7 +124,7 @@ async def test_stdio_server_with_cwd(run_context: RunContext[int]): server = MCPServerStdio('python', ['mcp_server.py'], cwd=test_dir) async with server: tools = await server.get_tools(run_context) - assert len(tools) == snapshot(16) + assert len(tools) == snapshot(17) async def test_process_tool_call(run_context: RunContext[int]) -> int: @@ -297,7 +299,7 @@ async def test_log_level_unset(run_context: RunContext[int]): assert server.log_level is None async with server: tools = [tool.tool_def for tool in (await server.get_tools(run_context)).values()] - assert len(tools) == snapshot(16) + assert len(tools) == snapshot(17) assert tools[13].name == 'get_log_level' result = await server.direct_call_tool('get_log_level', {}) @@ -1322,3 +1324,40 @@ def test_map_from_mcp_params_model_response(): def test_map_from_model_response(): with pytest.raises(UnexpectedModelBehavior, match='Unexpected part type: ThinkingPart, expected TextPart'): map_from_model_response(ModelResponse(parts=[ThinkingPart(content='Thinking...')])) + + +async def test_elicitation_callback_functionality(run_context: RunContext[int]): + """Test that elicitation callback is actually called and works.""" + # Track callback execution + callback_called = False + callback_message = None + callback_response = 'Yes, proceed with the action' + + async def mock_elicitation_callback( + context: RequestContext[ClientSession, Any, Any], params: ElicitRequestParams + ) -> ElicitResult: + nonlocal callback_called, callback_message + callback_called = True + callback_message = params.message + return ElicitResult(action='accept', content={'response': callback_response}) + + server = MCPServerStdio('python', ['-m', 'tests.mcp_server'], elicitation_callback=mock_elicitation_callback) + + async with server: + # Call the tool that uses elicitation + result = await server.direct_call_tool('use_elicitation', {'question': 'Should I continue?'}) + + # Verify the callback was called + assert callback_called, 'Elicitation callback should have been called' + assert callback_message == 'Should I continue?', 'Callback should receive the question' + assert result == f'User responded: {callback_response}', 'Tool should return the callback response' + + +async def test_elicitation_callback_not_set(run_context: RunContext[int]): + """Test that elicitation fails when no callback is set.""" + server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) + + async with server: + # Should raise an error when elicitation is attempted without callback + with pytest.raises(ModelRetry, match='Elicitation not supported'): + await server.direct_call_tool('use_elicitation', {'question': 'Should I continue?'}) diff --git a/uv.lock b/uv.lock index 27c6a7222..aa243b1e6 100644 --- a/uv.lock +++ b/uv.lock @@ -1965,7 +1965,7 @@ wheels = [ [[package]] name = "mcp" -version = "1.12.1" +version = "1.12.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio", marker = "python_full_version >= '3.10'" }, @@ -1980,9 +1980,9 @@ dependencies = [ { name = "starlette", marker = "python_full_version >= '3.10'" }, { name = "uvicorn", marker = "python_full_version >= '3.10' and sys_platform != 'emscripten'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/5c/5a/16cef13b2e60d5f865fbc96372efb23dc8b0591f102dd55003b4ae62f9b1/mcp-1.12.1.tar.gz", hash = "sha256:d1d0bdeb09e4b17c1a72b356248bf3baf75ab10db7008ef865c4afbeb0eb810e", size = 425768, upload-time = "2025-07-22T16:51:41.66Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4d/19/9955e2df5384ff5dd25d38f8e88aaf89d2d3d9d39f27e7383eaf0b293836/mcp-1.12.3.tar.gz", hash = "sha256:ab2e05f5e5c13e1dc90a4a9ef23ac500a6121362a564447855ef0ab643a99fed", size = 427203, upload-time = "2025-07-31T18:36:36.795Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b9/04/9a967a575518fc958bda1e34a52eae0c7f6accf3534811914fdaf57b0689/mcp-1.12.1-py3-none-any.whl", hash = "sha256:34147f62891417f8b000c39718add844182ba424c8eb2cea250b4267bda4b08b", size = 158463, upload-time = "2025-07-22T16:51:40.086Z" }, + { url = "https://files.pythonhosted.org/packages/8f/8b/0be74e3308a486f1d127f3f6767de5f9f76454c9b4183210c61cc50999b6/mcp-1.12.3-py3-none-any.whl", hash = "sha256:5483345bf39033b858920a5b6348a303acacf45b23936972160ff152107b850e", size = 158810, upload-time = "2025-07-31T18:36:34.915Z" }, ] [package.optional-dependencies] @@ -3527,8 +3527,8 @@ requires-dist = [ { name = "groq", marker = "extra == 'groq'", specifier = ">=0.25.0" }, { name = "httpx", specifier = ">=0.27" }, { name = "huggingface-hub", extras = ["inference"], marker = "extra == 'huggingface'", specifier = ">=0.33.5" }, + { name = "mcp", marker = "python_full_version >= '3.10' and extra == 'mcp'", specifier = ">=1.12.3" }, { name = "logfire", marker = "extra == 'logfire'", specifier = ">=3.14.1" }, - { name = "mcp", marker = "python_full_version >= '3.10' and extra == 'mcp'", specifier = ">=1.10.0" }, { name = "mistralai", marker = "extra == 'mistral'", specifier = ">=1.9.2" }, { name = "openai", marker = "extra == 'openai'", specifier = ">=1.92.0" }, { name = "opentelemetry-api", specifier = ">=1.28.0" },