Skip to content

Commit 5510e81

Browse files
authored
Fix ImageUrl, VideoUrl, AudioUrl and DocumentUrl not being serializable (#2422)
1 parent 97834d6 commit 5510e81

File tree

3 files changed

+81
-16
lines changed

3 files changed

+81
-16
lines changed

pydantic_ai_slim/pydantic_ai/messages.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ class FileUrl(ABC):
106106
- `GoogleModel`: `VideoUrl.vendor_metadata` is used as `video_metadata`: https://ai.google.dev/gemini-api/docs/video-understanding#customize-video-processing
107107
"""
108108

109-
_media_type: str | None = field(init=False, repr=False)
109+
_media_type: str | None = field(init=False, repr=False, compare=False)
110110

111111
def __init__(
112112
self,
@@ -120,19 +120,21 @@ def __init__(
120120
self.force_download = force_download
121121
self._media_type = media_type
122122

123-
@abstractmethod
124-
def _infer_media_type(self) -> str:
125-
"""Return the media type of the file, based on the url."""
126-
127123
@property
128124
def media_type(self) -> str:
129-
"""Return the media type of the file, based on the url or the provided `_media_type`."""
125+
"""Return the media type of the file, based on the URL or the provided `media_type`."""
130126
return self._media_type or self._infer_media_type()
131127

128+
@abstractmethod
129+
def _infer_media_type(self) -> str:
130+
"""Infer the media type of the file based on the URL."""
131+
raise NotImplementedError
132+
132133
@property
133134
@abstractmethod
134135
def format(self) -> str:
135136
"""The file format."""
137+
raise NotImplementedError
136138

137139
__repr__ = _utils.dataclasses_no_defaults_repr
138140

@@ -182,7 +184,9 @@ def _infer_media_type(self) -> VideoMediaType:
182184
elif self.is_youtube:
183185
return 'video/mp4'
184186
else:
185-
raise ValueError(f'Unknown video file extension: {self.url}')
187+
raise ValueError(
188+
f'Could not infer media type from video URL: {self.url}. Explicitly provide a `media_type` instead.'
189+
)
186190

187191
@property
188192
def is_youtube(self) -> bool:
@@ -238,7 +242,9 @@ def _infer_media_type(self) -> AudioMediaType:
238242
if self.url.endswith('.aac'):
239243
return 'audio/aac'
240244

241-
raise ValueError(f'Unknown audio file extension: {self.url}')
245+
raise ValueError(
246+
f'Could not infer media type from audio URL: {self.url}. Explicitly provide a `media_type` instead.'
247+
)
242248

243249
@property
244250
def format(self) -> AudioFormat:
@@ -278,7 +284,9 @@ def _infer_media_type(self) -> ImageMediaType:
278284
elif self.url.endswith('.webp'):
279285
return 'image/webp'
280286
else:
281-
raise ValueError(f'Unknown image file extension: {self.url}')
287+
raise ValueError(
288+
f'Could not infer media type from image URL: {self.url}. Explicitly provide a `media_type` instead.'
289+
)
282290

283291
@property
284292
def format(self) -> ImageFormat:
@@ -324,10 +332,16 @@ def _infer_media_type(self) -> str:
324332
return 'application/pdf'
325333
elif self.url.endswith('.rtf'):
326334
return 'application/rtf'
335+
elif self.url.endswith('.docx'):
336+
return 'application/vnd.openxmlformats-officedocument.wordprocessingml.document'
337+
elif self.url.endswith('.xlsx'):
338+
return 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet'
327339

328340
type_, _ = guess_type(self.url)
329341
if type_ is None:
330-
raise ValueError(f'Unknown document file extension: {self.url}')
342+
raise ValueError(
343+
f'Could not infer media type from document URL: {self.url}. Explicitly provide a `media_type` instead.'
344+
)
331345
return type_
332346

333347
@property

tests/test_agent.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2922,7 +2922,7 @@ def validate_output(ctx: RunContext[None], o: Any) -> Any: # pragma: no cover
29222922
agent.run_sync('Hello', output_type=int)
29232923

29242924

2925-
def test_binary_content_all_messages_json():
2925+
def test_binary_content_serializable():
29262926
agent = Agent('test')
29272927

29282928
content = BinaryContent(data=b'Hello', media_type='text/plain')
@@ -2974,6 +2974,57 @@ def test_binary_content_all_messages_json():
29742974
assert messages == result.all_messages()
29752975

29762976

2977+
def test_image_url_serializable():
2978+
agent = Agent('test')
2979+
2980+
content = ImageUrl('https://example.com/chart', media_type='image/jpeg')
2981+
result = agent.run_sync(['Hello', content])
2982+
2983+
serialized = result.all_messages_json()
2984+
assert json.loads(serialized) == snapshot(
2985+
[
2986+
{
2987+
'parts': [
2988+
{
2989+
'content': [
2990+
'Hello',
2991+
{
2992+
'url': 'https://example.com/chart',
2993+
'force_download': False,
2994+
'vendor_metadata': None,
2995+
'kind': 'image-url',
2996+
},
2997+
],
2998+
'timestamp': IsStr(),
2999+
'part_kind': 'user-prompt',
3000+
}
3001+
],
3002+
'instructions': None,
3003+
'kind': 'request',
3004+
},
3005+
{
3006+
'parts': [{'content': 'success (no tool calls)', 'part_kind': 'text'}],
3007+
'usage': {
3008+
'requests': 1,
3009+
'request_tokens': 51,
3010+
'response_tokens': 4,
3011+
'total_tokens': 55,
3012+
'details': None,
3013+
},
3014+
'model_name': 'test',
3015+
'timestamp': IsStr(),
3016+
'kind': 'response',
3017+
'vendor_details': None,
3018+
'vendor_id': None,
3019+
},
3020+
]
3021+
)
3022+
3023+
# We also need to be able to round trip the serialized messages.
3024+
messages = ModelMessagesTypeAdapter.validate_json(serialized)
3025+
assert messages == result.all_messages()
3026+
3027+
29773028
def test_tool_return_part_binary_content_serialization():
29783029
"""Test that ToolReturnPart can properly serialize BinaryContent."""
29793030
png_data = b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\x0cIDATx\x9cc```\x00\x00\x00\x04\x00\x01\xf6\x178\x00\x00\x00\x00IEND\xaeB`\x82'

tests/test_messages.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def test_audio_url(audio_url: AudioUrl, media_type: str, format: str):
153153

154154

155155
def test_audio_url_invalid():
156-
with pytest.raises(ValueError, match='Unknown audio file extension: foobar.potato'):
156+
with pytest.raises(ValueError, match='Could not infer media type from audio URL: foobar.potato'):
157157
AudioUrl('foobar.potato').media_type
158158

159159

@@ -173,10 +173,10 @@ def test_image_url_formats(image_url: ImageUrl, media_type: str, format: str):
173173

174174

175175
def test_image_url_invalid():
176-
with pytest.raises(ValueError, match='Unknown image file extension: foobar.potato'):
176+
with pytest.raises(ValueError, match='Could not infer media type from image URL: foobar.potato'):
177177
ImageUrl('foobar.potato').media_type
178178

179-
with pytest.raises(ValueError, match='Unknown image file extension: foobar.potato'):
179+
with pytest.raises(ValueError, match='Could not infer media type from image URL: foobar.potato'):
180180
ImageUrl('foobar.potato').format
181181

182182

@@ -213,7 +213,7 @@ def test_document_url_formats(document_url: DocumentUrl, media_type: str, format
213213

214214

215215
def test_document_url_invalid():
216-
with pytest.raises(ValueError, match='Unknown document file extension: foobar.potato'):
216+
with pytest.raises(ValueError, match='Could not infer media type from document URL: foobar.potato'):
217217
DocumentUrl('foobar.potato').media_type
218218

219219
with pytest.raises(ValueError, match='Unknown document media type: text/x-python'):
@@ -301,7 +301,7 @@ def test_video_url_formats(video_url: VideoUrl, media_type: str, format: str):
301301

302302

303303
def test_video_url_invalid():
304-
with pytest.raises(ValueError, match='Unknown video file extension: foobar.potato'):
304+
with pytest.raises(ValueError, match='Could not infer media type from video URL: foobar.potato'):
305305
VideoUrl('foobar.potato').media_type
306306

307307

0 commit comments

Comments
 (0)