Skip to content

Commit cacaf7b

Browse files
committed
Handle function calls without text (#3)
1 parent f25a4e1 commit cacaf7b

File tree

2 files changed

+211
-2
lines changed

2 files changed

+211
-2
lines changed

pydantic_ai_slim/pydantic_ai/models/google.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,7 @@ async def _process_streamed_response(
413413
_timestamp=first_chunk.create_time or _utils.now_utc(),
414414
)
415415

416-
async def _map_messages(self, messages: list[ModelMessage]) -> tuple[ContentDict | None, list[ContentUnionDict]]:
416+
async def _map_messages(self, messages: list[ModelMessage]) -> tuple[ContentDict | None, list[ContentUnionDict]]: # noqa: C901 # noqa: C901
417417
contents: list[ContentUnionDict] = []
418418
system_parts: list[PartDict] = []
419419

@@ -457,7 +457,27 @@ async def _map_messages(self, messages: list[ModelMessage]) -> tuple[ContentDict
457457
message_parts = [{'text': ''}]
458458
contents.append({'role': 'user', 'parts': message_parts})
459459
elif isinstance(m, ModelResponse):
460-
contents.append(_content_model_response(m))
460+
model_content = _content_model_response(m)
461+
# Skip model responses with empty parts (e.g., thinking-only responses)
462+
if model_content.get('parts'):
463+
# Check if the model response contains only function calls without text
464+
if parts := model_content.get('parts', []):
465+
has_function_calls = False
466+
has_text_parts = False
467+
for part in parts:
468+
if isinstance(part, dict):
469+
if 'function_call' in part:
470+
has_function_calls = True
471+
if 'text' in part:
472+
has_text_parts = True
473+
474+
# If we only have function calls without text, add minimal text to satisfy Google API
475+
if has_function_calls and not has_text_parts:
476+
# Add a minimal text part to make the conversation valid for Google API
477+
parts.append({'text': 'I have completed the function calls above.'})
478+
model_content['parts'] = parts
479+
480+
contents.append(model_content)
461481
else:
462482
assert_never(m)
463483
if instructions := self._get_instructions(messages):

tests/models/test_google.py

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
ImageUrl,
2727
ModelRequest,
2828
ModelResponse,
29+
ModelResponsePart,
2930
PartDeltaEvent,
3031
PartStartEvent,
3132
RetryPromptPart,
@@ -1740,3 +1741,191 @@ async def get_user_country() -> str:
17401741
'What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge.',
17411742
usage_limits=UsageLimits(total_tokens_limit=9, count_tokens_before_request=True),
17421743
)
1744+
1745+
1746+
@pytest.mark.parametrize(
1747+
'model_parts,expected_contents',
1748+
[
1749+
pytest.param(
1750+
[ToolCallPart(tool_name='test_tool', args={'param': 'value'}, tool_call_id='call_123')],
1751+
[
1752+
{
1753+
'role': 'model',
1754+
'parts': [
1755+
{
1756+
'function_call': {
1757+
'args': {'param': 'value'},
1758+
'id': 'call_123',
1759+
'name': 'test_tool',
1760+
}
1761+
},
1762+
{'text': 'I have completed the function calls above.'},
1763+
],
1764+
}
1765+
],
1766+
id='function_call_without_text',
1767+
),
1768+
pytest.param(
1769+
[],
1770+
[],
1771+
id='empty_response_parts',
1772+
),
1773+
pytest.param(
1774+
[
1775+
ToolCallPart(tool_name='test_tool', args={'param': 'value'}, tool_call_id='call_123'),
1776+
TextPart(content='Here is the result:'),
1777+
],
1778+
[
1779+
{
1780+
'role': 'model',
1781+
'parts': [
1782+
{
1783+
'function_call': {
1784+
'args': {'param': 'value'},
1785+
'id': 'call_123',
1786+
'name': 'test_tool',
1787+
}
1788+
},
1789+
{'text': 'Here is the result:'},
1790+
],
1791+
}
1792+
],
1793+
id='function_call_with_text',
1794+
),
1795+
pytest.param(
1796+
[ThinkingPart(content='Let me think about this...')],
1797+
[],
1798+
id='thinking_only_response_skipped',
1799+
),
1800+
],
1801+
)
1802+
async def test_google_model_response_part_handling(
1803+
google_provider: GoogleProvider, model_parts: list[ModelResponsePart], expected_contents: list[dict[str, Any]]
1804+
):
1805+
"""Test Google model's handling of different response part combinations for API compatibility."""
1806+
model = GoogleModel('gemini-2.0-flash', provider=google_provider)
1807+
1808+
model_response = ModelResponse(
1809+
parts=model_parts,
1810+
usage=Usage(requests=1, request_tokens=10, response_tokens=5, total_tokens=15),
1811+
model_name='gemini-2.0-flash',
1812+
)
1813+
1814+
_, contents = await model._map_messages([model_response]) # pyright: ignore[reportPrivateUsage]
1815+
assert contents == expected_contents
1816+
1817+
1818+
class FunctionCallDict(TypedDict):
1819+
name: str
1820+
args: dict[str, Any]
1821+
id: str
1822+
1823+
1824+
class FunctionCallPartDict(TypedDict):
1825+
function_call: FunctionCallDict
1826+
1827+
1828+
class TextPartDict(TypedDict):
1829+
text: str
1830+
1831+
1832+
class OtherPartDict(TypedDict, total=False):
1833+
other_field: str
1834+
1835+
1836+
# Union of all possible part types we're testing
1837+
TestPartDict = FunctionCallPartDict | TextPartDict | OtherPartDict | str # str for non-dict parts
1838+
1839+
1840+
class MockContentResponse(TypedDict, total=False):
1841+
role: str
1842+
parts: list[TestPartDict]
1843+
1844+
1845+
class ExpectedContent(TypedDict, total=False):
1846+
role: str
1847+
parts: list[TestPartDict]
1848+
1849+
1850+
@pytest.mark.parametrize(
1851+
'mock_content_response,expected_contents',
1852+
[
1853+
pytest.param(
1854+
MockContentResponse(
1855+
{
1856+
'role': 'model',
1857+
'parts': [
1858+
'not_a_dict', # Non-dict part to test isinstance check
1859+
{'function_call': {'name': 'test', 'args': {}, 'id': '123'}},
1860+
],
1861+
}
1862+
),
1863+
[
1864+
ExpectedContent(
1865+
{
1866+
'role': 'model',
1867+
'parts': [
1868+
'not_a_dict',
1869+
{'function_call': {'name': 'test', 'args': {}, 'id': '123'}},
1870+
{'text': 'I have completed the function calls above.'},
1871+
],
1872+
}
1873+
)
1874+
],
1875+
id='non_dict_parts_with_function_call',
1876+
),
1877+
pytest.param(
1878+
MockContentResponse(
1879+
{
1880+
'role': 'model',
1881+
'parts': [
1882+
{'other_field': 'value'}, # Dict without function_call or text
1883+
{'function_call': {'name': 'test', 'args': {}, 'id': '123'}},
1884+
],
1885+
}
1886+
),
1887+
[
1888+
ExpectedContent(
1889+
{
1890+
'role': 'model',
1891+
'parts': [
1892+
{'other_field': 'value'},
1893+
{'function_call': {'name': 'test', 'args': {}, 'id': '123'}},
1894+
{'text': 'I have completed the function calls above.'},
1895+
],
1896+
}
1897+
)
1898+
],
1899+
id='dict_parts_without_function_call_or_text',
1900+
),
1901+
pytest.param(
1902+
MockContentResponse({'role': 'model'}), # No 'parts' key
1903+
[],
1904+
id='no_parts_key',
1905+
),
1906+
pytest.param(
1907+
MockContentResponse({'role': 'model', 'parts': []}), # Empty parts
1908+
[],
1909+
id='empty_parts_list',
1910+
),
1911+
],
1912+
)
1913+
async def test_google_model_response_edge_cases(
1914+
google_provider: GoogleProvider,
1915+
mock_content_response: MockContentResponse,
1916+
expected_contents: list[ExpectedContent],
1917+
):
1918+
"""Test Google model's _map_messages method with various edge cases for function call handling."""
1919+
from unittest.mock import patch
1920+
1921+
model = GoogleModel('gemini-2.0-flash', provider=google_provider)
1922+
model_response = ModelResponse(
1923+
parts=[ToolCallPart(tool_name='test_tool', args={'param': 'value'}, tool_call_id='call_123')],
1924+
usage=Usage(requests=1, request_tokens=10, response_tokens=5, total_tokens=15),
1925+
model_name='gemini-2.0-flash',
1926+
)
1927+
1928+
with patch('pydantic_ai.models.google._content_model_response') as mock_content:
1929+
mock_content.return_value = mock_content_response
1930+
_, contents = await model._map_messages([model_response]) # pyright: ignore[reportPrivateUsage]
1931+
assert contents == expected_contents

0 commit comments

Comments
 (0)