Skip to content

Commit 0f5e987

Browse files
committed
Add Agent.to_mcp() method
1 parent 51fec9f commit 0f5e987

File tree

2 files changed

+77
-14
lines changed

2 files changed

+77
-14
lines changed

pydantic_ai_slim/pydantic_ai/_mcp.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,19 @@
11
import base64
22
from collections.abc import Sequence
3-
from typing import Literal
3+
from typing import Any, Literal, cast
4+
5+
import logfire
6+
from pydantic.alias_generators import to_snake
7+
8+
from pydantic_ai.agent.abstract import AbstractAgent
49

510
from . import exceptions, messages
11+
from .agent import AgentDepsT, OutputDataT
612

713
try:
814
from mcp import types as mcp_types
15+
from mcp.server.lowlevel.server import Server, StructuredContent
16+
from mcp.types import Tool
917
except ImportError as _import_error:
1018
raise ImportError(
1119
'Please install the `mcp` package to use the MCP server, '
@@ -121,3 +129,45 @@ def map_from_sampling_content(
121129
return messages.TextPart(content=content.text)
122130
else:
123131
raise NotImplementedError('Image and Audio responses in sampling are not yet supported')
132+
133+
134+
def agent_to_mcp(
135+
agent: AbstractAgent[AgentDepsT, OutputDataT],
136+
*,
137+
server_name: str | None = None,
138+
tool_name: str | None = None,
139+
tool_description: str | None = None,
140+
# TODO(Marcelo): Should this actually be a factory that is created in every tool call?
141+
deps: AgentDepsT = None,
142+
) -> Server:
143+
server_name = to_snake((server_name or agent.name or 'PydanticAI Agent').replace(' ', '_'))
144+
tool_name = to_snake((tool_name or agent.name or 'PydanticAI Agent').replace(' ', '_'))
145+
app = Server(name=server_name)
146+
147+
async def list_tools() -> list[Tool]:
148+
return [
149+
Tool(
150+
name=tool_name,
151+
description=tool_description,
152+
inputSchema={'type': 'object', 'properties': {'prompt': {'type': 'string'}}},
153+
# TODO(Marcelo): How do I get this?
154+
outputSchema={'type': 'object', 'properties': {}},
155+
)
156+
]
157+
158+
async def call_tool(name: str, args: dict[str, Any]) -> StructuredContent:
159+
if name != tool_name:
160+
raise ValueError(f'Unknown tool: {name}')
161+
162+
# TODO(Marcelo): Should we pass the `message_history` instead?
163+
prompt = cast(str, args['prompt'])
164+
logfire.info(f'Calling tool: {name} with args: {args}')
165+
166+
result = await agent.run(user_prompt=prompt, deps=deps)
167+
168+
return dict(result=result.output)
169+
170+
app.list_tools()(list_tools)
171+
app.call_tool()(call_tool)
172+
173+
return app

pydantic_ai_slim/pydantic_ai/agent/abstract.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,19 @@
11
from __future__ import annotations as _annotations
22

33
import inspect
4+
import warnings
45
from abc import ABC, abstractmethod
56
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable, Iterator, Mapping, Sequence
67
from contextlib import AbstractAsyncContextManager, asynccontextmanager, contextmanager
78
from types import FrameType
89
from typing import TYPE_CHECKING, Any, Generic, TypeAlias, cast, overload
910

10-
from typing_extensions import Self, TypeIs, TypeVar
11+
from typing_extensions import TypeIs, TypeVar
1112

1213
from pydantic_graph import End
1314
from pydantic_graph._utils import get_event_loop
1415

15-
from .. import (
16-
_agent_graph,
17-
_system_prompt,
18-
_utils,
19-
exceptions,
20-
messages as _messages,
21-
models,
22-
result,
23-
usage as _usage,
24-
)
16+
from .. import _agent_graph, _system_prompt, _utils, exceptions, messages as _messages, models, result, usage as _usage
2517
from .._tool_manager import ToolManager
2618
from ..output import OutputDataT, OutputSpec
2719
from ..result import AgentStream, FinalResult, StreamedRunResult
@@ -42,6 +34,7 @@
4234
from fasta2a.broker import Broker
4335
from fasta2a.schema import AgentProvider, Skill
4436
from fasta2a.storage import Storage
37+
from mcp.server.lowlevel import Server
4538
from starlette.middleware import Middleware
4639
from starlette.routing import BaseRoute, Route
4740
from starlette.types import ExceptionHandler, Lifespan
@@ -940,8 +933,28 @@ def to_a2a(
940933
lifespan=lifespan,
941934
)
942935

936+
def to_mcp(
937+
self,
938+
*,
939+
server_name: str | None = None,
940+
tool_name: str | None = None,
941+
tool_description: str | None = None,
942+
deps: AgentDepsT = None,
943+
) -> Server:
944+
from .._mcp import agent_to_mcp
945+
946+
warnings.warn('The `to_mcp` method is experimental, and may change in the future.', UserWarning)
947+
948+
return agent_to_mcp(
949+
self,
950+
server_name=server_name,
951+
tool_name=tool_name,
952+
tool_description=tool_description,
953+
deps=deps,
954+
)
955+
943956
async def to_cli(
944-
self: Self,
957+
self,
945958
deps: AgentDepsT = None,
946959
prog_name: str = 'pydantic-ai',
947960
message_history: list[_messages.ModelMessage] | None = None,
@@ -978,7 +991,7 @@ async def main():
978991
)
979992

980993
def to_cli_sync(
981-
self: Self,
994+
self,
982995
deps: AgentDepsT = None,
983996
prog_name: str = 'pydantic-ai',
984997
message_history: list[_messages.ModelMessage] | None = None,

0 commit comments

Comments
 (0)