Skip to content

Commit babc23b

Browse files
GDaamnDouweM
andauthored
add identifier field to BinaryContent class (#2231)
Co-authored-by: Douwe Maan <[email protected]>
1 parent 490c3b4 commit babc23b

File tree

3 files changed

+63
-3
lines changed

3 files changed

+63
-3
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -759,7 +759,7 @@ def process_content(content: Any) -> Any:
759759
)
760760
elif isinstance(content, _messages.MultiModalContentTypes):
761761
if isinstance(content, _messages.BinaryContent):
762-
identifier = multi_modal_content_identifier(content.data)
762+
identifier = content.identifier or multi_modal_content_identifier(content.data)
763763
else:
764764
identifier = multi_modal_content_identifier(content.url)
765765

pydantic_ai_slim/pydantic_ai/messages.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,14 @@ class BinaryContent:
282282
media_type: AudioMediaType | ImageMediaType | DocumentMediaType | str
283283
"""The media type of the binary data."""
284284

285+
identifier: str | None = None
286+
"""Identifier for the binary content, such as a URL or unique ID.
287+
288+
This identifier can be provided to the model in a message to allow it to refer to this file in a tool call argument, and the tool can look up the file in question by iterating over the message history and finding the matching `BinaryContent`.
289+
290+
This identifier is only automatically passed to the model when the `BinaryContent` is returned by a tool. If you're passing the `BinaryContent` as a user message, it's up to you to include a separate text part with the identifier, e.g. "This is file <identifier>:" preceding the `BinaryContent`.
291+
"""
292+
285293
vendor_metadata: dict[str, Any] | None = None
286294
"""Vendor-specific metadata for the file.
287295

tests/test_agent.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2938,7 +2938,13 @@ def test_binary_content_all_messages_json():
29382938
{
29392939
'content': [
29402940
'Hello',
2941-
{'data': 'SGVsbG8=', 'media_type': 'text/plain', 'vendor_metadata': None, 'kind': 'binary'},
2941+
{
2942+
'data': 'SGVsbG8=',
2943+
'media_type': 'text/plain',
2944+
'vendor_metadata': None,
2945+
'kind': 'binary',
2946+
'identifier': None,
2947+
},
29422948
],
29432949
'timestamp': IsStr(),
29442950
'part_kind': 'user-prompt',
@@ -2973,7 +2979,7 @@ def test_binary_content_all_messages_json():
29732979
def test_tool_return_part_binary_content_serialization():
29742980
"""Test that ToolReturnPart can properly serialize BinaryContent."""
29752981
png_data = b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\x0cIDATx\x9cc```\x00\x00\x00\x04\x00\x01\xf6\x178\x00\x00\x00\x00IEND\xaeB`\x82'
2976-
binary_content = BinaryContent(png_data, media_type='image/png')
2982+
binary_content = BinaryContent(png_data, media_type='image/png', identifier='image_id_1')
29772983

29782984
tool_return = ToolReturnPart(tool_name='test_tool', content=binary_content, tool_call_id='test_call_123')
29792985

@@ -2982,10 +2988,12 @@ def test_tool_return_part_binary_content_serialization():
29822988
assert '"kind":"binary"' in response_str
29832989
assert '"media_type":"image/png"' in response_str
29842990
assert '"data":"' in response_str
2991+
assert '"identifier":"image_id_1"' in response_str
29852992

29862993
response_obj = tool_return.model_response_object()
29872994
assert response_obj['return_value']['kind'] == 'binary'
29882995
assert response_obj['return_value']['media_type'] == 'image/png'
2996+
assert response_obj['return_value']['identifier'] == 'image_id_1'
29892997
assert 'data' in response_obj['return_value']
29902998

29912999

@@ -3011,6 +3019,50 @@ def get_image() -> BinaryContent:
30113019
assert result.output == 'Image received'
30123020

30133021

3022+
def test_tool_returning_binary_content_with_identifier():
3023+
"""Test that a tool returning BinaryContent directly works correctly."""
3024+
3025+
def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
3026+
if len(messages) == 1:
3027+
return ModelResponse(parts=[ToolCallPart('get_image', {})])
3028+
else:
3029+
return ModelResponse(parts=[TextPart('Image received')])
3030+
3031+
agent = Agent(FunctionModel(llm))
3032+
3033+
@agent.tool_plain
3034+
def get_image() -> BinaryContent:
3035+
"""Return a simple image."""
3036+
png_data = b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\x0cIDATx\x9cc```\x00\x00\x00\x04\x00\x01\xf6\x178\x00\x00\x00\x00IEND\xaeB`\x82'
3037+
return BinaryContent(png_data, media_type='image/png', identifier='image_id_1')
3038+
3039+
# This should work without the serialization error
3040+
result = agent.run_sync('Get an image')
3041+
assert result.all_messages()[2] == snapshot(
3042+
ModelRequest(
3043+
parts=[
3044+
ToolReturnPart(
3045+
tool_name='get_image',
3046+
content='See file image_id_1',
3047+
tool_call_id=IsStr(),
3048+
timestamp=IsNow(tz=timezone.utc),
3049+
),
3050+
UserPromptPart(
3051+
content=[
3052+
'This is file image_id_1:',
3053+
BinaryContent(
3054+
data=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\x0cIDATx\x9cc```\x00\x00\x00\x04\x00\x01\xf6\x178\x00\x00\x00\x00IEND\xaeB`\x82',
3055+
media_type='image/png',
3056+
identifier='image_id_1',
3057+
),
3058+
],
3059+
timestamp=IsNow(tz=timezone.utc),
3060+
),
3061+
]
3062+
)
3063+
)
3064+
3065+
30143066
def test_instructions_raise_error_when_system_prompt_is_set():
30153067
agent = Agent('test', instructions='An instructions!')
30163068

0 commit comments

Comments
 (0)