Skip to content

Commit 9a487b8

Browse files
committed
Handle function calls without text (#3)
1 parent f25a4e1 commit 9a487b8

File tree

2 files changed

+90
-2
lines changed

2 files changed

+90
-2
lines changed

pydantic_ai_slim/pydantic_ai/models/google.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,7 @@ async def _process_streamed_response(
413413
_timestamp=first_chunk.create_time or _utils.now_utc(),
414414
)
415415

416-
async def _map_messages(self, messages: list[ModelMessage]) -> tuple[ContentDict | None, list[ContentUnionDict]]:
416+
async def _map_messages(self, messages: list[ModelMessage]) -> tuple[ContentDict | None, list[ContentUnionDict]]: # noqa: C901
417417
contents: list[ContentUnionDict] = []
418418
system_parts: list[PartDict] = []
419419

@@ -457,7 +457,27 @@ async def _map_messages(self, messages: list[ModelMessage]) -> tuple[ContentDict
457457
message_parts = [{'text': ''}]
458458
contents.append({'role': 'user', 'parts': message_parts})
459459
elif isinstance(m, ModelResponse):
460-
contents.append(_content_model_response(m))
460+
model_content = _content_model_response(m)
461+
# Skip model responses with empty parts (e.g., thinking-only responses)
462+
if model_content.get('parts'):
463+
# Check if the model response contains only function calls without text
464+
if parts := model_content.get('parts', []):
465+
has_function_calls = False
466+
has_text_parts = False
467+
for part in parts:
468+
if isinstance(part, dict):
469+
if 'function_call' in part:
470+
has_function_calls = True
471+
if 'text' in part:
472+
has_text_parts = True
473+
474+
# If we only have function calls without text, add minimal text to satisfy Google API
475+
if has_function_calls and not has_text_parts:
476+
# Add a minimal text part to make the conversation valid for Google API
477+
parts.append({'text': 'I have completed the function calls above.'})
478+
model_content['parts'] = parts
479+
480+
contents.append(model_content)
461481
else:
462482
assert_never(m)
463483
if instructions := self._get_instructions(messages):

tests/models/test_google.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
ImageUrl,
2727
ModelRequest,
2828
ModelResponse,
29+
ModelResponsePart,
2930
PartDeltaEvent,
3031
PartStartEvent,
3132
RetryPromptPart,
@@ -1740,3 +1741,70 @@ async def get_user_country() -> str:
17401741
'What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge.',
17411742
usage_limits=UsageLimits(total_tokens_limit=9, count_tokens_before_request=True),
17421743
)
1744+
1745+
1746+
@pytest.mark.parametrize(
1747+
'model_parts,expected_contents',
1748+
[
1749+
pytest.param(
1750+
[ToolCallPart(tool_name='test_tool', args={'param': 'value'}, tool_call_id='call_123')],
1751+
[
1752+
{
1753+
'role': 'model',
1754+
'parts': [
1755+
{
1756+
'function_call': {
1757+
'args': {'param': 'value'},
1758+
'id': 'call_123',
1759+
'name': 'test_tool',
1760+
}
1761+
},
1762+
{'text': 'I have completed the function calls above.'},
1763+
],
1764+
}
1765+
],
1766+
id='function_call_without_text',
1767+
),
1768+
pytest.param(
1769+
[],
1770+
[],
1771+
id='empty_response_parts',
1772+
),
1773+
pytest.param(
1774+
[
1775+
ToolCallPart(tool_name='test_tool', args={'param': 'value'}, tool_call_id='call_123'),
1776+
TextPart(content='Here is the result:'),
1777+
],
1778+
[
1779+
{
1780+
'role': 'model',
1781+
'parts': [
1782+
{
1783+
'function_call': {
1784+
'args': {'param': 'value'},
1785+
'id': 'call_123',
1786+
'name': 'test_tool',
1787+
}
1788+
},
1789+
{'text': 'Here is the result:'},
1790+
],
1791+
}
1792+
],
1793+
id='function_call_with_text',
1794+
),
1795+
],
1796+
)
1797+
async def test_google_model_response_part_handling(
1798+
google_provider: GoogleProvider, model_parts: list[ModelResponsePart], expected_contents: list[dict[str, Any]]
1799+
):
1800+
"""Test Google model's handling of different response part combinations for API compatibility."""
1801+
model = GoogleModel('gemini-2.0-flash', provider=google_provider)
1802+
1803+
model_response = ModelResponse(
1804+
parts=model_parts,
1805+
usage=Usage(requests=1, request_tokens=10, response_tokens=5, total_tokens=15),
1806+
model_name='gemini-2.0-flash',
1807+
)
1808+
1809+
_, contents = await model._map_messages([model_response]) # pyright: ignore[reportPrivateUsage]
1810+
assert contents == expected_contents

0 commit comments

Comments
 (0)