diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 4d53518380..a0aebf906b 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -2,7 +2,6 @@ import asyncio import dataclasses -import hashlib from collections import defaultdict, deque from collections.abc import AsyncIterator, Awaitable, Callable, Iterator, Sequence from contextlib import asynccontextmanager, contextmanager @@ -650,13 +649,6 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT ) -def multi_modal_content_identifier(identifier: str | bytes) -> str: - """Generate stable identifier for multi-modal content to help LLM in finding a specific file in tool call responses.""" - if isinstance(identifier, str): - identifier = identifier.encode('utf-8') - return hashlib.sha1(identifier).hexdigest()[:6] - - async def process_function_tools( # noqa: C901 tool_manager: ToolManager[DepsT], tool_calls: list[_messages.ToolCallPart], @@ -915,10 +907,7 @@ async def _call_tool( f'`ToolReturn` should be used directly.' ) elif isinstance(content, _messages.MultiModalContent): - if isinstance(content, _messages.BinaryContent): - identifier = content.identifier or multi_modal_content_identifier(content.data) - else: - identifier = multi_modal_content_identifier(content.url) + identifier = content.identifier return_values.append(f'See file {identifier}') user_contents.extend([f'This is file {identifier}:', content]) diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index 3d9129ade8..c755885223 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -1,6 +1,7 @@ from __future__ import annotations as _annotations import base64 +import hashlib from abc import ABC, abstractmethod from collections.abc import Sequence from dataclasses import KW_ONLY, dataclass, field, replace @@ -88,6 +89,13 @@ def otel_message_parts(self, settings: InstrumentationSettings) -> list[_otel_me __repr__ = _utils.dataclasses_no_defaults_repr +def _multi_modal_content_identifier(identifier: str | bytes) -> str: + """Generate stable identifier for multi-modal content to help LLM in finding a specific file in tool call responses.""" + if isinstance(identifier, str): + identifier = identifier.encode('utf-8') + return hashlib.sha1(identifier).hexdigest()[:6] + + @dataclass(init=False, repr=False) class FileUrl(ABC): """Abstract base class for any URL-based file.""" @@ -115,17 +123,31 @@ class FileUrl(ABC): compare=False, default=None ) + identifier: str | None = None + """The identifier of the file, such as a unique ID. generating one from the url if not explicitly set + + This identifier can be provided to the model in a message to allow it to refer to this file in a tool call argument, + and the tool can look up the file in question by iterating over the message history and finding the matching `FileUrl`. + + This identifier is only automatically passed to the model when the `FileUrl` is returned by a tool. + If you're passing the `FileUrl` as a user message, it's up to you to include a separate text part with the identifier, + e.g. "This is file :" preceding the `FileUrl`. + """ + def __init__( self, url: str, + *, force_download: bool = False, vendor_metadata: dict[str, Any] | None = None, media_type: str | None = None, + identifier: str | None = None, ) -> None: self.url = url - self.vendor_metadata = vendor_metadata self.force_download = force_download + self.vendor_metadata = vendor_metadata self._media_type = media_type + self.identifier = identifier or _multi_modal_content_identifier(url) @pydantic.computed_field @property @@ -162,11 +184,12 @@ class VideoUrl(FileUrl): def __init__( self, url: str, + *, force_download: bool = False, vendor_metadata: dict[str, Any] | None = None, media_type: str | None = None, kind: Literal['video-url'] = 'video-url', - *, + identifier: str | None = None, # Required for inline-snapshot which expects all dataclass `__init__` methods to take all field names as kwargs. _media_type: str | None = None, ) -> None: @@ -175,6 +198,7 @@ def __init__( force_download=force_download, vendor_metadata=vendor_metadata, media_type=media_type or _media_type, + identifier=identifier, ) self.kind = kind @@ -235,11 +259,12 @@ class AudioUrl(FileUrl): def __init__( self, url: str, + *, force_download: bool = False, vendor_metadata: dict[str, Any] | None = None, media_type: str | None = None, kind: Literal['audio-url'] = 'audio-url', - *, + identifier: str | None = None, # Required for inline-snapshot which expects all dataclass `__init__` methods to take all field names as kwargs. _media_type: str | None = None, ) -> None: @@ -248,6 +273,7 @@ def __init__( force_download=force_download, vendor_metadata=vendor_metadata, media_type=media_type or _media_type, + identifier=identifier, ) self.kind = kind @@ -295,11 +321,12 @@ class ImageUrl(FileUrl): def __init__( self, url: str, + *, force_download: bool = False, vendor_metadata: dict[str, Any] | None = None, media_type: str | None = None, kind: Literal['image-url'] = 'image-url', - *, + identifier: str | None = None, # Required for inline-snapshot which expects all dataclass `__init__` methods to take all field names as kwargs. _media_type: str | None = None, ) -> None: @@ -308,6 +335,7 @@ def __init__( force_download=force_download, vendor_metadata=vendor_metadata, media_type=media_type or _media_type, + identifier=identifier, ) self.kind = kind @@ -350,11 +378,12 @@ class DocumentUrl(FileUrl): def __init__( self, url: str, + *, force_download: bool = False, vendor_metadata: dict[str, Any] | None = None, media_type: str | None = None, kind: Literal['document-url'] = 'document-url', - *, + identifier: str | None = None, # Required for inline-snapshot which expects all dataclass `__init__` methods to take all field names as kwargs. _media_type: str | None = None, ) -> None: @@ -363,6 +392,7 @@ def __init__( force_download=force_download, vendor_metadata=vendor_metadata, media_type=media_type or _media_type, + identifier=identifier, ) self.kind = kind @@ -405,24 +435,26 @@ def format(self) -> DocumentFormat: raise ValueError(f'Unknown document media type: {media_type}') from e -@dataclass(repr=False) +@dataclass(init=False, repr=False) class BinaryContent: """Binary content, e.g. an audio or image file.""" data: bytes """The binary data.""" - media_type: AudioMediaType | ImageMediaType | DocumentMediaType | str - """The media type of the binary data.""" - _: KW_ONLY - identifier: str | None = None - """Identifier for the binary content, such as a URL or unique ID. + media_type: AudioMediaType | ImageMediaType | DocumentMediaType | str + """The media type of the binary data.""" - This identifier can be provided to the model in a message to allow it to refer to this file in a tool call argument, and the tool can look up the file in question by iterating over the message history and finding the matching `BinaryContent`. + identifier: str + """Identifier for the binary content, such as a unique ID. generating one from the data if not explicitly set + This identifier can be provided to the model in a message to allow it to refer to this file in a tool call argument, + and the tool can look up the file in question by iterating over the message history and finding the matching `BinaryContent`. - This identifier is only automatically passed to the model when the `BinaryContent` is returned by a tool. If you're passing the `BinaryContent` as a user message, it's up to you to include a separate text part with the identifier, e.g. "This is file :" preceding the `BinaryContent`. + This identifier is only automatically passed to the model when the `BinaryContent` is returned by a tool. + If you're passing the `BinaryContent` as a user message, it's up to you to include a separate text part with the identifier, + e.g. "This is file :" preceding the `BinaryContent`. """ vendor_metadata: dict[str, Any] | None = None @@ -435,6 +467,21 @@ class BinaryContent: kind: Literal['binary'] = 'binary' """Type identifier, this is available on all parts as a discriminator.""" + def __init__( + self, + data: bytes, + *, + media_type: AudioMediaType | ImageMediaType | DocumentMediaType | str, + identifier: str | None = None, + vendor_metadata: dict[str, Any] | None = None, + kind: Literal['binary'] = 'binary', + ) -> None: + self.data = data + self.media_type = media_type + self.identifier = identifier or _multi_modal_content_identifier(data) + self.vendor_metadata = vendor_metadata + self.kind = kind + @property def is_audio(self) -> bool: """Return `True` if the media type is an audio type.""" diff --git a/tests/test_agent.py b/tests/test_agent.py index dc2c004ec7..2b7965eb97 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -28,7 +28,9 @@ from pydantic_ai.agent import AgentRunResult, WrapperAgent from pydantic_ai.messages import ( AgentStreamEvent, + AudioUrl, BinaryContent, + DocumentUrl, ImageUrl, ModelMessage, ModelMessagesTypeAdapter, @@ -42,6 +44,7 @@ ToolReturn, ToolReturnPart, UserPromptPart, + VideoUrl, ) from pydantic_ai.models.function import AgentInfo, FunctionModel from pydantic_ai.models.test import TestModel @@ -3088,7 +3091,7 @@ def test_binary_content_serializable(): 'media_type': 'text/plain', 'vendor_metadata': None, 'kind': 'binary', - 'identifier': None, + 'identifier': 'f7ff9e', }, ], 'timestamp': IsStr(), @@ -3143,6 +3146,7 @@ def test_image_url_serializable_missing_media_type(): 'vendor_metadata': None, 'kind': 'image-url', 'media_type': 'image/jpeg', + 'identifier': 'a72e39', }, ], 'timestamp': IsStr(), @@ -3204,6 +3208,7 @@ def test_image_url_serializable(): 'vendor_metadata': None, 'kind': 'image-url', 'media_type': 'image/jpeg', + 'identifier': 'bdd86d', }, ], 'timestamp': IsStr(), @@ -3248,7 +3253,7 @@ def test_image_url_serializable(): def test_tool_return_part_binary_content_serialization(): """Test that ToolReturnPart can properly serialize BinaryContent.""" png_data = b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\x0cIDATx\x9cc```\x00\x00\x00\x04\x00\x01\xf6\x178\x00\x00\x00\x00IEND\xaeB`\x82' - binary_content = BinaryContent(png_data, media_type='image/png', identifier='image_id_1') + binary_content = BinaryContent(png_data, media_type='image/png') tool_return = ToolReturnPart(tool_name='test_tool', content=binary_content, tool_call_id='test_call_123') @@ -3257,12 +3262,12 @@ def test_tool_return_part_binary_content_serialization(): assert '"kind":"binary"' in response_str assert '"media_type":"image/png"' in response_str assert '"data":"' in response_str - assert '"identifier":"image_id_1"' in response_str + assert '"identifier":"14a01a"' in response_str response_obj = tool_return.model_response_object() assert response_obj['return_value']['kind'] == 'binary' assert response_obj['return_value']['media_type'] == 'image/png' - assert response_obj['return_value']['identifier'] == 'image_id_1' + assert response_obj['return_value']['identifier'] == '14a01a' assert 'data' in response_obj['return_value'] @@ -3332,6 +3337,55 @@ def get_image() -> BinaryContent: ) +def test_tool_returning_file_url_with_identifier(): + """Test that a tool returning FileUrl subclasses with identifiers works correctly.""" + + def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + if len(messages) == 1: + return ModelResponse(parts=[ToolCallPart('get_files', {})]) + else: + return ModelResponse(parts=[TextPart('Files received')]) + + agent = Agent(FunctionModel(llm)) + + @agent.tool_plain + def get_files(): + """Return various file URLs with custom identifiers.""" + return [ + ImageUrl(url='https://example.com/image.jpg', identifier='img_001'), + VideoUrl(url='https://example.com/video.mp4', identifier='vid_002'), + AudioUrl(url='https://example.com/audio.mp3', identifier='aud_003'), + DocumentUrl(url='https://example.com/document.pdf', identifier='doc_004'), + ] + + result = agent.run_sync('Get some files') + assert result.all_messages()[2] == snapshot( + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_files', + content=['See file img_001', 'See file vid_002', 'See file aud_003', 'See file doc_004'], + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + ), + UserPromptPart( + content=[ + 'This is file img_001:', + ImageUrl(url='https://example.com/image.jpg', identifier='img_001'), + 'This is file vid_002:', + VideoUrl(url='https://example.com/video.mp4', identifier='vid_002'), + 'This is file aud_003:', + AudioUrl(url='https://example.com/audio.mp3', identifier='aud_003'), + 'This is file doc_004:', + DocumentUrl(url='https://example.com/document.pdf', identifier='doc_004'), + ], + timestamp=IsNow(tz=timezone.utc), + ), + ] + ) + ) + + def test_instructions_raise_error_when_system_prompt_is_set(): agent = Agent('test', instructions='An instructions!')