Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions pydantic_ai_slim/pydantic_ai/models/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ async def _process_streamed_response(
_timestamp=first_chunk.create_time or _utils.now_utc(),
)

async def _map_messages(self, messages: list[ModelMessage]) -> tuple[ContentDict | None, list[ContentUnionDict]]:
async def _map_messages(self, messages: list[ModelMessage]) -> tuple[ContentDict | None, list[ContentUnionDict]]: # noqa: C901 # noqa: C901
contents: list[ContentUnionDict] = []
system_parts: list[PartDict] = []

Expand Down Expand Up @@ -457,7 +457,27 @@ async def _map_messages(self, messages: list[ModelMessage]) -> tuple[ContentDict
message_parts = [{'text': ''}]
contents.append({'role': 'user', 'parts': message_parts})
elif isinstance(m, ModelResponse):
contents.append(_content_model_response(m))
model_content = _content_model_response(m)
# Skip model responses with empty parts (e.g., thinking-only responses)
if model_content.get('parts'):
# Check if the model response contains only function calls without text
if parts := model_content.get('parts', []):
has_function_calls = False
has_text_parts = False
for part in parts:
if isinstance(part, dict):
if 'function_call' in part:
has_function_calls = True
if 'text' in part:
has_text_parts = True

# If we only have function calls without text, add minimal text to satisfy Google API
if has_function_calls and not has_text_parts:
# Add a minimal text part to make the conversation valid for Google API
parts.append({'text': 'I have completed the function calls above.'})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need text or could be empty?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we do this inside _content_model_response based on the ModelResponse's parts rather than parsing the returned ContentDict? We already iterate over all items there, so we can keep track of what we've seen.

model_content['parts'] = parts

contents.append(model_content)
else:
assert_never(m)
if instructions := self._get_instructions(messages):
Expand Down
191 changes: 190 additions & 1 deletion tests/models/test_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import datetime
import os
from typing import Any
from typing import Any, Union

import pytest
from httpx import Timeout
Expand All @@ -26,6 +26,7 @@
ImageUrl,
ModelRequest,
ModelResponse,
ModelResponsePart,
PartDeltaEvent,
PartStartEvent,
RetryPromptPart,
Expand Down Expand Up @@ -1740,3 +1741,191 @@ async def get_user_country() -> str:
'What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge.',
usage_limits=UsageLimits(total_tokens_limit=9, count_tokens_before_request=True),
)


@pytest.mark.parametrize(
'model_parts,expected_contents',
[
pytest.param(
[ToolCallPart(tool_name='test_tool', args={'param': 'value'}, tool_call_id='call_123')],
[
{
'role': 'model',
'parts': [
{
'function_call': {
'args': {'param': 'value'},
'id': 'call_123',
'name': 'test_tool',
}
},
{'text': 'I have completed the function calls above.'},
],
}
],
id='function_call_without_text',
),
pytest.param(
[],
[],
id='empty_response_parts',
),
pytest.param(
[
ToolCallPart(tool_name='test_tool', args={'param': 'value'}, tool_call_id='call_123'),
TextPart(content='Here is the result:'),
],
[
{
'role': 'model',
'parts': [
{
'function_call': {
'args': {'param': 'value'},
'id': 'call_123',
'name': 'test_tool',
}
},
{'text': 'Here is the result:'},
],
}
],
id='function_call_with_text',
),
pytest.param(
[ThinkingPart(content='Let me think about this...')],
[],
id='thinking_only_response_skipped',
),
],
)
async def test_google_model_response_part_handling(
google_provider: GoogleProvider, model_parts: list[ModelResponsePart], expected_contents: list[dict[str, Any]]
):
"""Test Google model's handling of different response part combinations for API compatibility."""
model = GoogleModel('gemini-2.0-flash', provider=google_provider)

model_response = ModelResponse(
parts=model_parts,
usage=Usage(requests=1, request_tokens=10, response_tokens=5, total_tokens=15),
model_name='gemini-2.0-flash',
)

_, contents = await model._map_messages([model_response]) # pyright: ignore[reportPrivateUsage]
assert contents == expected_contents


class FunctionCallDict(TypedDict):
name: str
args: dict[str, Any]
id: str


class FunctionCallPartDict(TypedDict):
function_call: FunctionCallDict


class TextPartDict(TypedDict):
text: str


class OtherPartDict(TypedDict, total=False):
other_field: str


# Union of all possible part types we're testing
TestPartDict = Union[FunctionCallPartDict, TextPartDict, OtherPartDict, str] # str for non-dict parts


class MockContentResponse(TypedDict, total=False):
role: str
parts: list[TestPartDict]


class ExpectedContent(TypedDict, total=False):
role: str
parts: list[TestPartDict]


@pytest.mark.parametrize(
'mock_content_response,expected_contents',
[
pytest.param(
MockContentResponse(
{
'role': 'model',
'parts': [
'not_a_dict', # Non-dict part to test isinstance check
{'function_call': {'name': 'test', 'args': {}, 'id': '123'}},
],
}
),
[
ExpectedContent(
{
'role': 'model',
'parts': [
'not_a_dict',
{'function_call': {'name': 'test', 'args': {}, 'id': '123'}},
{'text': 'I have completed the function calls above.'},
],
}
)
],
id='non_dict_parts_with_function_call',
),
pytest.param(
MockContentResponse(
{
'role': 'model',
'parts': [
{'other_field': 'value'}, # Dict without function_call or text
{'function_call': {'name': 'test', 'args': {}, 'id': '123'}},
],
}
),
[
ExpectedContent(
{
'role': 'model',
'parts': [
{'other_field': 'value'},
{'function_call': {'name': 'test', 'args': {}, 'id': '123'}},
{'text': 'I have completed the function calls above.'},
],
}
)
],
id='dict_parts_without_function_call_or_text',
),
pytest.param(
MockContentResponse({'role': 'model'}), # No 'parts' key
[],
id='no_parts_key',
),
pytest.param(
MockContentResponse({'role': 'model', 'parts': []}), # Empty parts
[],
id='empty_parts_list',
),
],
)
async def test_google_model_response_edge_cases(
google_provider: GoogleProvider,
mock_content_response: MockContentResponse,
expected_contents: list[ExpectedContent],
):
"""Test Google model's _map_messages method with various edge cases for function call handling."""
from unittest.mock import patch

model = GoogleModel('gemini-2.0-flash', provider=google_provider)
model_response = ModelResponse(
parts=[ToolCallPart(tool_name='test_tool', args={'param': 'value'}, tool_call_id='call_123')],
usage=Usage(requests=1, request_tokens=10, response_tokens=5, total_tokens=15),
model_name='gemini-2.0-flash',
)

with patch('pydantic_ai.models.google._content_model_response') as mock_content:
mock_content.return_value = mock_content_response
_, contents = await model._map_messages([model_response]) # pyright: ignore[reportPrivateUsage]
assert contents == expected_contents
Loading