Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,10 @@ async def _map_messages(self, messages: list[ModelMessage]) -> tuple[ContentDict
message_parts = [{'text': ''}]
contents.append({'role': 'user', 'parts': message_parts})
elif isinstance(m, ModelResponse):
contents.append(_content_model_response(m))
model_content = _content_model_response(m)
# Skip model responses with empty parts (e.g., thinking-only responses)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ethanabrooks Does this mean this would be a non-issue once we start sending back thinking parts? (#2594) I mean to work on that next week

if model_content.get('parts'):
contents.append(model_content)
else:
assert_never(m)
if instructions := self._get_instructions(messages):
Expand Down Expand Up @@ -594,12 +597,17 @@ def timestamp(self) -> datetime:

def _content_model_response(m: ModelResponse) -> ContentDict:
parts: list[PartDict] = []
has_function_calls = False
has_text_parts = False

for item in m.parts:
if isinstance(item, ToolCallPart):
function_call = FunctionCallDict(name=item.tool_name, args=item.args_as_dict(), id=item.tool_call_id)
parts.append({'function_call': function_call})
has_function_calls = True
elif isinstance(item, TextPart):
parts.append({'text': item.content})
has_text_parts = True
elif isinstance(item, ThinkingPart): # pragma: no cover
# NOTE: We don't send ThinkingPart to the providers yet. If you are unsatisfied with this,
# please open an issue. The below code is the code to send thinking to the provider.
Expand All @@ -615,6 +623,11 @@ def _content_model_response(m: ModelResponse) -> ContentDict:
parts.append({'code_execution_result': item.content})
else:
assert_never(item)

# If we only have function calls without text, add minimal text to satisfy Google API
if has_function_calls and not has_text_parts:
parts.append({'text': 'I have completed the function calls above.'})

return ContentDict(role='model', parts=parts)


Expand Down
189 changes: 188 additions & 1 deletion tests/models/test_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
ImageUrl,
ModelRequest,
ModelResponse,
ModelResponsePart,
PartDeltaEvent,
PartStartEvent,
RetryPromptPart,
Expand Down Expand Up @@ -60,7 +61,12 @@
Outcome,
)

from pydantic_ai.models.google import GoogleModel, GoogleModelSettings, _metadata_as_usage # type: ignore
from pydantic_ai.models.google import (
GoogleModel,
GoogleModelSettings,
_content_model_response, # pyright: ignore[reportPrivateUsage]
_metadata_as_usage,
)
from pydantic_ai.providers.google import GoogleProvider

pytestmark = [
Expand Down Expand Up @@ -1828,3 +1834,184 @@ class CityLocation(BaseModel):
agent = Agent(m, output_type=PromptedOutput(CityLocation), builtin_tools=[UrlContextTool()])
result = await agent.run('What is the largest city in Mexico?')
assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico'))


@pytest.mark.parametrize(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be great if we had a minimal reproducible example that triggers the error against the actual API, so we can verify what parts of the fix are necessary (e.g. is sending back thinking parts enough? does a one-space string work? is it only Vertex or also GLA?). Right now, we're effectively duplicating the fix in tests, rather than verifying that a previously-failing example now succeeds.

'model_parts,expected_contents',
[
pytest.param(
[ToolCallPart(tool_name='test_tool', args={'param': 'value'}, tool_call_id='call_123')],
[
{
'role': 'model',
'parts': [
{
'function_call': {
'args': {'param': 'value'},
'id': 'call_123',
'name': 'test_tool',
}
},
{'text': 'I have completed the function calls above.'},
],
}
],
id='function_call_without_text',
),
pytest.param(
[],
[],
id='empty_response_parts',
),
pytest.param(
[
ToolCallPart(tool_name='test_tool', args={'param': 'value'}, tool_call_id='call_123'),
TextPart(content='Here is the result:'),
],
[
{
'role': 'model',
'parts': [
{
'function_call': {
'args': {'param': 'value'},
'id': 'call_123',
'name': 'test_tool',
}
},
{'text': 'Here is the result:'},
],
}
],
id='function_call_with_text',
),
pytest.param(
[ThinkingPart(content='Let me think about this...')],
[],
id='thinking_only_response_skipped',
),
],
)
async def test_google_model_response_part_handling(
google_provider: GoogleProvider, model_parts: list[ModelResponsePart], expected_contents: list[dict[str, Any]]
):
"""Test Google model's handling of different response part combinations for API compatibility."""
model = GoogleModel('gemini-2.0-flash', provider=google_provider)

model_response = ModelResponse(
parts=model_parts,
usage=RequestUsage(input_tokens=10, output_tokens=5),
model_name='gemini-2.0-flash',
)

_, contents = await model._map_messages([model_response]) # pyright: ignore[reportPrivateUsage]
assert contents == expected_contents


@pytest.mark.parametrize(
'model_parts,expected_parts',
[
pytest.param(
[ToolCallPart(tool_name='test_tool', args={'param': 'value'}, tool_call_id='call_123')],
[
{
'function_call': {
'args': {'param': 'value'},
'id': 'call_123',
'name': 'test_tool',
}
},
{'text': 'I have completed the function calls above.'},
],
id='function_call_only_adds_text',
),
pytest.param(
[
ToolCallPart(tool_name='test_tool', args={'param': 'value'}, tool_call_id='call_123'),
TextPart(content='Here is the result:'),
],
[
{
'function_call': {
'args': {'param': 'value'},
'id': 'call_123',
'name': 'test_tool',
}
},
{'text': 'Here is the result:'},
],
id='function_call_with_text_no_addition',
),
pytest.param(
[TextPart(content='Just text response')],
[{'text': 'Just text response'}],
id='text_only_no_addition',
),
pytest.param(
[
ToolCallPart(tool_name='tool1', args={'a': 1}, tool_call_id='call_1'),
ToolCallPart(tool_name='tool2', args={'b': 2}, tool_call_id='call_2'),
],
[
{
'function_call': {
'args': {'a': 1},
'id': 'call_1',
'name': 'tool1',
}
},
{
'function_call': {
'args': {'b': 2},
'id': 'call_2',
'name': 'tool2',
}
},
{'text': 'I have completed the function calls above.'},
],
id='multiple_function_calls_only',
),
pytest.param(
[ThinkingPart(content='Let me think...')],
[],
id='thinking_only_empty_parts',
),
pytest.param(
[
ThinkingPart(content='Let me think...'),
ToolCallPart(tool_name='test_tool', args={'param': 'value'}, tool_call_id='call_123'),
],
[
{
'function_call': {
'args': {'param': 'value'},
'id': 'call_123',
'name': 'test_tool',
}
},
{'text': 'I have completed the function calls above.'},
],
id='thinking_and_function_call',
),
pytest.param(
[],
[],
id='empty_parts',
),
],
)
def test_content_model_response_function_call_handling(
model_parts: list[ModelResponsePart], expected_parts: list[dict[str, Any]]
):
"""Test _content_model_response function's handling of function calls without text."""

model_response = ModelResponse(
parts=model_parts,
usage=RequestUsage(input_tokens=10, output_tokens=5),
model_name='gemini-2.0-flash',
)

result = _content_model_response(model_response)

expected_result = {'role': 'model', 'parts': expected_parts}
assert result == expected_result
Loading