diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index 731ccb5ca6..4f98d995a3 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -85,7 +85,7 @@ def otel_event(self, settings: InstrumentationSettings) -> Event: __repr__ = _utils.dataclasses_no_defaults_repr -@dataclass(repr=False) +@dataclass(init=False, repr=False) class FileUrl(ABC): """Abstract base class for any URL-based file.""" @@ -106,11 +106,29 @@ class FileUrl(ABC): - `GoogleModel`: `VideoUrl.vendor_metadata` is used as `video_metadata`: https://ai.google.dev/gemini-api/docs/video-understanding#customize-video-processing """ - @property + _media_type: str | None = field(init=False, repr=False) + + def __init__( + self, + url: str, + force_download: bool = False, + vendor_metadata: dict[str, Any] | None = None, + media_type: str | None = None, + ) -> None: + self.url = url + self.vendor_metadata = vendor_metadata + self.force_download = force_download + self._media_type = media_type + @abstractmethod - def media_type(self) -> str: + def _infer_media_type(self) -> str: """Return the media type of the file, based on the url.""" + @property + def media_type(self) -> str: + """Return the media type of the file, based on the url or the provided `_media_type`.""" + return self._media_type or self._infer_media_type() + @property @abstractmethod def format(self) -> str: @@ -119,7 +137,7 @@ def format(self) -> str: __repr__ = _utils.dataclasses_no_defaults_repr -@dataclass(repr=False) +@dataclass(init=False, repr=False) class VideoUrl(FileUrl): """A URL to a video.""" @@ -129,8 +147,18 @@ class VideoUrl(FileUrl): kind: Literal['video-url'] = 'video-url' """Type identifier, this is available on all parts as a discriminator.""" - @property - def media_type(self) -> VideoMediaType: + def __init__( + self, + url: str, + force_download: bool = False, + vendor_metadata: dict[str, Any] | None = None, + media_type: str | None = None, + kind: Literal['video-url'] = 'video-url', + ) -> None: + super().__init__(url=url, force_download=force_download, vendor_metadata=vendor_metadata, media_type=media_type) + self.kind = kind + + def _infer_media_type(self) -> VideoMediaType: """Return the media type of the video, based on the url.""" if self.url.endswith('.mkv'): return 'video/x-matroska' @@ -170,7 +198,7 @@ def format(self) -> VideoFormat: return _video_format_lookup[self.media_type] -@dataclass(repr=False) +@dataclass(init=False, repr=False) class AudioUrl(FileUrl): """A URL to an audio file.""" @@ -180,8 +208,18 @@ class AudioUrl(FileUrl): kind: Literal['audio-url'] = 'audio-url' """Type identifier, this is available on all parts as a discriminator.""" - @property - def media_type(self) -> AudioMediaType: + def __init__( + self, + url: str, + force_download: bool = False, + vendor_metadata: dict[str, Any] | None = None, + media_type: str | None = None, + kind: Literal['audio-url'] = 'audio-url', + ) -> None: + super().__init__(url=url, force_download=force_download, vendor_metadata=vendor_metadata, media_type=media_type) + self.kind = kind + + def _infer_media_type(self) -> AudioMediaType: """Return the media type of the audio file, based on the url. References: @@ -208,7 +246,7 @@ def format(self) -> AudioFormat: return _audio_format_lookup[self.media_type] -@dataclass(repr=False) +@dataclass(init=False, repr=False) class ImageUrl(FileUrl): """A URL to an image.""" @@ -218,8 +256,18 @@ class ImageUrl(FileUrl): kind: Literal['image-url'] = 'image-url' """Type identifier, this is available on all parts as a discriminator.""" - @property - def media_type(self) -> ImageMediaType: + def __init__( + self, + url: str, + force_download: bool = False, + vendor_metadata: dict[str, Any] | None = None, + media_type: str | None = None, + kind: Literal['image-url'] = 'image-url', + ) -> None: + super().__init__(url=url, force_download=force_download, vendor_metadata=vendor_metadata, media_type=media_type) + self.kind = kind + + def _infer_media_type(self) -> ImageMediaType: """Return the media type of the image, based on the url.""" if self.url.endswith(('.jpg', '.jpeg')): return 'image/jpeg' @@ -241,7 +289,7 @@ def format(self) -> ImageFormat: return _image_format_lookup[self.media_type] -@dataclass(repr=False) +@dataclass(init=False, repr=False) class DocumentUrl(FileUrl): """The URL of the document.""" @@ -251,8 +299,18 @@ class DocumentUrl(FileUrl): kind: Literal['document-url'] = 'document-url' """Type identifier, this is available on all parts as a discriminator.""" - @property - def media_type(self) -> str: + def __init__( + self, + url: str, + force_download: bool = False, + vendor_metadata: dict[str, Any] | None = None, + media_type: str | None = None, + kind: Literal['document-url'] = 'document-url', + ) -> None: + super().__init__(url=url, force_download=force_download, vendor_metadata=vendor_metadata, media_type=media_type) + self.kind = kind + + def _infer_media_type(self) -> str: """Return the media type of the document, based on the url.""" type_, _ = guess_type(self.url) if type_ is None: diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index 9ec1260d4e..ad5da243c1 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -411,7 +411,12 @@ async def _map_user_prompt(self, part: UserPromptPart) -> list[PartDict]: file_data_dict['video_metadata'] = item.vendor_metadata content.append(file_data_dict) # type: ignore elif isinstance(item, FileUrl): - if self.system == 'google-gla' or item.force_download: + if item.force_download or ( + # google-gla does not support passing file urls directly, except for youtube videos + # (see above) and files uploaded to the file API (which cannot be downloaded anyway) + self.system == 'google-gla' + and not item.url.startswith(r'https://generativelanguage.googleapis.com/v1beta/files') + ): downloaded_item = await download_item(item, data_format='base64') inline_data = {'data': downloaded_item['data'], 'mime_type': downloaded_item['data_type']} content.append({'inline_data': inline_data}) # type: ignore diff --git a/tests/test_messages.py b/tests/test_messages.py index d95aca32a7..e1f3a0b9fe 100644 --- a/tests/test_messages.py +++ b/tests/test_messages.py @@ -17,16 +17,20 @@ def test_image_url(): assert image_url.media_type == 'image/jpeg' assert image_url.format == 'jpeg' + image_url = ImageUrl(url='https://example.com/image', media_type='image/jpeg') + assert image_url.media_type == 'image/jpeg' + assert image_url.format == 'jpeg' -def test_video_url(): - with pytest.raises(ValueError, match='Unknown video file extension: https://example.com/video.potato'): - video_url = VideoUrl(url='https://example.com/video.potato') - video_url.media_type +def test_video_url(): video_url = VideoUrl(url='https://example.com/video.mp4') assert video_url.media_type == 'video/mp4' assert video_url.format == 'mp4' + video_url = VideoUrl(url='https://example.com/video', media_type='video/mp4') + assert video_url.media_type == 'video/mp4' + assert video_url.format == 'mp4' + @pytest.mark.parametrize( 'url,is_youtube', @@ -45,14 +49,14 @@ def test_youtube_video_url(url: str, is_youtube: bool): def test_document_url(): - with pytest.raises(ValueError, match='Unknown document file extension: https://example.com/document.potato'): - document_url = DocumentUrl(url='https://example.com/document.potato') - document_url.media_type - document_url = DocumentUrl(url='https://example.com/document.pdf') assert document_url.media_type == 'application/pdf' assert document_url.format == 'pdf' + document_url = DocumentUrl(url='https://example.com/document', media_type='application/pdf') + assert document_url.media_type == 'application/pdf' + assert document_url.format == 'pdf' + @pytest.mark.parametrize( 'media_type, format', @@ -129,6 +133,7 @@ def test_binary_content_document(media_type: str, format: str): pytest.param(AudioUrl('foobar.flac'), 'audio/flac', 'flac', id='flac'), pytest.param(AudioUrl('foobar.aiff'), 'audio/aiff', 'aiff', id='aiff'), pytest.param(AudioUrl('foobar.aac'), 'audio/aac', 'aac', id='aac'), + pytest.param(AudioUrl('foobar', media_type='audio/mpeg'), 'audio/mpeg', 'mp3', id='mp3'), ], ) def test_audio_url(audio_url: AudioUrl, media_type: str, format: str):