Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 51 additions & 1 deletion pydantic_ai_slim/pydantic_ai/_mcp.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
import base64
from collections.abc import Sequence
from typing import Literal
from typing import Any, Literal, cast

import logfire
from pydantic.alias_generators import to_snake

from pydantic_ai.agent.abstract import AbstractAgent

from . import exceptions, messages
from .agent import AgentDepsT, OutputDataT

try:
from mcp import types as mcp_types
from mcp.server.lowlevel.server import Server, StructuredContent
from mcp.types import Tool
except ImportError as _import_error:
raise ImportError(
'Please install the `mcp` package to use the MCP server, '
Expand Down Expand Up @@ -121,3 +129,45 @@ def map_from_sampling_content(
return messages.TextPart(content=content.text)
else:
raise NotImplementedError('Image and Audio responses in sampling are not yet supported')


def agent_to_mcp(
agent: AbstractAgent[AgentDepsT, OutputDataT],
*,
server_name: str | None = None,
tool_name: str | None = None,
tool_description: str | None = None,
# TODO(Marcelo): Should this actually be a factory that is created in every tool call?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a union of static deps and a deps factory makes sense, if the deps factory would get the tool call _meta.

deps: AgentDepsT = None,
) -> Server:
server_name = to_snake((server_name or agent.name or 'PydanticAI Agent').replace(' ', '_'))
tool_name = to_snake((tool_name or agent.name or 'PydanticAI Agent').replace(' ', '_'))
app = Server(name=server_name)

async def list_tools() -> list[Tool]:
return [
Tool(
name=tool_name,
description=tool_description,
inputSchema={'type': 'object', 'properties': {'prompt': {'type': 'string'}}},
# TODO(Marcelo): How do I get this?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's not currently a nice way to get this, but it'd be useful to have a new output_json_schema property on AbstractAgent.

In the case of the concrete Agent, it would depend on agent._output_schema:

  • if StructuredTextOutputSchema (superclass of NativeOutputSchema and PromptedOutputSchema), get it from .object_def.json_schema
  • if it's OutputSchemaWithoutMode, get it from .processor.object_def.json_schema
  • if it's PlainTextOutputSchema, it's just {'type': 'string'}
  • if it's ToolOutputSchema, we need to create a union schema of all .toolset.processors using UnionOutputProcessor, which currently takes a sequence of output types and creates ObjectOutputProcessors for them on the fly, but could also take a sequence of ObjectOutputProcessors (the ones we get from .toolset.processors) directly. Once we have that UnionOutputProcessor, the schema is on .object_def.json_schema
  • if it's ToolOrTextOutputSchema, it's the above or {'type': 'string'}

Note that this changes a bit with some refactoring I did in #2970, but directionally it's the same: there's not currently a nice way to get this, and it's especially tricky for tool output, because we don't have a union of all types handy.

I should be able to implement this pretty quickly through, once that images PR with the output types refactor merges.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this PR wait for it then?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll wait for it then.

outputSchema={'type': 'object', 'properties': {}},
)
]

async def call_tool(name: str, args: dict[str, Any]) -> StructuredContent:
if name != tool_name:
raise ValueError(f'Unknown tool: {name}')

# TODO(Marcelo): Should we pass the `message_history` instead?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think just the prompt is fine, when would the LLM generate an entire message history?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, I think the point is that we need to maintain the history in the session...

Good point!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may need to create a database abstraction here. 🤔

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure the tool should be stateful like that? If it's essentially a subagent, wouldn't multiple calls be expected to start separate subagents? I think continuing the conversation should be explicit, with some conversation ID returned and passed in

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, if the client wants to create a new conversation, they can open a new session.

I think continuing the conversation should be explicit, with some conversation ID returned and passed in

The MCP spec handles this with a session ID.

prompt = cast(str, args['prompt'])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use a typed dict for args so we don't have to cast?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, that would be incorrect... What I actually need to check if 'prompt' is in args, and check that it's a string.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd expect the library to validate the args match the type hint, no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explained via Slack - answering here: no.

logfire.info(f'Calling tool: {name} with args: {args}')

result = await agent.run(user_prompt=prompt, deps=deps)

return dict(result=result.output)

app.list_tools()(list_tools)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These could be decorators right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Decorators inside a function are treated as misused type-wise.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lame

app.call_tool()(call_tool)

return app
39 changes: 26 additions & 13 deletions pydantic_ai_slim/pydantic_ai/agent/abstract.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,19 @@
from __future__ import annotations as _annotations

import inspect
import warnings
from abc import ABC, abstractmethod
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable, Iterator, Mapping, Sequence
from contextlib import AbstractAsyncContextManager, asynccontextmanager, contextmanager
from types import FrameType
from typing import TYPE_CHECKING, Any, Generic, TypeAlias, cast, overload

from typing_extensions import Self, TypeIs, TypeVar
from typing_extensions import TypeIs, TypeVar

from pydantic_graph import End
from pydantic_graph._utils import get_event_loop

from .. import (
_agent_graph,
_system_prompt,
_utils,
exceptions,
messages as _messages,
models,
result,
usage as _usage,
)
from .. import _agent_graph, _system_prompt, _utils, exceptions, messages as _messages, models, result, usage as _usage
from .._tool_manager import ToolManager
from ..output import OutputDataT, OutputSpec
from ..result import AgentStream, FinalResult, StreamedRunResult
Expand All @@ -42,6 +34,7 @@
from fasta2a.broker import Broker
from fasta2a.schema import AgentProvider, Skill
from fasta2a.storage import Storage
from mcp.server.lowlevel import Server
from starlette.middleware import Middleware
from starlette.routing import BaseRoute, Route
from starlette.types import ExceptionHandler, Lifespan
Expand Down Expand Up @@ -940,8 +933,28 @@ def to_a2a(
lifespan=lifespan,
)

def to_mcp(
self,
*,
server_name: str | None = None,
tool_name: str | None = None,
tool_description: str | None = None,
deps: AgentDepsT = None,
) -> Server:
from .._mcp import agent_to_mcp

warnings.warn('The `to_mcp` method is experimental, and may change in the future.', UserWarning)

return agent_to_mcp(
self,
server_name=server_name,
tool_name=tool_name,
tool_description=tool_description,
deps=deps,
)

async def to_cli(
self: Self,
self,
deps: AgentDepsT = None,
prog_name: str = 'pydantic-ai',
message_history: list[_messages.ModelMessage] | None = None,
Expand Down Expand Up @@ -978,7 +991,7 @@ async def main():
)

def to_cli_sync(
self: Self,
self,
deps: AgentDepsT = None,
prog_name: str = 'pydantic-ai',
message_history: list[_messages.ModelMessage] | None = None,
Expand Down
Loading