Skip to content

Commit 46ba28f

Browse files
authored
Add identifier field to FileUrl and subclasses (#2636)
1 parent ca079f5 commit 46ba28f

File tree

3 files changed

+119
-29
lines changed

3 files changed

+119
-29
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import asyncio
44
import dataclasses
5-
import hashlib
65
from collections import defaultdict, deque
76
from collections.abc import AsyncIterator, Awaitable, Callable, Iterator, Sequence
87
from contextlib import asynccontextmanager, contextmanager
@@ -650,13 +649,6 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT
650649
)
651650

652651

653-
def multi_modal_content_identifier(identifier: str | bytes) -> str:
654-
"""Generate stable identifier for multi-modal content to help LLM in finding a specific file in tool call responses."""
655-
if isinstance(identifier, str):
656-
identifier = identifier.encode('utf-8')
657-
return hashlib.sha1(identifier).hexdigest()[:6]
658-
659-
660652
async def process_function_tools( # noqa: C901
661653
tool_manager: ToolManager[DepsT],
662654
tool_calls: list[_messages.ToolCallPart],
@@ -915,10 +907,7 @@ async def _call_tool(
915907
f'`ToolReturn` should be used directly.'
916908
)
917909
elif isinstance(content, _messages.MultiModalContent):
918-
if isinstance(content, _messages.BinaryContent):
919-
identifier = content.identifier or multi_modal_content_identifier(content.data)
920-
else:
921-
identifier = multi_modal_content_identifier(content.url)
910+
identifier = content.identifier
922911

923912
return_values.append(f'See file {identifier}')
924913
user_contents.extend([f'This is file {identifier}:', content])

pydantic_ai_slim/pydantic_ai/messages.py

Lines changed: 60 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations as _annotations
22

33
import base64
4+
import hashlib
45
from abc import ABC, abstractmethod
56
from collections.abc import Sequence
67
from dataclasses import KW_ONLY, dataclass, field, replace
@@ -88,6 +89,13 @@ def otel_message_parts(self, settings: InstrumentationSettings) -> list[_otel_me
8889
__repr__ = _utils.dataclasses_no_defaults_repr
8990

9091

92+
def _multi_modal_content_identifier(identifier: str | bytes) -> str:
93+
"""Generate stable identifier for multi-modal content to help LLM in finding a specific file in tool call responses."""
94+
if isinstance(identifier, str):
95+
identifier = identifier.encode('utf-8')
96+
return hashlib.sha1(identifier).hexdigest()[:6]
97+
98+
9199
@dataclass(init=False, repr=False)
92100
class FileUrl(ABC):
93101
"""Abstract base class for any URL-based file."""
@@ -115,17 +123,31 @@ class FileUrl(ABC):
115123
compare=False, default=None
116124
)
117125

126+
identifier: str | None = None
127+
"""The identifier of the file, such as a unique ID. generating one from the url if not explicitly set
128+
129+
This identifier can be provided to the model in a message to allow it to refer to this file in a tool call argument,
130+
and the tool can look up the file in question by iterating over the message history and finding the matching `FileUrl`.
131+
132+
This identifier is only automatically passed to the model when the `FileUrl` is returned by a tool.
133+
If you're passing the `FileUrl` as a user message, it's up to you to include a separate text part with the identifier,
134+
e.g. "This is file <identifier>:" preceding the `FileUrl`.
135+
"""
136+
118137
def __init__(
119138
self,
120139
url: str,
140+
*,
121141
force_download: bool = False,
122142
vendor_metadata: dict[str, Any] | None = None,
123143
media_type: str | None = None,
144+
identifier: str | None = None,
124145
) -> None:
125146
self.url = url
126-
self.vendor_metadata = vendor_metadata
127147
self.force_download = force_download
148+
self.vendor_metadata = vendor_metadata
128149
self._media_type = media_type
150+
self.identifier = identifier or _multi_modal_content_identifier(url)
129151

130152
@pydantic.computed_field
131153
@property
@@ -162,11 +184,12 @@ class VideoUrl(FileUrl):
162184
def __init__(
163185
self,
164186
url: str,
187+
*,
165188
force_download: bool = False,
166189
vendor_metadata: dict[str, Any] | None = None,
167190
media_type: str | None = None,
168191
kind: Literal['video-url'] = 'video-url',
169-
*,
192+
identifier: str | None = None,
170193
# Required for inline-snapshot which expects all dataclass `__init__` methods to take all field names as kwargs.
171194
_media_type: str | None = None,
172195
) -> None:
@@ -175,6 +198,7 @@ def __init__(
175198
force_download=force_download,
176199
vendor_metadata=vendor_metadata,
177200
media_type=media_type or _media_type,
201+
identifier=identifier,
178202
)
179203
self.kind = kind
180204

@@ -235,11 +259,12 @@ class AudioUrl(FileUrl):
235259
def __init__(
236260
self,
237261
url: str,
262+
*,
238263
force_download: bool = False,
239264
vendor_metadata: dict[str, Any] | None = None,
240265
media_type: str | None = None,
241266
kind: Literal['audio-url'] = 'audio-url',
242-
*,
267+
identifier: str | None = None,
243268
# Required for inline-snapshot which expects all dataclass `__init__` methods to take all field names as kwargs.
244269
_media_type: str | None = None,
245270
) -> None:
@@ -248,6 +273,7 @@ def __init__(
248273
force_download=force_download,
249274
vendor_metadata=vendor_metadata,
250275
media_type=media_type or _media_type,
276+
identifier=identifier,
251277
)
252278
self.kind = kind
253279

@@ -295,11 +321,12 @@ class ImageUrl(FileUrl):
295321
def __init__(
296322
self,
297323
url: str,
324+
*,
298325
force_download: bool = False,
299326
vendor_metadata: dict[str, Any] | None = None,
300327
media_type: str | None = None,
301328
kind: Literal['image-url'] = 'image-url',
302-
*,
329+
identifier: str | None = None,
303330
# Required for inline-snapshot which expects all dataclass `__init__` methods to take all field names as kwargs.
304331
_media_type: str | None = None,
305332
) -> None:
@@ -308,6 +335,7 @@ def __init__(
308335
force_download=force_download,
309336
vendor_metadata=vendor_metadata,
310337
media_type=media_type or _media_type,
338+
identifier=identifier,
311339
)
312340
self.kind = kind
313341

@@ -350,11 +378,12 @@ class DocumentUrl(FileUrl):
350378
def __init__(
351379
self,
352380
url: str,
381+
*,
353382
force_download: bool = False,
354383
vendor_metadata: dict[str, Any] | None = None,
355384
media_type: str | None = None,
356385
kind: Literal['document-url'] = 'document-url',
357-
*,
386+
identifier: str | None = None,
358387
# Required for inline-snapshot which expects all dataclass `__init__` methods to take all field names as kwargs.
359388
_media_type: str | None = None,
360389
) -> None:
@@ -363,6 +392,7 @@ def __init__(
363392
force_download=force_download,
364393
vendor_metadata=vendor_metadata,
365394
media_type=media_type or _media_type,
395+
identifier=identifier,
366396
)
367397
self.kind = kind
368398

@@ -405,24 +435,26 @@ def format(self) -> DocumentFormat:
405435
raise ValueError(f'Unknown document media type: {media_type}') from e
406436

407437

408-
@dataclass(repr=False)
438+
@dataclass(init=False, repr=False)
409439
class BinaryContent:
410440
"""Binary content, e.g. an audio or image file."""
411441

412442
data: bytes
413443
"""The binary data."""
414444

415-
media_type: AudioMediaType | ImageMediaType | DocumentMediaType | str
416-
"""The media type of the binary data."""
417-
418445
_: KW_ONLY
419446

420-
identifier: str | None = None
421-
"""Identifier for the binary content, such as a URL or unique ID.
447+
media_type: AudioMediaType | ImageMediaType | DocumentMediaType | str
448+
"""The media type of the binary data."""
422449

423-
This identifier can be provided to the model in a message to allow it to refer to this file in a tool call argument, and the tool can look up the file in question by iterating over the message history and finding the matching `BinaryContent`.
450+
identifier: str
451+
"""Identifier for the binary content, such as a unique ID. generating one from the data if not explicitly set
452+
This identifier can be provided to the model in a message to allow it to refer to this file in a tool call argument,
453+
and the tool can look up the file in question by iterating over the message history and finding the matching `BinaryContent`.
424454
425-
This identifier is only automatically passed to the model when the `BinaryContent` is returned by a tool. If you're passing the `BinaryContent` as a user message, it's up to you to include a separate text part with the identifier, e.g. "This is file <identifier>:" preceding the `BinaryContent`.
455+
This identifier is only automatically passed to the model when the `BinaryContent` is returned by a tool.
456+
If you're passing the `BinaryContent` as a user message, it's up to you to include a separate text part with the identifier,
457+
e.g. "This is file <identifier>:" preceding the `BinaryContent`.
426458
"""
427459

428460
vendor_metadata: dict[str, Any] | None = None
@@ -435,6 +467,21 @@ class BinaryContent:
435467
kind: Literal['binary'] = 'binary'
436468
"""Type identifier, this is available on all parts as a discriminator."""
437469

470+
def __init__(
471+
self,
472+
data: bytes,
473+
*,
474+
media_type: AudioMediaType | ImageMediaType | DocumentMediaType | str,
475+
identifier: str | None = None,
476+
vendor_metadata: dict[str, Any] | None = None,
477+
kind: Literal['binary'] = 'binary',
478+
) -> None:
479+
self.data = data
480+
self.media_type = media_type
481+
self.identifier = identifier or _multi_modal_content_identifier(data)
482+
self.vendor_metadata = vendor_metadata
483+
self.kind = kind
484+
438485
@property
439486
def is_audio(self) -> bool:
440487
"""Return `True` if the media type is an audio type."""

tests/test_agent.py

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@
2828
from pydantic_ai.agent import AgentRunResult, WrapperAgent
2929
from pydantic_ai.messages import (
3030
AgentStreamEvent,
31+
AudioUrl,
3132
BinaryContent,
33+
DocumentUrl,
3234
ImageUrl,
3335
ModelMessage,
3436
ModelMessagesTypeAdapter,
@@ -42,6 +44,7 @@
4244
ToolReturn,
4345
ToolReturnPart,
4446
UserPromptPart,
47+
VideoUrl,
4548
)
4649
from pydantic_ai.models.function import AgentInfo, FunctionModel
4750
from pydantic_ai.models.test import TestModel
@@ -3088,7 +3091,7 @@ def test_binary_content_serializable():
30883091
'media_type': 'text/plain',
30893092
'vendor_metadata': None,
30903093
'kind': 'binary',
3091-
'identifier': None,
3094+
'identifier': 'f7ff9e',
30923095
},
30933096
],
30943097
'timestamp': IsStr(),
@@ -3143,6 +3146,7 @@ def test_image_url_serializable_missing_media_type():
31433146
'vendor_metadata': None,
31443147
'kind': 'image-url',
31453148
'media_type': 'image/jpeg',
3149+
'identifier': 'a72e39',
31463150
},
31473151
],
31483152
'timestamp': IsStr(),
@@ -3204,6 +3208,7 @@ def test_image_url_serializable():
32043208
'vendor_metadata': None,
32053209
'kind': 'image-url',
32063210
'media_type': 'image/jpeg',
3211+
'identifier': 'bdd86d',
32073212
},
32083213
],
32093214
'timestamp': IsStr(),
@@ -3248,7 +3253,7 @@ def test_image_url_serializable():
32483253
def test_tool_return_part_binary_content_serialization():
32493254
"""Test that ToolReturnPart can properly serialize BinaryContent."""
32503255
png_data = b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\x0cIDATx\x9cc```\x00\x00\x00\x04\x00\x01\xf6\x178\x00\x00\x00\x00IEND\xaeB`\x82'
3251-
binary_content = BinaryContent(png_data, media_type='image/png', identifier='image_id_1')
3256+
binary_content = BinaryContent(png_data, media_type='image/png')
32523257

32533258
tool_return = ToolReturnPart(tool_name='test_tool', content=binary_content, tool_call_id='test_call_123')
32543259

@@ -3257,12 +3262,12 @@ def test_tool_return_part_binary_content_serialization():
32573262
assert '"kind":"binary"' in response_str
32583263
assert '"media_type":"image/png"' in response_str
32593264
assert '"data":"' in response_str
3260-
assert '"identifier":"image_id_1"' in response_str
3265+
assert '"identifier":"14a01a"' in response_str
32613266

32623267
response_obj = tool_return.model_response_object()
32633268
assert response_obj['return_value']['kind'] == 'binary'
32643269
assert response_obj['return_value']['media_type'] == 'image/png'
3265-
assert response_obj['return_value']['identifier'] == 'image_id_1'
3270+
assert response_obj['return_value']['identifier'] == '14a01a'
32663271
assert 'data' in response_obj['return_value']
32673272

32683273

@@ -3332,6 +3337,55 @@ def get_image() -> BinaryContent:
33323337
)
33333338

33343339

3340+
def test_tool_returning_file_url_with_identifier():
3341+
"""Test that a tool returning FileUrl subclasses with identifiers works correctly."""
3342+
3343+
def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
3344+
if len(messages) == 1:
3345+
return ModelResponse(parts=[ToolCallPart('get_files', {})])
3346+
else:
3347+
return ModelResponse(parts=[TextPart('Files received')])
3348+
3349+
agent = Agent(FunctionModel(llm))
3350+
3351+
@agent.tool_plain
3352+
def get_files():
3353+
"""Return various file URLs with custom identifiers."""
3354+
return [
3355+
ImageUrl(url='https://example.com/image.jpg', identifier='img_001'),
3356+
VideoUrl(url='https://example.com/video.mp4', identifier='vid_002'),
3357+
AudioUrl(url='https://example.com/audio.mp3', identifier='aud_003'),
3358+
DocumentUrl(url='https://example.com/document.pdf', identifier='doc_004'),
3359+
]
3360+
3361+
result = agent.run_sync('Get some files')
3362+
assert result.all_messages()[2] == snapshot(
3363+
ModelRequest(
3364+
parts=[
3365+
ToolReturnPart(
3366+
tool_name='get_files',
3367+
content=['See file img_001', 'See file vid_002', 'See file aud_003', 'See file doc_004'],
3368+
tool_call_id=IsStr(),
3369+
timestamp=IsNow(tz=timezone.utc),
3370+
),
3371+
UserPromptPart(
3372+
content=[
3373+
'This is file img_001:',
3374+
ImageUrl(url='https://example.com/image.jpg', identifier='img_001'),
3375+
'This is file vid_002:',
3376+
VideoUrl(url='https://example.com/video.mp4', identifier='vid_002'),
3377+
'This is file aud_003:',
3378+
AudioUrl(url='https://example.com/audio.mp3', identifier='aud_003'),
3379+
'This is file doc_004:',
3380+
DocumentUrl(url='https://example.com/document.pdf', identifier='doc_004'),
3381+
],
3382+
timestamp=IsNow(tz=timezone.utc),
3383+
),
3384+
]
3385+
)
3386+
)
3387+
3388+
33353389
def test_instructions_raise_error_when_system_prompt_is_set():
33363390
agent = Agent('test', instructions='An instructions!')
33373391

0 commit comments

Comments
 (0)