Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 85 additions & 19 deletions pydantic_ai_slim/pydantic_ai/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,11 +255,79 @@ async def direct_call_tool(
# The MCP SDK wraps primitives and generic types like list in a `result` key, but we want to use the raw value returned by the tool function.
# See https://github.com/modelcontextprotocol/python-sdk#structured-output
if isinstance(structured, dict) and len(structured) == 1 and 'result' in structured:
return structured['result']
return structured
return (
messages.ToolReturn(return_value=structured['result'], metadata=result.meta)
if result.meta
else structured['result']
)
return messages.ToolReturn(return_value=structured, metadata=result.meta) if result.meta else structured
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This and the above can be simplified to only build ToolReturn once


mapped = [await self._map_tool_result_part(part) for part in result.content]
return mapped[0] if len(mapped) == 1 else mapped
mapped_part_metadata_tuple_list = [await self._map_tool_result_part(part) for part in result.content]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's call this parts_with_metadata

if (
all(mapped_part_metadata_tuple[1] is None for mapped_part_metadata_tuple in mapped_part_metadata_tuple_list)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can be all(metadata is None for (part, metadata) in mapped_part_metadata_tuple_list). Let's use tuple destructuring everywhere instead of index-based reading.

Also let's move this to a parts_have_metadata = any(...) variable (inverted to be positive) as we use it multiple times (e.g. if not parts_have_metadata and result.metadata is None reads pretty clearly)

and result.meta is None
):
# There is no metadata in the tool result or its parts, return just the mapped values
return (
mapped_part_metadata_tuple_list[0][0]
if len(mapped_part_metadata_tuple_list[0]) == 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be len(mapped_part_metadata_tuple_list) == 1?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really don't like [i] for tuples, and we already know we have no part metadata here, so let's make this something like:

parts = [part for (part, _) in parts_and_metadata]
# use parts as we did before

else [mapped_part_metadata_tuple[0] for mapped_part_metadata_tuple in mapped_part_metadata_tuple_list]
)
elif (
all(mapped_part_metadata_tuple[1] is None for mapped_part_metadata_tuple in mapped_part_metadata_tuple_list)
and result.meta is not None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if not parts_have_metadata and result.meta is not None

):
# There is no metadata in the tool result parts, but there is metadata in the tool result
return messages.ToolReturn(
return_value=(
mapped_part_metadata_tuple_list[0][0]
if len(mapped_part_metadata_tuple_list[0]) == 1
else [
mapped_part_metadata_tuple[0] for mapped_part_metadata_tuple in mapped_part_metadata_tuple_list
]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as up; and we can likely also deduplicate this into a single if not parts_have_metadata branch that only does the parts = [part for (part, _) in parts_and_metadata] and parts[0] if len(parts) == 1 else parts thing once

),
metadata=result.meta,
)
else:
# There is metadata in the tool result parts and there may be a metadata in the tool result, return a ToolReturn object
return_values: list[Any] = []
user_contents: list[Any] = []
return_metadata: dict[str, Any] = {}
for idx, (mapped_part, part_metadata) in enumerate(mapped_part_metadata_tuple_list):
if part_metadata is not None:
# Merge the metadata dictionaries, with part metadata taking precedence
if return_metadata.get('content', None) is None:
# Create an empty list if it doesn't exist yet
return_metadata['content'] = list[dict[str, Any]]()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can likely use return_metadata.setdefault('content', []) just once before the for

return_metadata['content'].append({str(idx): part_metadata})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A list of one-item dicts feels a bit awkward to me; it makes it a lot harder to find the metadata for a specific item then if this were a dict with int indexes or a list with holes. dict[int, dict[str, Any]] would probably be best.

if isinstance(mapped_part, messages.BinaryContent):
identifier = mapped_part.identifier

return_values.append(f'See file {identifier}')
user_contents.append([f'This is file {identifier}:', mapped_part])
else:
user_contents.append(mapped_part)
Comment on lines +303 to +309
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please put in a comment that this should be kept up to date with this code:

elif isinstance(content, _messages.MultiModalContent):
identifier = content.identifier
return_values.append(f'See file {identifier}')
user_contents.extend([f'This is file {identifier}:', content])
else:
return_values.append(content)

And put a comment there pointing back at this code.

I'd like to deduplicate it, but I have some more work to do in this area for #3253 so it'll be easier to just come up with a cleaner approach then.


if result.meta is not None and return_metadata.get('content', None) is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not store things in return_metadata['content'] yet above, and instead just have a parts_metadata var. Then once we get here, we can determine what the final return_metadata should be by checking result.meta and parts_metadata and building an appropriate dict from those.

# Merge the tool result metadata into the return metadata, with part metadata taking precedence
return_metadata['result'] = result.meta
elif result.meta is not None and return_metadata.get('content', None) is None:
return_metadata = result.meta
elif (
result.meta is None
and return_metadata.get('content', None) is not None
and len(return_metadata['content']) == 1
):
# If there is only one content metadata, unwrap it
return_metadata = return_metadata['content'][0]
# TODO: What else should we cover here?

# Finally, construct and return the ToolReturn object
return messages.ToolReturn(
return_value=return_values,
content=user_contents,
metadata=return_metadata,
)

async def call_tool(
self,
Expand Down Expand Up @@ -374,35 +442,32 @@ async def _sampling_callback(

async def _map_tool_result_part(
self, part: mcp_types.ContentBlock
) -> str | messages.BinaryContent | dict[str, Any] | list[Any]:
) -> tuple[str | messages.BinaryContent | dict[str, Any] | list[Any], dict[str, Any] | None]:
# See https://github.com/jlowin/fastmcp/blob/main/docs/servers/tools.mdx#return-values

metadata: dict[str, Any] | None = part.meta
if isinstance(part, mcp_types.TextContent):
text = part.text
if text.startswith(('[', '{')):
try:
return pydantic_core.from_json(text)
return pydantic_core.from_json(text), metadata
except ValueError:
pass
return text
return text, metadata
elif isinstance(part, mcp_types.ImageContent):
return messages.BinaryContent(data=base64.b64decode(part.data), media_type=part.mimeType)
elif isinstance(part, mcp_types.AudioContent):
return messages.BinaryContent(data=base64.b64decode(part.data), media_type=part.mimeType), metadata
elif isinstance(part, mcp_types.AudioContent): # pragma: no cover
# NOTE: The FastMCP server doesn't support audio content.
# See <https://github.com/modelcontextprotocol/python-sdk/issues/952> for more details.
return messages.BinaryContent(
data=base64.b64decode(part.data), media_type=part.mimeType
) # pragma: no cover
return messages.BinaryContent(data=base64.b64decode(part.data), media_type=part.mimeType), metadata
elif isinstance(part, mcp_types.EmbeddedResource):
resource = part.resource
return self._get_content(resource)
return self._get_content(part.resource), metadata
elif isinstance(part, mcp_types.ResourceLink):
resource_result: mcp_types.ReadResourceResult = await self._client.read_resource(part.uri)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that resource_result can also have its own meta, and so can each of its contents 😅 So we may need to build a single nested metadata dict from all of those

return (
self._get_content(resource_result.contents[0])
if len(resource_result.contents) == 1
else [self._get_content(resource) for resource in resource_result.contents]
)
if len(resource_result.contents) == 1:
return self._get_content(resource_result.contents[0]), metadata
else:
return [self._get_content(resource) for resource in resource_result.contents], metadata
else:
assert_never(part)

Expand Down Expand Up @@ -875,6 +940,7 @@ def __eq__(self, value: object, /) -> bool:
ToolResult = (
str
| messages.BinaryContent
| messages.ToolReturn
| dict[str, Any]
| list[Any]
| Sequence[str | messages.BinaryContent | dict[str, Any] | list[Any]]
Expand Down
39 changes: 39 additions & 0 deletions tests/mcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,45 @@ async def get_image_resource_link() -> ResourceLink:
)


@mcp.tool(structured_output=False, annotations=ToolAnnotations(title='Collatz Conjecture sequence generator'))
async def get_collatz_conjecture(n: int) -> TextContent:
"""Generate the Collatz conjecture sequence for a given number.
This tool attaches response metadata.

Args:
n: The starting number for the Collatz sequence.
Returns:
A list representing the Collatz sequence with attached metadata.
"""
if n <= 0:
raise ValueError('Starting number for the Collatz conjecture must be a positive integer.')

input_param_n = n # store the original input value

sequence = [n]
while n != 1:
if n % 2 == 0:
n = n // 2
else:
n = 3 * n + 1
sequence.append(n)

return TextContent(
type='text',
text=str(sequence),
_meta={'pydantic_ai': {'tool': 'collatz_conjecture', 'n': input_param_n, 'length': len(sequence)}},
)


@mcp.tool()
async def get_structured_text_content_with_metadata() -> dict[str, Any]:
"""Return structured dict with metadata."""
return {
'result': 'This is some text content.',
'_meta': {'pydantic_ai': {'source': 'get_structured_text_content_with_metadata'}},
}


@mcp.resource('resource://kiwi.png', mime_type='image/png')
async def kiwi_resource() -> bytes:
return Path(__file__).parent.joinpath('assets/kiwi.png').read_bytes()
Expand Down
35 changes: 33 additions & 2 deletions tests/test_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from pydantic_ai.agent import Agent
from pydantic_ai.exceptions import ModelRetry, UnexpectedModelBehavior, UserError
from pydantic_ai.mcp import MCPServerStreamableHTTP, load_mcp_servers
from pydantic_ai.messages import ToolReturn
from pydantic_ai.models import Model
from pydantic_ai.models.test import TestModel
from pydantic_ai.tools import RunContext
Expand Down Expand Up @@ -77,7 +78,7 @@ async def test_stdio_server(run_context: RunContext[int]):
server = MCPServerStdio('python', ['-m', 'tests.mcp_server'])
async with server:
tools = [tool.tool_def for tool in (await server.get_tools(run_context)).values()]
assert len(tools) == snapshot(18)
assert len(tools) == snapshot(20)
assert tools[0].name == 'celsius_to_fahrenheit'
assert isinstance(tools[0].description, str)
assert tools[0].description.startswith('Convert Celsius to Fahrenheit.')
Expand All @@ -87,6 +88,36 @@ async def test_stdio_server(run_context: RunContext[int]):
assert result == snapshot(32.0)


async def test_tool_response_metadata(run_context: RunContext[int]):
server = MCPServerStdio('python', ['-m', 'tests.mcp_server'])
async with server:
tools = [tool.tool_def for tool in (await server.get_tools(run_context)).values()]
assert len(tools) == snapshot(20)
assert tools[4].name == 'get_collatz_conjecture'
assert isinstance(tools[4].description, str)
assert tools[4].description.startswith('Generate the Collatz conjecture sequence for a given number.')

result = await server.direct_call_tool('get_collatz_conjecture', {'n': 7})
assert isinstance(result, ToolReturn)
assert result.return_value == snapshot([7, 22, 11, 34, 17, 52, 26, 13, 40, 20, 10, 5, 16, 8, 4, 2, 1])
assert result.metadata == snapshot({'pydantic_ai': {'tool': 'collatz_conjecture', 'n': 7, 'length': 17}})


async def test_tool_structured_response_metadata(run_context: RunContext[int]):
server = MCPServerStdio('python', ['-m', 'tests.mcp_server'])
async with server:
tools = [tool.tool_def for tool in (await server.get_tools(run_context)).values()]
assert len(tools) == snapshot(20)
assert tools[5].name == 'get_structured_text_content_with_metadata'
assert isinstance(tools[5].description, str)
assert tools[5].description.startswith('Return structured dict with metadata.')

result = await server.direct_call_tool('get_structured_text_content_with_metadata', {})
assert isinstance(result, ToolReturn)
assert result.return_value == 'This is some text content.'
assert result.metadata == snapshot({'pydantic_ai': {'source': 'get_structured_text_content_with_metadata'}})


async def test_reentrant_context_manager():
server = MCPServerStdio('python', ['-m', 'tests.mcp_server'])
async with server:
Expand Down Expand Up @@ -138,7 +169,7 @@ async def test_stdio_server_with_cwd(run_context: RunContext[int]):
server = MCPServerStdio('python', ['mcp_server.py'], cwd=test_dir)
async with server:
tools = await server.get_tools(run_context)
assert len(tools) == snapshot(18)
assert len(tools) == snapshot(20)


async def test_process_tool_call(run_context: RunContext[int]) -> int:
Expand Down
Loading