Skip to content
62 changes: 43 additions & 19 deletions pydantic_ai_slim/pydantic_ai/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,12 +490,23 @@ async def direct_call_tool(
):
# The MCP SDK wraps primitives and generic types like list in a `result` key, but we want to use the raw value returned by the tool function.
# See https://github.com/modelcontextprotocol/python-sdk#structured-output
return_value = structured
if isinstance(structured, dict) and len(structured) == 1 and 'result' in structured:
return structured['result']
return structured
return_value = structured['result']
return messages.ToolReturn(return_value=return_value, metadata=result.meta) if result.meta else return_value

mapped = [await self._map_tool_result_part(part) for part in result.content]
return mapped[0] if len(mapped) == 1 else mapped
if result.meta:
# The following branching cannot be tested until FastMCP is updated to version 2.13.1
# such that the MCP server can generate ToolResult and result.meta can be specified.
# TODO: Add tests for the following branching once FastMCP is updated.
return ( # pragma: no cover
messages.ToolReturn(return_value=mapped[0], metadata=result.meta)
if len(mapped) == 1
else messages.ToolReturn(return_value=mapped, metadata=result.meta)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will cause us to run into trouble here, because we're putting BinaryContent inside ToolReturn.return_value:

if (
isinstance(tool_return.return_value, _messages.MultiModalContent)
or isinstance(tool_return.return_value, list)
and any(
isinstance(content, _messages.MultiModalContent)
for content in tool_return.return_value # type: ignore
)
):
raise exceptions.UserError(
f'The `return_value` of tool {tool_call.tool_name!r} contains invalid nested `MultiModalContent` objects. '
f'Please use `content` instead.'
)

So we should make that logic less strict, so that if it finds a ToolReturn.return_value with BinaryContent in it, it'll go through this same logic to move them to the content field:

return_values: list[Any] = []
user_contents: list[str | _messages.UserContent] = []
for content in contents:
if isinstance(content, _messages.ToolReturn):
raise exceptions.UserError(
f'The return value of tool {tool_call.tool_name!r} contains invalid nested `ToolReturn` objects. '
f'`ToolReturn` should be used directly.'
)
elif isinstance(content, _messages.MultiModalContent):
identifier = content.identifier
return_values.append(f'See file {identifier}')
user_contents.extend([f'This is file {identifier}:', content])
else:
return_values.append(content)
tool_return = _messages.ToolReturn(
return_value=return_values[0] if len(return_values) == 1 and not result_is_list else return_values,
content=user_contents,
)

(Although part of me is thinking that maybe I should do #3253 first, because that's going to affect so much of this logic and likely make it easier, as we can keep binary parts on ToolReturn.return_value without having to move them... If you can wait with this PR, I could maybe look at that one first?)

We should also update the tests to not just test direct_call_tool, but actually test using these tools in an agent run, to ensure we go through that part of the agent graph tool call result parsing code.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, maybe I should wait for #3253 to be completed first?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also update the tests to not just test direct_call_tool, but actually test using these tools in an agent run, to ensure we go through that part of the agent graph tool call result parsing code.

Okay, I will get to these.

)
else:
return mapped[0] if len(mapped) == 1 else mapped
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the code here can be deduplicated to only do the if result.meta: ToolReturn branch ones

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


async def call_tool(
self,
Expand Down Expand Up @@ -574,16 +585,18 @@ async def list_resource_templates(self) -> list[ResourceTemplate]:
return [ResourceTemplate.from_mcp_sdk(t) for t in result.resourceTemplates]

@overload
async def read_resource(self, uri: str) -> str | messages.BinaryContent | list[str | messages.BinaryContent]: ...
async def read_resource(
self, uri: str
) -> str | messages.TextPart | messages.BinaryContent | list[str | messages.TextPart | messages.BinaryContent]: ...

@overload
async def read_resource(
self, uri: Resource
) -> str | messages.BinaryContent | list[str | messages.BinaryContent]: ...
) -> str | messages.TextPart | messages.BinaryContent | list[str | messages.TextPart | messages.BinaryContent]: ...

async def read_resource(
self, uri: str | Resource
) -> str | messages.BinaryContent | list[str | messages.BinaryContent]:
) -> str | messages.TextPart | messages.BinaryContent | list[str | messages.TextPart | messages.BinaryContent]:
"""Read the contents of a specific resource by URI.

Args:
Expand Down Expand Up @@ -682,24 +695,29 @@ async def _sampling_callback(

async def _map_tool_result_part(
self, part: mcp_types.ContentBlock
) -> str | messages.BinaryContent | dict[str, Any] | list[Any]:
) -> str | messages.TextPart | messages.BinaryContent | dict[str, Any] | list[Any]:
# See https://github.com/jlowin/fastmcp/blob/main/docs/servers/tools.mdx#return-values

if isinstance(part, mcp_types.TextContent):
text = part.text
if text.startswith(('[', '{')):
try:
return pydantic_core.from_json(text)
except ValueError:
pass
return text
if part.meta:
return messages.TextPart(content=text, metadata=part.meta)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we'll need a new messages.TextContent object, as the way we're using TextPart here is incompatible with the """A plain text response from a model.""" docstring, and we do try to distinguish between parts and content.

Then we should add TextContent to the UserContent union, so that it's allowed in place of str.

And we should update the logic here to treat UserContent just like str in terms of what's sent back to the model, i.e. not serializing the whole object which would include the metadata:

def model_response_str(self) -> str:
"""Return a string representation of the content for the model."""
if isinstance(self.content, str):
return self.content
else:
return tool_return_ta.dump_json(self.content).decode()
def model_response_object(self) -> dict[str, Any]:
"""Return a dictionary representation of the content, wrapping non-dict types appropriately."""
# gemini supports JSON dict return values, but no other JSON types, hence we wrap anything else in a dict
json_content = tool_return_ta.dump_python(self.content, mode='json')
if isinstance(json_content, dict):
return json_content # type: ignore[reportUnknownReturn]
else:
return {'return_value': json_content}

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I see. Then should messages.TextContent and messages.TextPartlook exactly the same except thatmessages.TextContentcan havemetadataand it should not haveid` indicating an identifier for the part?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a messages.TextContent, which is a bit of a problem with imports as TextContent because it conflicts with MCP types TextContent. Perhaps, a different name, such as messages.TextData is better unless we always import it as messages.TextContent or alias an import such as from pydantic_ai.messages import TextContent as PydanticAITextContent, which looks ugly.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we'll need a new messages.TextContent object, as the way we're using TextPart here is incompatible with the """A plain text response from a model.""" docstring, and we do try to distinguish between parts and content.

Then we should add TextContent to the UserContent union, so that it's allowed in place of str.

And we should update the logic here to treat UserContent just like str in terms of what's sent back to the model, i.e. not serializing the whole object which would include the metadata:

def model_response_str(self) -> str:
"""Return a string representation of the content for the model."""
if isinstance(self.content, str):
return self.content
else:
return tool_return_ta.dump_json(self.content).decode()
def model_response_object(self) -> dict[str, Any]:
"""Return a dictionary representation of the content, wrapping non-dict types appropriately."""
# gemini supports JSON dict return values, but no other JSON types, hence we wrap anything else in a dict
json_content = tool_return_ta.dump_python(self.content, mode='json')
if isinstance(json_content, dict):
return json_content # type: ignore[reportUnknownReturn]
else:
return {'return_value': json_content}

For this part, isn't if isinstance(self.content, str) already okay for my newly added messages.TextContent since it has a content, which is of type str?

else:
if text.startswith(('[', '{')):
try:
return pydantic_core.from_json(text)
except ValueError:
pass
return text
elif isinstance(part, mcp_types.ImageContent):
return messages.BinaryContent(data=base64.b64decode(part.data), media_type=part.mimeType)
return messages.BinaryContent(
data=base64.b64decode(part.data), media_type=part.mimeType, metadata=part.meta
)
elif isinstance(part, mcp_types.AudioContent):
# NOTE: The FastMCP server doesn't support audio content.
# See <https://github.com/modelcontextprotocol/python-sdk/issues/952> for more details.
return messages.BinaryContent(
data=base64.b64decode(part.data), media_type=part.mimeType
data=base64.b64decode(part.data), media_type=part.mimeType, metadata=part.meta
) # pragma: no cover
elif isinstance(part, mcp_types.EmbeddedResource):
resource = part.resource
Expand All @@ -711,12 +729,16 @@ async def _map_tool_result_part(

def _get_content(
self, resource: mcp_types.TextResourceContents | mcp_types.BlobResourceContents
) -> str | messages.BinaryContent:
) -> str | messages.TextPart | messages.BinaryContent:
if isinstance(resource, mcp_types.TextResourceContents):
return resource.text
return (
resource.text if not resource.meta else messages.TextPart(content=resource.text, metadata=resource.meta)
)
elif isinstance(resource, mcp_types.BlobResourceContents):
return messages.BinaryContent(
data=base64.b64decode(resource.blob), media_type=resource.mimeType or 'application/octet-stream'
data=base64.b64decode(resource.blob),
media_type=resource.mimeType or 'application/octet-stream',
metadata=resource.meta,
)
else:
assert_never(resource)
Expand Down Expand Up @@ -1178,10 +1200,12 @@ def __eq__(self, value: object, /) -> bool:

ToolResult = (
str
| messages.TextPart
| messages.BinaryContent
| messages.ToolReturn
| dict[str, Any]
| list[Any]
| Sequence[str | messages.BinaryContent | dict[str, Any] | list[Any]]
| Sequence[str | messages.TextPart | messages.BinaryContent | dict[str, Any] | list[Any]]
)
"""The result type of an MCP tool call."""

Expand Down
17 changes: 16 additions & 1 deletion pydantic_ai_slim/pydantic_ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,9 @@ class BinaryContent:
- `OpenAIChatModel`, `OpenAIResponsesModel`: `BinaryContent.vendor_metadata['detail']` is used as `detail` setting for images
"""

