Skip to content

Commit 77649aa

Browse files
authored
fix: send the right data format to gemini image input (#991)
1 parent 2f2713b commit 77649aa

File tree

4 files changed

+42
-28
lines changed

4 files changed

+42
-28
lines changed

pydantic_ai_slim/pydantic_ai/models/gemini.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -320,18 +320,24 @@ async def _map_user_prompt(part: UserPromptPart) -> list[_GeminiPartUnion]:
320320
content.append({'text': item})
321321
elif isinstance(item, BinaryContent):
322322
base64_encoded = base64.b64encode(item.data).decode('utf-8')
323-
content.append(_GeminiInlineDataPart(data=base64_encoded, mime_type=item.media_type))
323+
content.append(
324+
_GeminiInlineDataPart(inline_data={'data': base64_encoded, 'mime_type': item.media_type})
325+
)
324326
elif isinstance(item, (AudioUrl, ImageUrl)):
325327
try:
326-
content.append(_GeminiFileDataData(file_uri=item.url, mime_type=item.media_type))
328+
content.append(
329+
_GeminiFileDataPart(file_data={'file_uri': item.url, 'mime_type': item.media_type})
330+
)
327331
except ValueError:
328332
# Download the file if can't find the mime type.
329333
client = cached_async_http_client()
330334
response = await client.get(item.url, follow_redirects=True)
331335
response.raise_for_status()
332336
base64_encoded = base64.b64encode(response.content).decode('utf-8')
333337
content.append(
334-
_GeminiInlineDataPart(data=base64_encoded, mime_type=response.headers['Content-Type'])
338+
_GeminiInlineDataPart(
339+
inline_data={'data': base64_encoded, 'mime_type': response.headers['Content-Type']}
340+
)
335341
)
336342
else:
337343
assert_never(item)
@@ -528,20 +534,28 @@ class _GeminiTextPart(TypedDict):
528534
text: str
529535

530536

537+
class _GeminiInlineData(TypedDict):
538+
data: str
539+
mime_type: Annotated[str, pydantic.Field(alias='mimeType')]
540+
541+
531542
class _GeminiInlineDataPart(TypedDict):
532543
"""See <https://ai.google.dev/api/caching#Blob>."""
533544

534-
data: str
535-
mime_type: Annotated[str, pydantic.Field(alias='mimeType')]
545+
inline_data: Annotated[_GeminiInlineData, pydantic.Field(alias='inlineData')]
536546

537547

538-
class _GeminiFileDataData(TypedDict):
548+
class _GeminiFileData(TypedDict):
539549
"""See <https://ai.google.dev/api/caching#FileData>."""
540550

541551
file_uri: Annotated[str, pydantic.Field(alias='fileUri')]
542552
mime_type: Annotated[str, pydantic.Field(alias='mimeType')]
543553

544554

555+
class _GeminiFileDataPart(TypedDict):
556+
file_data: Annotated[_GeminiFileData, pydantic.Field(alias='fileData')]
557+
558+
545559
class _GeminiFunctionCallPart(TypedDict):
546560
function_call: Annotated[_GeminiFunctionCall, pydantic.Field(alias='functionCall')]
547561

@@ -617,7 +631,7 @@ def _part_discriminator(v: Any) -> str:
617631
Annotated[_GeminiFunctionCallPart, pydantic.Tag('function_call')],
618632
Annotated[_GeminiFunctionResponsePart, pydantic.Tag('function_response')],
619633
Annotated[_GeminiInlineDataPart, pydantic.Tag('inline_data')],
620-
Annotated[_GeminiFileDataData, pydantic.Tag('file_data')],
634+
Annotated[_GeminiFileDataPart, pydantic.Tag('file_data')],
621635
],
622636
pydantic.Discriminator(_part_discriminator),
623637
]

tests/models/cassettes/test_gemini/test_image_as_binary_content_input.yaml

Lines changed: 7 additions & 7 deletions
Large diffs are not rendered by default.

tests/models/cassettes/test_gemini/test_image_url_input.yaml

Lines changed: 12 additions & 12 deletions
Large diffs are not rendered by default.

tests/models/test_gemini.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -998,7 +998,7 @@ async def test_image_as_binary_content_input(
998998
agent = Agent(m)
999999

10001000
result = await agent.run(['What is the name of this fruit?', image_content])
1001-
assert result.data == snapshot('The fruit in the image is a Kiwi.')
1001+
assert result.data == snapshot('The fruit in the image is a kiwi.')
10021002

10031003

10041004
@pytest.mark.vcr()
@@ -1009,4 +1009,4 @@ async def test_image_url_input(allow_model_requests: None, gemini_api_key: str)
10091009
image_url = ImageUrl(url='https://goo.gle/instrument-img')
10101010

10111011
result = await agent.run(['What is the name of this fruit?', image_url])
1012-
assert result.data == snapshot('The image shows an organ, not a fruit.')
1012+
assert result.data == snapshot('This is not a fruit, it is an organ console.')

0 commit comments

Comments
 (0)