Skip to content
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 1 addition & 12 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import asyncio
import dataclasses
import hashlib
from collections import defaultdict, deque
from collections.abc import AsyncIterator, Awaitable, Callable, Iterator, Sequence
from contextlib import asynccontextmanager, contextmanager
Expand Down Expand Up @@ -650,13 +649,6 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT
)


def multi_modal_content_identifier(identifier: str | bytes) -> str:
"""Generate stable identifier for multi-modal content to help LLM in finding a specific file in tool call responses."""
if isinstance(identifier, str):
identifier = identifier.encode('utf-8')
return hashlib.sha1(identifier).hexdigest()[:6]


async def process_function_tools( # noqa: C901
tool_manager: ToolManager[DepsT],
tool_calls: list[_messages.ToolCallPart],
Expand Down Expand Up @@ -915,10 +907,7 @@ async def _call_tool(
f'`ToolReturn` should be used directly.'
)
elif isinstance(content, _messages.MultiModalContent):
if isinstance(content, _messages.BinaryContent):
identifier = content.identifier or multi_modal_content_identifier(content.data)
else:
identifier = multi_modal_content_identifier(content.url)
identifier = content.identifier

return_values.append(f'See file {identifier}')
user_contents.extend([f'This is file {identifier}:', content])
Expand Down
63 changes: 54 additions & 9 deletions pydantic_ai_slim/pydantic_ai/messages.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations as _annotations

import base64
import hashlib
from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import KW_ONLY, dataclass, field, replace
Expand Down Expand Up @@ -88,13 +89,31 @@ def otel_message_parts(self, settings: InstrumentationSettings) -> list[_otel_me
__repr__ = _utils.dataclasses_no_defaults_repr


def _multi_modal_content_identifier(identifier: str | bytes) -> str:
"""Generate stable identifier for multi-modal content to help LLM in finding a specific file in tool call responses."""
if isinstance(identifier, str):
identifier = identifier.encode('utf-8')
return hashlib.sha1(identifier).hexdigest()[:6]


@dataclass(init=False, repr=False)
class FileUrl(ABC):
"""Abstract base class for any URL-based file."""

url: str
"""The URL of the file."""

identifier: str
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please make this the final field, like it is on the constructor

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DouweM Thank you for the review!
Moved identifier field to be the final field in the dataclass definition

"""The identifier of the file, such as a unique ID. generating one from the url if not explicitly set

This identifier can be provided to the model in a message to allow it to refer to this file in a tool call argument,
and the tool can look up the file in question by iterating over the message history and finding the matching `FileUrl`.

This identifier is only automatically passed to the model when the `FileUrl` is returned by a tool.
If you're passing the `FileUrl` as a user message, it's up to you to include a separate text part with the identifier,
e.g. "This is file <identifier>:" preceding the `FileUrl`.
"""

_: KW_ONLY

force_download: bool = False
Expand All @@ -121,11 +140,13 @@ def __init__(
force_download: bool = False,
vendor_metadata: dict[str, Any] | None = None,
media_type: str | None = None,
identifier: str | None = None,
) -> None:
self.url = url
self.vendor_metadata = vendor_metadata
self.force_download = force_download
self.vendor_metadata = vendor_metadata
self._media_type = media_type
self.identifier = identifier or _multi_modal_content_identifier(url)

@pydantic.computed_field
@property
Expand Down Expand Up @@ -166,6 +187,7 @@ def __init__(
vendor_metadata: dict[str, Any] | None = None,
media_type: str | None = None,
kind: Literal['video-url'] = 'video-url',
identifier: str | None = None,
*,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please move the * ahead of force_download so all arguments other than url need to be keywords arguments -- same for the other subclasses

Copy link
Contributor Author

@kyuam32 kyuam32 Sep 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

refactored with same rules for other MultiModalContent types

# Required for inline-snapshot which expects all dataclass `__init__` methods to take all field names as kwargs.
_media_type: str | None = None,
Expand All @@ -175,6 +197,7 @@ def __init__(
force_download=force_download,
vendor_metadata=vendor_metadata,
media_type=media_type or _media_type,
identifier=identifier,
)
self.kind = kind

Expand Down Expand Up @@ -239,6 +262,7 @@ def __init__(
vendor_metadata: dict[str, Any] | None = None,
media_type: str | None = None,
kind: Literal['audio-url'] = 'audio-url',
identifier: str | None = None,
*,
# Required for inline-snapshot which expects all dataclass `__init__` methods to take all field names as kwargs.
_media_type: str | None = None,
Expand All @@ -248,6 +272,7 @@ def __init__(
force_download=force_download,
vendor_metadata=vendor_metadata,
media_type=media_type or _media_type,
identifier=identifier,
)
self.kind = kind

Expand Down Expand Up @@ -299,6 +324,7 @@ def __init__(
vendor_metadata: dict[str, Any] | None = None,
media_type: str | None = None,
kind: Literal['image-url'] = 'image-url',
identifier: str | None = None,
*,
# Required for inline-snapshot which expects all dataclass `__init__` methods to take all field names as kwargs.
_media_type: str | None = None,
Expand All @@ -308,6 +334,7 @@ def __init__(
force_download=force_download,
vendor_metadata=vendor_metadata,
media_type=media_type or _media_type,
identifier=identifier,
)
self.kind = kind

Expand Down Expand Up @@ -354,6 +381,7 @@ def __init__(
vendor_metadata: dict[str, Any] | None = None,
media_type: str | None = None,
kind: Literal['document-url'] = 'document-url',
identifier: str | None = None,
*,
# Required for inline-snapshot which expects all dataclass `__init__` methods to take all field names as kwargs.
_media_type: str | None = None,
Expand All @@ -363,6 +391,7 @@ def __init__(
force_download=force_download,
vendor_metadata=vendor_metadata,
media_type=media_type or _media_type,
identifier=identifier,
)
self.kind = kind

Expand Down Expand Up @@ -405,7 +434,7 @@ def format(self) -> DocumentFormat:
raise ValueError(f'Unknown document media type: {media_type}') from e


@dataclass(repr=False)
@dataclass(init=False, repr=False)
class BinaryContent:
"""Binary content, e.g. an audio or image file."""

Expand All @@ -415,16 +444,18 @@ class BinaryContent:
media_type: AudioMediaType | ImageMediaType | DocumentMediaType | str
"""The media type of the binary data."""

_: KW_ONLY

identifier: str | None = None
"""Identifier for the binary content, such as a URL or unique ID.

This identifier can be provided to the model in a message to allow it to refer to this file in a tool call argument, and the tool can look up the file in question by iterating over the message history and finding the matching `BinaryContent`.
identifier: str
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DouweM @kyuam32 -- was this intended to go to str vs keeping str | None? we store old json to rehydrate later into chags and now ModelMessageAdapter.validate_json throws because it expects identifier to have a value for those old dicts

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kousun12 Ay that was unintentional, I was relying on the fact that we always set an identifier in __init__, but that wouldn't work for validation. Can you create a new issue please?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Related issue has been filed: #3103

"""Identifier for the binary content, such as a unique ID. generating one from the data if not explicitly set
This identifier can be provided to the model in a message to allow it to refer to this file in a tool call argument,
and the tool can look up the file in question by iterating over the message history and finding the matching `BinaryContent`.

This identifier is only automatically passed to the model when the `BinaryContent` is returned by a tool. If you're passing the `BinaryContent` as a user message, it's up to you to include a separate text part with the identifier, e.g. "This is file <identifier>:" preceding the `BinaryContent`.
This identifier is only automatically passed to the model when the `BinaryContent` is returned by a tool.
If you're passing the `BinaryContent` as a user message, it's up to you to include a separate text part with the identifier,
e.g. "This is file <identifier>:" preceding the `BinaryContent`.
"""

_: KW_ONLY
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please move this ahead of media_type

Copy link
Contributor Author

@kyuam32 kyuam32 Sep 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved _: KW_ONLY ahead of media_type, after data field


vendor_metadata: dict[str, Any] | None = None
"""Vendor-specific metadata for the file.

Expand All @@ -435,6 +466,20 @@ class BinaryContent:
kind: Literal['binary'] = 'binary'
"""Type identifier, this is available on all parts as a discriminator."""

def __init__(
self,
data: bytes,
media_type: AudioMediaType | ImageMediaType | DocumentMediaType | str,
identifier: str | None = None,
vendor_metadata: dict[str, Any] | None = None,
kind: Literal['binary'] = 'binary',
) -> None:
self.data = data
self.media_type = media_type
self.identifier = identifier or _multi_modal_content_identifier(data)
self.vendor_metadata = vendor_metadata
self.kind = kind

@property
def is_audio(self) -> bool:
"""Return `True` if the media type is an audio type."""
Expand Down
62 changes: 58 additions & 4 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
from pydantic_ai.agent import AgentRunResult, WrapperAgent
from pydantic_ai.messages import (
AgentStreamEvent,
AudioUrl,
BinaryContent,
DocumentUrl,
ImageUrl,
ModelMessage,
ModelMessagesTypeAdapter,
Expand All @@ -42,6 +44,7 @@
ToolReturn,
ToolReturnPart,
UserPromptPart,
VideoUrl,
)
from pydantic_ai.models.function import AgentInfo, FunctionModel
from pydantic_ai.models.test import TestModel
Expand Down Expand Up @@ -3088,7 +3091,7 @@ def test_binary_content_serializable():
'media_type': 'text/plain',
'vendor_metadata': None,
'kind': 'binary',
'identifier': None,
'identifier': 'f7ff9e',
},
],
'timestamp': IsStr(),
Expand Down Expand Up @@ -3143,6 +3146,7 @@ def test_image_url_serializable_missing_media_type():
'vendor_metadata': None,
'kind': 'image-url',
'media_type': 'image/jpeg',
'identifier': 'a72e39',
},
],
'timestamp': IsStr(),
Expand Down Expand Up @@ -3204,6 +3208,7 @@ def test_image_url_serializable():
'vendor_metadata': None,
'kind': 'image-url',
'media_type': 'image/jpeg',
'identifier': 'bdd86d',
},
],
'timestamp': IsStr(),
Expand Down Expand Up @@ -3248,7 +3253,7 @@ def test_image_url_serializable():
def test_tool_return_part_binary_content_serialization():
"""Test that ToolReturnPart can properly serialize BinaryContent."""
png_data = b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\x0cIDATx\x9cc```\x00\x00\x00\x04\x00\x01\xf6\x178\x00\x00\x00\x00IEND\xaeB`\x82'
binary_content = BinaryContent(png_data, media_type='image/png', identifier='image_id_1')
binary_content = BinaryContent(png_data, media_type='image/png')