metadata: Any = None
"""Additional data that can be accessed programmatically by the application but is not sent to the LLM."""

_identifier: Annotated[str | None, pydantic.Field(alias='identifier', default=None, exclude=True)] = field(
compare=False, default=None
)
Expand All @@ -500,6 +503,7 @@ def __init__(
media_type: AudioMediaType | ImageMediaType | DocumentMediaType | str,
identifier: str | None = None,
vendor_metadata: dict[str, Any] | None = None,
metadata: Any = None,
kind: Literal['binary'] = 'binary',
# Required for inline-snapshot which expects all dataclass `__init__` methods to take all field names as kwargs.
_identifier: str | None = None,
Expand All @@ -508,6 +512,7 @@ def __init__(
self.media_type = media_type
self._identifier = identifier or _identifier
self.vendor_metadata = vendor_metadata
self.metadata = metadata
self.kind = kind

@staticmethod
Expand All @@ -519,6 +524,7 @@ def narrow_type(bc: BinaryContent) -> BinaryContent | BinaryImage:
media_type=bc.media_type,
identifier=bc.identifier,
vendor_metadata=bc.vendor_metadata,
metadata=bc.metadata,
)
else:
return bc
Expand Down Expand Up @@ -622,11 +628,17 @@ def __init__(
identifier: str | None = None,
vendor_metadata: dict[str, Any] | None = None,
# Required for inline-snapshot which expects all dataclass `__init__` methods to take all field names as kwargs.
metadata: Any = None,
kind: Literal['binary'] = 'binary',
_identifier: str | None = None,
):
super().__init__(
data=data, media_type=media_type, identifier=identifier or _identifier, vendor_metadata=vendor_metadata
data=data,
media_type=media_type,
identifier=identifier or _identifier,
vendor_metadata=vendor_metadata,
metadata=metadata,
kind=kind,
)

if not self.is_image:
Expand Down Expand Up @@ -1031,6 +1043,9 @@ class TextPart:

This is used for data that is required to be sent back to APIs, as well as data users may want to access programmatically."""

metadata: Any = None
"""Additional data that can be accessed programmatically by the application but is not sent to the LLM."""

part_kind: Literal['text'] = 'text'
"""Part type identifier, this is available on all parts as a discriminator."""

Expand Down
30 changes: 30 additions & 0 deletions tests/mcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,36 @@ async def get_weather_forecast(location: str) -> str:
return f'The weather in {location} is sunny and 26 degrees Celsius.'


@mcp.tool(structured_output=False, annotations=ToolAnnotations(title='Collatz Conjecture sequence generator'))
async def get_collatz_conjecture(n: int) -> TextContent:
"""Generate the Collatz conjecture sequence for a given number.
This tool attaches response metadata.

Args:
n: The starting number for the Collatz sequence.
Returns:
A list representing the Collatz sequence with attached metadata.
"""
if n <= 0:
raise ValueError('Starting number for the Collatz conjecture must be a positive integer.')

input_param_n = n # store the original input value

sequence = [n]
while n != 1:
if n % 2 == 0:
n = n // 2
else:
n = 3 * n + 1
sequence.append(n)

return TextContent(
type='text',
text=str(sequence),
_meta={'pydantic_ai': {'tool': 'collatz_conjecture', 'n': input_param_n, 'length': len(sequence)}},
)


@mcp.tool()
async def get_image_resource() -> EmbeddedResource:
data = Path(__file__).parent.joinpath('assets/kiwi.png').read_bytes()
Expand Down
26 changes: 23 additions & 3 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3785,6 +3785,7 @@ def test_binary_content_serializable():
'data': 'SGVsbG8=',
'media_type': 'text/plain',
'vendor_metadata': None,
'metadata': None,
'kind': 'binary',
'identifier': 'f7ff9e',
},
Expand All @@ -3800,7 +3801,13 @@ def test_binary_content_serializable():
},
{
'parts': [
{'content': 'success (no tool calls)', 'id': None, 'part_kind': 'text', 'provider_details': None}
{
'content': 'success (no tool calls)',
'id': None,
'part_kind': 'text',
'metadata': None,
'provider_details': None,
}
],
'usage': {
'input_tokens': 56,
Expand Down Expand Up @@ -3862,7 +3869,13 @@ def test_image_url_serializable_missing_media_type():
},
{
'parts': [
{'content': 'success (no tool calls)', 'id': None, 'part_kind': 'text', 'provider_details': None}
{
'content': 'success (no tool calls)',
'id': None,
'part_kind': 'text',
'metadata': None,
'provider_details': None,
}
],
'usage': {
'input_tokens': 51,
Expand Down Expand Up @@ -3931,7 +3944,13 @@ def test_image_url_serializable():
},
{
'parts': [
{'content': 'success (no tool calls)', 'id': None, 'part_kind': 'text', 'provider_details': None}
{
'content': 'success (no tool calls)',
'id': None,
'part_kind': 'text',
'metadata': None,
'provider_details': None,
}
],
'usage': {
'input_tokens': 51,
Expand Down Expand Up @@ -3978,6 +3997,7 @@ def test_tool_return_part_binary_content_serialization():
'data': 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzgAAAAASUVORK5CYII=',
'media_type': 'image/png',
'vendor_metadata': None,
'metadata': None,
'_identifier': None,
'kind': 'binary',
}
Expand Down
19 changes: 17 additions & 2 deletions tests/test_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ async def test_stdio_server(run_context: RunContext[int]):
server = MCPServerStdio('python', ['-m', 'tests.mcp_server'])
async with server:
tools = [tool.tool_def for tool in (await server.get_tools(run_context)).values()]
assert len(tools) == snapshot(18)
assert len(tools) == snapshot(19)
assert tools[0].name == 'celsius_to_fahrenheit'
assert isinstance(tools[0].description, str)
assert tools[0].description.startswith('Convert Celsius to Fahrenheit.')
Expand All @@ -105,6 +105,21 @@ async def test_stdio_server(run_context: RunContext[int]):
assert result == snapshot(32.0)


async def test_tool_response_single_text_part_metadata(run_context: RunContext[int]):
server = MCPServerStdio('python', ['-m', 'tests.mcp_server'])
async with server:
tools = [tool.tool_def for tool in (await server.get_tools(run_context)).values()]
assert len(tools) == snapshot(19)
assert tools[2].name == 'get_collatz_conjecture'
assert isinstance(tools[2].description, str)
assert tools[2].description.startswith('Generate the Collatz conjecture sequence for a given number.')

result = await server.direct_call_tool('get_collatz_conjecture', {'n': 7})
assert isinstance(result, TextPart)
assert result.content == snapshot('[7, 22, 11, 34, 17, 52, 26, 13, 40, 20, 10, 5, 16, 8, 4, 2, 1]')
assert result.metadata == snapshot({'pydantic_ai': {'tool': 'collatz_conjecture', 'n': 7, 'length': 17}})


async def test_reentrant_context_manager():
server = MCPServerStdio('python', ['-m', 'tests.mcp_server'])
async with server:
Expand Down Expand Up @@ -156,7 +171,7 @@ async def test_stdio_server_with_cwd(run_context: RunContext[int]):
server = MCPServerStdio('python', ['mcp_server.py'], cwd=test_dir)
async with server:
tools = await server.get_tools(run_context)
assert len(tools) == snapshot(18)
assert len(tools) == snapshot(19)


async def test_process_tool_call(run_context: RunContext[int]) -> int:
Expand Down
3 changes: 3 additions & 0 deletions tests/test_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,7 @@ def test_file_part_serialization_roundtrip():
'data': 'ZmFrZQ==',
'media_type': 'image/jpeg',
'identifier': 'c053ec',
'metadata': None,
'vendor_metadata': None,
'kind': 'binary',
},
Expand Down Expand Up @@ -605,6 +606,7 @@ def test_binary_content_validation_with_optional_identifier():
'data': b'fake',
'vendor_metadata': None,
'kind': 'binary',
'metadata': None,
'media_type': 'image/jpeg',
'identifier': 'c053ec',
}
Expand All @@ -621,6 +623,7 @@ def test_binary_content_validation_with_optional_identifier():
'data': b'fake',
'vendor_metadata': None,
'kind': 'binary',
'metadata': None,
'media_type': 'image/png',
'identifier': 'foo',
}
Expand Down