Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 1 addition & 12 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import asyncio
import dataclasses
import hashlib
from collections import defaultdict, deque
from collections.abc import AsyncIterator, Awaitable, Callable, Iterator, Sequence
from contextlib import asynccontextmanager, contextmanager
Expand Down Expand Up @@ -650,13 +649,6 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT
)


def multi_modal_content_identifier(identifier: str | bytes) -> str:
"""Generate stable identifier for multi-modal content to help LLM in finding a specific file in tool call responses."""
if isinstance(identifier, str):
identifier = identifier.encode('utf-8')
return hashlib.sha1(identifier).hexdigest()[:6]


async def process_function_tools( # noqa: C901
tool_manager: ToolManager[DepsT],
tool_calls: list[_messages.ToolCallPart],
Expand Down Expand Up @@ -915,10 +907,7 @@ async def _call_tool(
f'`ToolReturn` should be used directly.'
)
elif isinstance(content, _messages.MultiModalContent):
if isinstance(content, _messages.BinaryContent):
identifier = content.identifier or multi_modal_content_identifier(content.data)
else:
identifier = multi_modal_content_identifier(content.url)
identifier = content.identifier

return_values.append(f'See file {identifier}')
user_contents.extend([f'This is file {identifier}:', content])
Expand Down
73 changes: 60 additions & 13 deletions pydantic_ai_slim/pydantic_ai/messages.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations as _annotations

import base64
import hashlib
from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import KW_ONLY, dataclass, field, replace
Expand Down Expand Up @@ -88,6 +89,13 @@ def otel_message_parts(self, settings: InstrumentationSettings) -> list[_otel_me
__repr__ = _utils.dataclasses_no_defaults_repr


def _multi_modal_content_identifier(identifier: str | bytes) -> str:
"""Generate stable identifier for multi-modal content to help LLM in finding a specific file in tool call responses."""
if isinstance(identifier, str):
identifier = identifier.encode('utf-8')
return hashlib.sha1(identifier).hexdigest()[:6]


@dataclass(init=False, repr=False)
class FileUrl(ABC):
"""Abstract base class for any URL-based file."""
Expand Down Expand Up @@ -115,17 +123,31 @@ class FileUrl(ABC):
compare=False, default=None
)

identifier: str | None = None
"""The identifier of the file, such as a unique ID. generating one from the url if not explicitly set

This identifier can be provided to the model in a message to allow it to refer to this file in a tool call argument,
and the tool can look up the file in question by iterating over the message history and finding the matching `FileUrl`.

This identifier is only automatically passed to the model when the `FileUrl` is returned by a tool.
If you're passing the `FileUrl` as a user message, it's up to you to include a separate text part with the identifier,
e.g. "This is file <identifier>:" preceding the `FileUrl`.
"""

def __init__(
self,
url: str,
*,
force_download: bool = False,
vendor_metadata: dict[str, Any] | None = None,
media_type: str | None = None,
identifier: str | None = None,
) -> None:
self.url = url
self.vendor_metadata = vendor_metadata
self.force_download = force_download
self.vendor_metadata = vendor_metadata
self._media_type = media_type
self.identifier = identifier or _multi_modal_content_identifier(url)

@pydantic.computed_field
@property
Expand Down Expand Up @@ -162,11 +184,12 @@ class VideoUrl(FileUrl):
def __init__(
self,
url: str,
*,
force_download: bool = False,
vendor_metadata: dict[str, Any] | None = None,
media_type: str | None = None,
kind: Literal['video-url'] = 'video-url',
*,
identifier: str | None = None,
# Required for inline-snapshot which expects all dataclass `__init__` methods to take all field names as kwargs.
_media_type: str | None = None,
) -> None:
Expand All @@ -175,6 +198,7 @@ def __init__(
force_download=force_download,
vendor_metadata=vendor_metadata,
media_type=media_type or _media_type,
identifier=identifier,
)
self.kind = kind

Expand Down Expand Up @@ -235,11 +259,12 @@ class AudioUrl(FileUrl):
def __init__(
self,
url: str,
*,
force_download: bool = False,
vendor_metadata: dict[str, Any] | None = None,
media_type: str | None = None,
kind: Literal['audio-url'] = 'audio-url',
*,
identifier: str | None = None,
# Required for inline-snapshot which expects all dataclass `__init__` methods to take all field names as kwargs.
_media_type: str | None = None,
) -> None:
Expand All @@ -248,6 +273,7 @@ def __init__(
force_download=force_download,
vendor_metadata=vendor_metadata,
media_type=media_type or _media_type,
identifier=identifier,
)
self.kind = kind

Expand Down Expand Up @@ -295,11 +321,12 @@ class ImageUrl(FileUrl):
def __init__(
self,
url: str,
*,
force_download: bool = False,
vendor_metadata: dict[str, Any] | None = None,
media_type: str | None = None,
kind: Literal['image-url'] = 'image-url',
*,
identifier: str | None = None,
# Required for inline-snapshot which expects all dataclass `__init__` methods to take all field names as kwargs.
_media_type: str | None = None,
) -> None:
Expand All @@ -308,6 +335,7 @@ def __init__(
force_download=force_download,
vendor_metadata=vendor_metadata,
media_type=media_type or _media_type,
identifier=identifier,
)
self.kind = kind

Expand Down Expand Up @@ -350,11 +378,12 @@ class DocumentUrl(FileUrl):
def __init__(
self,
url: str,
*,
force_download: bool = False,
vendor_metadata: dict[str, Any] | None = None,
media_type: str | None = None,
kind: Literal['document-url'] = 'document-url',
*,
identifier: str | None = None,
# Required for inline-snapshot which expects all dataclass `__init__` methods to take all field names as kwargs.
_media_type: str | None = None,
) -> None:
Expand All @@ -363,6 +392,7 @@ def __init__(
force_download=force_download,
vendor_metadata=vendor_metadata,
media_type=media_type or _media_type,
identifier=identifier,
)
self.kind = kind

Expand Down Expand Up @@ -405,24 +435,26 @@ def format(self) -> DocumentFormat:
raise ValueError(f'Unknown document media type: {media_type}') from e


@dataclass(repr=False)
@dataclass(init=False, repr=False)
class BinaryContent:
"""Binary content, e.g. an audio or image file."""

data: bytes
"""The binary data."""

media_type: AudioMediaType | ImageMediaType | DocumentMediaType | str
"""The media type of the binary data."""

_: KW_ONLY

identifier: str | None = None
"""Identifier for the binary content, such as a URL or unique ID.
media_type: AudioMediaType | ImageMediaType | DocumentMediaType | str
"""The media type of the binary data."""

This identifier can be provided to the model in a message to allow it to refer to this file in a tool call argument, and the tool can look up the file in question by iterating over the message history and finding the matching `BinaryContent`.
identifier: str
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DouweM @kyuam32 -- was this intended to go to str vs keeping str | None? we store old json to rehydrate later into chags and now ModelMessageAdapter.validate_json throws because it expects identifier to have a value for those old dicts

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kousun12 Ay that was unintentional, I was relying on the fact that we always set an identifier in __init__, but that wouldn't work for validation. Can you create a new issue please?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Related issue has been filed: #3103

"""Identifier for the binary content, such as a unique ID. generating one from the data if not explicitly set
This identifier can be provided to the model in a message to allow it to refer to this file in a tool call argument,
and the tool can look up the file in question by iterating over the message history and finding the matching `BinaryContent`.

This identifier is only automatically passed to the model when the `BinaryContent` is returned by a tool. If you're passing the `BinaryContent` as a user message, it's up to you to include a separate text part with the identifier, e.g. "This is file <identifier>:" preceding the `BinaryContent`.
This identifier is only automatically passed to the model when the `BinaryContent` is returned by a tool.
If you're passing the `BinaryContent` as a user message, it's up to you to include a separate text part with the identifier,
e.g. "This is file <identifier>:" preceding the `BinaryContent`.
"""

vendor_metadata: dict[str, Any] | None = None
Expand All @@ -435,6 +467,21 @@ class BinaryContent:
kind: Literal['binary'] = 'binary'
"""Type identifier, this is available on all parts as a discriminator."""

def __init__(
self,
data: bytes,
*,
media_type: AudioMediaType | ImageMediaType | DocumentMediaType | str,
identifier: str | None = None,
vendor_metadata: dict[str, Any] | None = None,
kind: Literal['binary'] = 'binary',
) -> None:
self.data = data
self.media_type = media_type
self.identifier = identifier or _multi_modal_content_identifier(data)
self.vendor_metadata = vendor_metadata
self.kind = kind

@property
def is_audio(self) -> bool:
"""Return `True` if the media type is an audio type."""
Expand Down
62 changes: 58 additions & 4 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
from pydantic_ai.agent import AgentRunResult, WrapperAgent
from pydantic_ai.messages import (
AgentStreamEvent,
AudioUrl,
BinaryContent,
DocumentUrl,
ImageUrl,
ModelMessage,
ModelMessagesTypeAdapter,
Expand All @@ -42,6 +44,7 @@
ToolReturn,
ToolReturnPart,
UserPromptPart,
VideoUrl,
)
from pydantic_ai.models.function import AgentInfo, FunctionModel
from pydantic_ai.models.test import TestModel
Expand Down Expand Up @@ -3088,7 +3091,7 @@ def test_binary_content_serializable():
'media_type': 'text/plain',
'vendor_metadata': None,
'kind': 'binary',
'identifier': None,
'identifier': 'f7ff9e',
},
],
'timestamp': IsStr(),
Expand Down Expand Up @@ -3143,6 +3146,7 @@ def test_image_url_serializable_missing_media_type():
'vendor_metadata': None,
'kind': 'image-url',
'media_type': 'image/jpeg',
'identifier': 'a72e39',
},
],
'timestamp': IsStr(),
Expand Down Expand Up @@ -3204,6 +3208,7 @@ def test_image_url_serializable():
'vendor_metadata': None,
'kind': 'image-url',
'media_type': 'image/jpeg',
'identifier': 'bdd86d',
},
],
'timestamp': IsStr(),
Expand Down Expand Up @@ -3248,7 +3253,7 @@ def test_image_url_serializable():
def test_tool_return_part_binary_content_serialization():
"""Test that ToolReturnPart can properly serialize BinaryContent."""
png_data = b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\x0cIDATx\x9cc```\x00\x00\x00\x04\x00\x01\xf6\x178\x00\x00\x00\x00IEND\xaeB`\x82'
binary_content = BinaryContent(png_data, media_type='image/png', identifier='image_id_1')
binary_content = BinaryContent(png_data, media_type='image/png')

tool_return = ToolReturnPart(tool_name='test_tool', content=binary_content, tool_call_id='test_call_123')

Expand All @@ -3257,12 +3262,12 @@ def test_tool_return_part_binary_content_serialization():
assert '"kind":"binary"' in response_str
assert '"media_type":"image/png"' in response_str
assert '"data":"' in response_str
assert '"identifier":"image_id_1"' in response_str
assert '"identifier":"14a01a"' in response_str

response_obj = tool_return.model_response_object()
assert response_obj['return_value']['kind'] == 'binary'
assert response_obj['return_value']['media_type'] == 'image/png'
assert response_obj['return_value']['identifier'] == 'image_id_1'
assert response_obj['return_value']['identifier'] == '14a01a'
assert 'data' in response_obj['return_value']


Expand Down Expand Up @@ -3332,6 +3337,55 @@ def get_image() -> BinaryContent:
)


def test_tool_returning_file_url_with_identifier():
"""Test that a tool returning FileUrl subclasses with identifiers works correctly."""

def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
if len(messages) == 1:
return ModelResponse(parts=[ToolCallPart('get_files', {})])
else:
return ModelResponse(parts=[TextPart('Files received')])

agent = Agent(FunctionModel(llm))

@agent.tool_plain
def get_files():
"""Return various file URLs with custom identifiers."""
return [
ImageUrl(url='https://example.com/image.jpg', identifier='img_001'),
VideoUrl(url='https://example.com/video.mp4', identifier='vid_002'),
AudioUrl(url='https://example.com/audio.mp3', identifier='aud_003'),
DocumentUrl(url='https://example.com/document.pdf', identifier='doc_004'),
]

result = agent.run_sync('Get some files')
assert result.all_messages()[2] == snapshot(
ModelRequest(
parts=[
ToolReturnPart(
tool_name='get_files',
content=['See file img_001', 'See file vid_002', 'See file aud_003', 'See file doc_004'],
tool_call_id=IsStr(),
timestamp=IsNow(tz=timezone.utc),
),
UserPromptPart(
content=[
'This is file img_001:',
ImageUrl(url='https://example.com/image.jpg', identifier='img_001'),
'This is file vid_002:',
VideoUrl(url='https://example.com/video.mp4', identifier='vid_002'),
'This is file aud_003:',
AudioUrl(url='https://example.com/audio.mp3', identifier='aud_003'),
'This is file doc_004:',
DocumentUrl(url='https://example.com/document.pdf', identifier='doc_004'),
],
timestamp=IsNow(tz=timezone.utc),
),
]
)
)


def test_instructions_raise_error_when_system_prompt_is_set():
agent = Agent('test', instructions='An instructions!')

Expand Down