diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index 4b53a79565..1f2257616e 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -413,7 +413,7 @@ async def _process_streamed_response( _timestamp=first_chunk.create_time or _utils.now_utc(), ) - async def _map_messages(self, messages: list[ModelMessage]) -> tuple[ContentDict | None, list[ContentUnionDict]]: + async def _map_messages(self, messages: list[ModelMessage]) -> tuple[ContentDict | None, list[ContentUnionDict]]: # noqa: C901 # noqa: C901 contents: list[ContentUnionDict] = [] system_parts: list[PartDict] = [] @@ -457,7 +457,27 @@ async def _map_messages(self, messages: list[ModelMessage]) -> tuple[ContentDict message_parts = [{'text': ''}] contents.append({'role': 'user', 'parts': message_parts}) elif isinstance(m, ModelResponse): - contents.append(_content_model_response(m)) + model_content = _content_model_response(m) + # Skip model responses with empty parts (e.g., thinking-only responses) + if model_content.get('parts'): + # Check if the model response contains only function calls without text + if parts := model_content.get('parts', []): + has_function_calls = False + has_text_parts = False + for part in parts: + if isinstance(part, dict): + if 'function_call' in part: + has_function_calls = True + if 'text' in part: + has_text_parts = True + + # If we only have function calls without text, add minimal text to satisfy Google API + if has_function_calls and not has_text_parts: + # Add a minimal text part to make the conversation valid for Google API + parts.append({'text': 'I have completed the function calls above.'}) + model_content['parts'] = parts + + contents.append(model_content) else: assert_never(m) if instructions := self._get_instructions(messages): diff --git a/tests/models/test_google.py b/tests/models/test_google.py index 066566ee71..f95c52febe 100644 --- a/tests/models/test_google.py +++ b/tests/models/test_google.py @@ -2,7 +2,7 @@ import datetime import os -from typing import Any +from typing import Any, Union import pytest from httpx import Timeout @@ -26,6 +26,7 @@ ImageUrl, ModelRequest, ModelResponse, + ModelResponsePart, PartDeltaEvent, PartStartEvent, RetryPromptPart, @@ -1740,3 +1741,191 @@ async def get_user_country() -> str: 'What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge.', usage_limits=UsageLimits(total_tokens_limit=9, count_tokens_before_request=True), ) + + +@pytest.mark.parametrize( + 'model_parts,expected_contents', + [ + pytest.param( + [ToolCallPart(tool_name='test_tool', args={'param': 'value'}, tool_call_id='call_123')], + [ + { + 'role': 'model', + 'parts': [ + { + 'function_call': { + 'args': {'param': 'value'}, + 'id': 'call_123', + 'name': 'test_tool', + } + }, + {'text': 'I have completed the function calls above.'}, + ], + } + ], + id='function_call_without_text', + ), + pytest.param( + [], + [], + id='empty_response_parts', + ), + pytest.param( + [ + ToolCallPart(tool_name='test_tool', args={'param': 'value'}, tool_call_id='call_123'), + TextPart(content='Here is the result:'), + ], + [ + { + 'role': 'model', + 'parts': [ + { + 'function_call': { + 'args': {'param': 'value'}, + 'id': 'call_123', + 'name': 'test_tool', + } + }, + {'text': 'Here is the result:'}, + ], + } + ], + id='function_call_with_text', + ), + pytest.param( + [ThinkingPart(content='Let me think about this...')], + [], + id='thinking_only_response_skipped', + ), + ], +) +async def test_google_model_response_part_handling( + google_provider: GoogleProvider, model_parts: list[ModelResponsePart], expected_contents: list[dict[str, Any]] +): + """Test Google model's handling of different response part combinations for API compatibility.""" + model = GoogleModel('gemini-2.0-flash', provider=google_provider) + + model_response = ModelResponse( + parts=model_parts, + usage=Usage(requests=1, request_tokens=10, response_tokens=5, total_tokens=15), + model_name='gemini-2.0-flash', + ) + + _, contents = await model._map_messages([model_response]) # pyright: ignore[reportPrivateUsage] + assert contents == expected_contents + + +class FunctionCallDict(TypedDict): + name: str + args: dict[str, Any] + id: str + + +class FunctionCallPartDict(TypedDict): + function_call: FunctionCallDict + + +class TextPartDict(TypedDict): + text: str + + +class OtherPartDict(TypedDict, total=False): + other_field: str + + +# Union of all possible part types we're testing +TestPartDict = Union[FunctionCallPartDict, TextPartDict, OtherPartDict, str] # str for non-dict parts + + +class MockContentResponse(TypedDict, total=False): + role: str + parts: list[TestPartDict] + + +class ExpectedContent(TypedDict, total=False): + role: str + parts: list[TestPartDict] + + +@pytest.mark.parametrize( + 'mock_content_response,expected_contents', + [ + pytest.param( + MockContentResponse( + { + 'role': 'model', + 'parts': [ + 'not_a_dict', # Non-dict part to test isinstance check + {'function_call': {'name': 'test', 'args': {}, 'id': '123'}}, + ], + } + ), + [ + ExpectedContent( + { + 'role': 'model', + 'parts': [ + 'not_a_dict', + {'function_call': {'name': 'test', 'args': {}, 'id': '123'}}, + {'text': 'I have completed the function calls above.'}, + ], + } + ) + ], + id='non_dict_parts_with_function_call', + ), + pytest.param( + MockContentResponse( + { + 'role': 'model', + 'parts': [ + {'other_field': 'value'}, # Dict without function_call or text + {'function_call': {'name': 'test', 'args': {}, 'id': '123'}}, + ], + } + ), + [ + ExpectedContent( + { + 'role': 'model', + 'parts': [ + {'other_field': 'value'}, + {'function_call': {'name': 'test', 'args': {}, 'id': '123'}}, + {'text': 'I have completed the function calls above.'}, + ], + } + ) + ], + id='dict_parts_without_function_call_or_text', + ), + pytest.param( + MockContentResponse({'role': 'model'}), # No 'parts' key + [], + id='no_parts_key', + ), + pytest.param( + MockContentResponse({'role': 'model', 'parts': []}), # Empty parts + [], + id='empty_parts_list', + ), + ], +) +async def test_google_model_response_edge_cases( + google_provider: GoogleProvider, + mock_content_response: MockContentResponse, + expected_contents: list[ExpectedContent], +): + """Test Google model's _map_messages method with various edge cases for function call handling.""" + from unittest.mock import patch + + model = GoogleModel('gemini-2.0-flash', provider=google_provider) + model_response = ModelResponse( + parts=[ToolCallPart(tool_name='test_tool', args={'param': 'value'}, tool_call_id='call_123')], + usage=Usage(requests=1, request_tokens=10, response_tokens=5, total_tokens=15), + model_name='gemini-2.0-flash', + ) + + with patch('pydantic_ai.models.google._content_model_response') as mock_content: + mock_content.return_value = mock_content_response + _, contents = await model._map_messages([model_response]) # pyright: ignore[reportPrivateUsage] + assert contents == expected_contents