Skip to content

Commit fc6a2b2

Browse files
dprovDavid
andauthored
Support passing files uploaded to Gemini Files API and setting custom media type (#2270)
Co-authored-by: David <[email protected]>
1 parent da80f5d commit fc6a2b2

File tree

3 files changed

+92
-24
lines changed

3 files changed

+92
-24
lines changed

pydantic_ai_slim/pydantic_ai/messages.py

Lines changed: 73 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def otel_event(self, settings: InstrumentationSettings) -> Event:
8585
__repr__ = _utils.dataclasses_no_defaults_repr
8686

8787

88-
@dataclass(repr=False)
88+
@dataclass(init=False, repr=False)
8989
class FileUrl(ABC):
9090
"""Abstract base class for any URL-based file."""
9191

@@ -106,11 +106,29 @@ class FileUrl(ABC):
106106
- `GoogleModel`: `VideoUrl.vendor_metadata` is used as `video_metadata`: https://ai.google.dev/gemini-api/docs/video-understanding#customize-video-processing
107107
"""
108108

109-
@property
109+
_media_type: str | None = field(init=False, repr=False)
110+
111+
def __init__(
112+
self,
113+
url: str,
114+
force_download: bool = False,
115+
vendor_metadata: dict[str, Any] | None = None,
116+
media_type: str | None = None,
117+
) -> None:
118+
self.url = url
119+
self.vendor_metadata = vendor_metadata
120+
self.force_download = force_download
121+
self._media_type = media_type
122+
110123
@abstractmethod
111-
def media_type(self) -> str:
124+
def _infer_media_type(self) -> str:
112125
"""Return the media type of the file, based on the url."""
113126

127+
@property
128+
def media_type(self) -> str:
129+
"""Return the media type of the file, based on the url or the provided `_media_type`."""
130+
return self._media_type or self._infer_media_type()
131+
114132
@property
115133
@abstractmethod
116134
def format(self) -> str:
@@ -119,7 +137,7 @@ def format(self) -> str:
119137
__repr__ = _utils.dataclasses_no_defaults_repr
120138

121139

122-
@dataclass(repr=False)
140+
@dataclass(init=False, repr=False)
123141
class VideoUrl(FileUrl):
124142
"""A URL to a video."""
125143

@@ -129,8 +147,18 @@ class VideoUrl(FileUrl):
129147
kind: Literal['video-url'] = 'video-url'
130148
"""Type identifier, this is available on all parts as a discriminator."""
131149

132-
@property
133-
def media_type(self) -> VideoMediaType:
150+
def __init__(
151+
self,
152+
url: str,
153+
force_download: bool = False,
154+
vendor_metadata: dict[str, Any] | None = None,
155+
media_type: str | None = None,
156+
kind: Literal['video-url'] = 'video-url',
157+
) -> None:
158+
super().__init__(url=url, force_download=force_download, vendor_metadata=vendor_metadata, media_type=media_type)
159+
self.kind = kind
160+
161+
def _infer_media_type(self) -> VideoMediaType:
134162
"""Return the media type of the video, based on the url."""
135163
if self.url.endswith('.mkv'):
136164
return 'video/x-matroska'
@@ -170,7 +198,7 @@ def format(self) -> VideoFormat:
170198
return _video_format_lookup[self.media_type]
171199

172200

173-
@dataclass(repr=False)
201+
@dataclass(init=False, repr=False)
174202
class AudioUrl(FileUrl):
175203
"""A URL to an audio file."""
176204

@@ -180,8 +208,18 @@ class AudioUrl(FileUrl):
180208
kind: Literal['audio-url'] = 'audio-url'
181209
"""Type identifier, this is available on all parts as a discriminator."""
182210

183-
@property
184-
def media_type(self) -> AudioMediaType:
211+
def __init__(
212+
self,
213+
url: str,
214+
force_download: bool = False,
215+
vendor_metadata: dict[str, Any] | None = None,
216+
media_type: str | None = None,
217+
kind: Literal['audio-url'] = 'audio-url',
218+
) -> None:
219+
super().__init__(url=url, force_download=force_download, vendor_metadata=vendor_metadata, media_type=media_type)
220+
self.kind = kind
221+
222+
def _infer_media_type(self) -> AudioMediaType:
185223
"""Return the media type of the audio file, based on the url.
186224
187225
References:
@@ -208,7 +246,7 @@ def format(self) -> AudioFormat:
208246
return _audio_format_lookup[self.media_type]
209247

210248

211-
@dataclass(repr=False)
249+
@dataclass(init=False, repr=False)
212250
class ImageUrl(FileUrl):
213251
"""A URL to an image."""
214252

@@ -218,8 +256,18 @@ class ImageUrl(FileUrl):
218256
kind: Literal['image-url'] = 'image-url'
219257
"""Type identifier, this is available on all parts as a discriminator."""
220258

221-
@property
222-
def media_type(self) -> ImageMediaType:
259+
def __init__(
260+
self,
261+
url: str,
262+
force_download: bool = False,
263+
vendor_metadata: dict[str, Any] | None = None,
264+
media_type: str | None = None,
265+
kind: Literal['image-url'] = 'image-url',
266+
) -> None:
267+
super().__init__(url=url, force_download=force_download, vendor_metadata=vendor_metadata, media_type=media_type)
268+
self.kind = kind
269+
270+
def _infer_media_type(self) -> ImageMediaType:
223271
"""Return the media type of the image, based on the url."""
224272
if self.url.endswith(('.jpg', '.jpeg')):
225273
return 'image/jpeg'
@@ -241,7 +289,7 @@ def format(self) -> ImageFormat:
241289
return _image_format_lookup[self.media_type]
242290

243291

244-
@dataclass(repr=False)
292+
@dataclass(init=False, repr=False)
245293
class DocumentUrl(FileUrl):
246294
"""The URL of the document."""
247295

@@ -251,8 +299,18 @@ class DocumentUrl(FileUrl):
251299
kind: Literal['document-url'] = 'document-url'
252300
"""Type identifier, this is available on all parts as a discriminator."""
253301

254-
@property
255-
def media_type(self) -> str:
302+
def __init__(
303+
self,
304+
url: str,
305+
force_download: bool = False,
306+
vendor_metadata: dict[str, Any] | None = None,
307+
media_type: str | None = None,
308+
kind: Literal['document-url'] = 'document-url',
309+
) -> None:
310+
super().__init__(url=url, force_download=force_download, vendor_metadata=vendor_metadata, media_type=media_type)
311+
self.kind = kind
312+
313+
def _infer_media_type(self) -> str:
256314
"""Return the media type of the document, based on the url."""
257315
type_, _ = guess_type(self.url)
258316
if type_ is None:

pydantic_ai_slim/pydantic_ai/models/google.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,12 @@ async def _map_user_prompt(self, part: UserPromptPart) -> list[PartDict]:
411411
file_data_dict['video_metadata'] = item.vendor_metadata
412412
content.append(file_data_dict) # type: ignore
413413
elif isinstance(item, FileUrl):
414-
if self.system == 'google-gla' or item.force_download:
414+
if item.force_download or (
415+
# google-gla does not support passing file urls directly, except for youtube videos
416+
# (see above) and files uploaded to the file API (which cannot be downloaded anyway)
417+
self.system == 'google-gla'
418+
and not item.url.startswith(r'https://generativelanguage.googleapis.com/v1beta/files')
419+
):
415420
downloaded_item = await download_item(item, data_format='base64')
416421
inline_data = {'data': downloaded_item['data'], 'mime_type': downloaded_item['data_type']}
417422
content.append({'inline_data': inline_data}) # type: ignore

tests/test_messages.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,20 @@ def test_image_url():
1010
assert image_url.media_type == 'image/jpeg'
1111
assert image_url.format == 'jpeg'
1212

13+
image_url = ImageUrl(url='https://example.com/image', media_type='image/jpeg')
14+
assert image_url.media_type == 'image/jpeg'
15+
assert image_url.format == 'jpeg'
1316

14-
def test_video_url():
15-
with pytest.raises(ValueError, match='Unknown video file extension: https://example.com/video.potato'):
16-
video_url = VideoUrl(url='https://example.com/video.potato')
17-
video_url.media_type
1817

18+
def test_video_url():
1919
video_url = VideoUrl(url='https://example.com/video.mp4')
2020
assert video_url.media_type == 'video/mp4'
2121
assert video_url.format == 'mp4'
2222

23+
video_url = VideoUrl(url='https://example.com/video', media_type='video/mp4')
24+
assert video_url.media_type == 'video/mp4'
25+
assert video_url.format == 'mp4'
26+
2327

2428
@pytest.mark.parametrize(
2529
'url,is_youtube',
@@ -38,14 +42,14 @@ def test_youtube_video_url(url: str, is_youtube: bool):
3842

3943

4044
def test_document_url():
41-
with pytest.raises(ValueError, match='Unknown document file extension: https://example.com/document.potato'):
42-
document_url = DocumentUrl(url='https://example.com/document.potato')
43-
document_url.media_type
44-
4545
document_url = DocumentUrl(url='https://example.com/document.pdf')
4646
assert document_url.media_type == 'application/pdf'
4747
assert document_url.format == 'pdf'
4848

49+
document_url = DocumentUrl(url='https://example.com/document', media_type='application/pdf')
50+
assert document_url.media_type == 'application/pdf'
51+
assert document_url.format == 'pdf'
52+
4953

5054
@pytest.mark.parametrize(
5155
'media_type, format',
@@ -122,6 +126,7 @@ def test_binary_content_document(media_type: str, format: str):
122126
pytest.param(AudioUrl('foobar.flac'), 'audio/flac', 'flac', id='flac'),
123127
pytest.param(AudioUrl('foobar.aiff'), 'audio/aiff', 'aiff', id='aiff'),
124128
pytest.param(AudioUrl('foobar.aac'), 'audio/aac', 'aac', id='aac'),
129+
pytest.param(AudioUrl('foobar', media_type='audio/mpeg'), 'audio/mpeg', 'mp3', id='mp3'),
125130
],
126131
)
127132
def test_audio_url(audio_url: AudioUrl, media_type: str, format: str):

0 commit comments

Comments
 (0)