tool_return = ToolReturnPart(tool_name='test_tool', content=binary_content, tool_call_id='test_call_123')

Expand All @@ -3257,12 +3262,12 @@ def test_tool_return_part_binary_content_serialization():
assert '"kind":"binary"' in response_str
assert '"media_type":"image/png"' in response_str
assert '"data":"' in response_str
assert '"identifier":"image_id_1"' in response_str
assert '"identifier":"14a01a"' in response_str

response_obj = tool_return.model_response_object()
assert response_obj['return_value']['kind'] == 'binary'
assert response_obj['return_value']['media_type'] == 'image/png'
assert response_obj['return_value']['identifier'] == 'image_id_1'
assert response_obj['return_value']['identifier'] == '14a01a'
assert 'data' in response_obj['return_value']


Expand Down Expand Up @@ -3332,6 +3337,55 @@ def get_image() -> BinaryContent:
)


def test_tool_returning_file_url_with_identifier():
"""Test that a tool returning FileUrl subclasses with identifiers works correctly."""

def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
if len(messages) == 1:
return ModelResponse(parts=[ToolCallPart('get_files', {})])
else:
return ModelResponse(parts=[TextPart('Files received')])

agent = Agent(FunctionModel(llm))

@agent.tool_plain
def get_files():
"""Return various file URLs with custom identifiers."""
return [
ImageUrl(url='https://example.com/image.jpg', identifier='img_001'),
VideoUrl(url='https://example.com/video.mp4', identifier='vid_002'),
AudioUrl(url='https://example.com/audio.mp3', identifier='aud_003'),
DocumentUrl(url='https://example.com/document.pdf', identifier='doc_004'),
]

result = agent.run_sync('Get some files')
assert result.all_messages()[2] == snapshot(
ModelRequest(
parts=[
ToolReturnPart(
tool_name='get_files',
content=['See file img_001', 'See file vid_002', 'See file aud_003', 'See file doc_004'],
tool_call_id=IsStr(),
timestamp=IsNow(tz=timezone.utc),
),
UserPromptPart(
content=[
'This is file img_001:',
ImageUrl(url='https://example.com/image.jpg', identifier='img_001'),
'This is file vid_002:',
VideoUrl(url='https://example.com/video.mp4', identifier='vid_002'),
'This is file aud_003:',
AudioUrl(url='https://example.com/audio.mp3', identifier='aud_003'),
'This is file doc_004:',
DocumentUrl(url='https://example.com/document.pdf', identifier='doc_004'),
],
timestamp=IsNow(tz=timezone.utc),
),
]
)
)


def test_instructions_raise_error_when_system_prompt_is_set():
agent = Agent('test', instructions='An instructions!')

Expand Down