Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 22 additions & 18 deletions pydantic_ai_slim/pydantic_ai/messages.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why aren't we using the mimetypes stdlib module? mimetypes.guess_type() already parses URLs and the current implementation doesn't take into account case insensitivity, etc.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Viicos Interestingly we already use that in DocumentUrl._infer_media_type, after checking a bunch of types ourselves :/

@fedexman Can you see if we can use mimetypes.guess_type() for all of these?

The method can be changed to just return str rather than XMediaType, as I don't think that type is used on any public fields.

Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from os import PathLike
from pathlib import Path
from typing import TYPE_CHECKING, Annotated, Any, Literal, TypeAlias, cast, overload
from urllib.parse import urlparse

import pydantic
import pydantic_core
Expand Down Expand Up @@ -228,21 +229,22 @@ def __init__(

def _infer_media_type(self) -> VideoMediaType:
"""Return the media type of the video, based on the url."""
if self.url.endswith('.mkv'):
path = urlparse(self.url).path
if path.endswith('.mkv'):
return 'video/x-matroska'
elif self.url.endswith('.mov'):
elif path.endswith('.mov'):
return 'video/quicktime'
elif self.url.endswith('.mp4'):
elif path.endswith('.mp4'):
return 'video/mp4'
elif self.url.endswith('.webm'):
elif path.endswith('.webm'):
return 'video/webm'
elif self.url.endswith('.flv'):
elif path.endswith('.flv'):
return 'video/x-flv'
elif self.url.endswith(('.mpeg', '.mpg')):
elif path.endswith(('.mpeg', '.mpg')):
return 'video/mpeg'
elif self.url.endswith('.wmv'):
elif path.endswith('.wmv'):
return 'video/x-ms-wmv'
elif self.url.endswith('.three_gp'):
elif path.endswith('.three_gp'):
return 'video/3gpp'
# Assume that YouTube videos are mp4 because there would be no extension
# to infer from. This should not be a problem, as Gemini disregards media
Expand Down Expand Up @@ -308,17 +310,18 @@ def _infer_media_type(self) -> AudioMediaType:
References:
- Gemini: https://ai.google.dev/gemini-api/docs/audio#supported-formats
"""
if self.url.endswith('.mp3'):
path = urlparse(self.url).path
if path.endswith('.mp3'):
return 'audio/mpeg'
if self.url.endswith('.wav'):
if path.endswith('.wav'):
return 'audio/wav'
if self.url.endswith('.flac'):
if path.endswith('.flac'):
return 'audio/flac'
if self.url.endswith('.oga'):
if path.endswith('.oga'):
return 'audio/ogg'
if self.url.endswith('.aiff'):
if path.endswith('.aiff'):
return 'audio/aiff'
if self.url.endswith('.aac'):
if path.endswith('.aac'):
return 'audio/aac'

raise ValueError(
Expand Down Expand Up @@ -367,13 +370,14 @@ def __init__(

def _infer_media_type(self) -> ImageMediaType:
"""Return the media type of the image, based on the url."""
if self.url.endswith(('.jpg', '.jpeg')):
path = urlparse(self.url).path
if path.endswith(('.jpg', '.jpeg')):
return 'image/jpeg'
elif self.url.endswith('.png'):
elif path.endswith('.png'):
return 'image/png'
elif self.url.endswith('.gif'):
elif path.endswith('.gif'):
return 'image/gif'
elif self.url.endswith('.webp'):
elif path.endswith('.webp'):
return 'image/webp'
else:
raise ValueError(
Expand Down
72 changes: 72 additions & 0 deletions tests/test_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,78 @@ def test_video_url_invalid():
VideoUrl('foobar.potato').media_type


@pytest.mark.parametrize(
'url,media_type,format',
[
pytest.param(
'https://example.com/video.mp4?query=param',
'video/mp4',
'mp4',
id='mp4_with_query',
),
pytest.param(
'https://example.com/video.webm?X-Amz-Algorithm=AWS4-HMAC-SHA256',
'video/webm',
'webm',
id='webm_with_aws_params',
),
],
)
def test_video_url_with_query_parameters(url: str, media_type: str, format: str):
"""Test that VideoUrl correctly infers media type from URLs with query parameters (e.g., presigned URLs)."""
video_url = VideoUrl(url)
assert video_url.media_type == media_type
assert video_url.format == format


@pytest.mark.parametrize(
'url,media_type,format',
[
pytest.param(
'https://example.com/audio.mp3?query=param',
'audio/mpeg',
'mp3',
id='mp3_with_query',
),
pytest.param(
'https://example.com/audio.wav?X-Amz-Algorithm=AWS4-HMAC-SHA256',
'audio/wav',
'wav',
id='wav_with_aws_params',
),
],
)
def test_audio_url_with_query_parameters(url: str, media_type: str, format: str):
"""Test that AudioUrl correctly infers media type from URLs with query parameters (e.g., presigned URLs)."""
audio_url = AudioUrl(url)
assert audio_url.media_type == media_type
assert audio_url.format == format


@pytest.mark.parametrize(
'url,media_type,format',
[
pytest.param(
'https://example.com/image.png?query=param',
'image/png',
'png',
id='png_with_query',
),
pytest.param(
'https://example.com/image.jpg?X-Amz-Algorithm=AWS4-HMAC-SHA256',
'image/jpeg',
'jpeg',
id='jpg_with_aws_params',
),
],
)
def test_image_url_with_query_parameters(url: str, media_type: str, format: str):
"""Test that ImageUrl correctly infers media type from URLs with query parameters (e.g., presigned URLs)."""
image_url = ImageUrl(url)
assert image_url.media_type == media_type
assert image_url.format == format


def test_thinking_part_delta_apply_to_thinking_part_delta():
"""Test lines 768-775: Apply ThinkingPartDelta to another ThinkingPartDelta."""
original_delta = ThinkingPartDelta(
Expand Down