Skip to content

Commit e456ad3

Browse files
committed
move logic inside _content_model_response
1 parent e436234 commit e456ad3

File tree

2 files changed

+95
-105
lines changed

2 files changed

+95
-105
lines changed

pydantic_ai_slim/pydantic_ai/models/google.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -460,23 +460,6 @@ async def _map_messages(self, messages: list[ModelMessage]) -> tuple[ContentDict
460460
model_content = _content_model_response(m)
461461
# Skip model responses with empty parts (e.g., thinking-only responses)
462462
if model_content.get('parts'):
463-
# Check if the model response contains only function calls without text
464-
if parts := model_content.get('parts', []):
465-
has_function_calls = False
466-
has_text_parts = False
467-
for part in parts:
468-
if isinstance(part, dict):
469-
if 'function_call' in part:
470-
has_function_calls = True
471-
if 'text' in part:
472-
has_text_parts = True
473-
474-
# If we only have function calls without text, add minimal text to satisfy Google API
475-
if has_function_calls and not has_text_parts:
476-
# Add a minimal text part to make the conversation valid for Google API
477-
parts.append({'text': 'I have completed the function calls above.'})
478-
model_content['parts'] = parts
479-
480463
contents.append(model_content)
481464
else:
482465
assert_never(m)
@@ -583,12 +566,17 @@ def timestamp(self) -> datetime:
583566

584567
def _content_model_response(m: ModelResponse) -> ContentDict:
585568
parts: list[PartDict] = []
569+
has_function_calls = False
570+
has_text_parts = False
571+
586572
for item in m.parts:
587573
if isinstance(item, ToolCallPart):
588574
function_call = FunctionCallDict(name=item.tool_name, args=item.args_as_dict(), id=item.tool_call_id)
589575
parts.append({'function_call': function_call})
576+
has_function_calls = True
590577
elif isinstance(item, TextPart):
591578
parts.append({'text': item.content})
579+
has_text_parts = True
592580
elif isinstance(item, ThinkingPart): # pragma: no cover
593581
# NOTE: We don't send ThinkingPart to the providers yet. If you are unsatisfied with this,
594582
# please open an issue. The below code is the code to send thinking to the provider.
@@ -604,6 +592,11 @@ def _content_model_response(m: ModelResponse) -> ContentDict:
604592
parts.append({'code_execution_result': item.content})
605593
else:
606594
assert_never(item)
595+
596+
# If we only have function calls without text, add minimal text to satisfy Google API
597+
if has_function_calls and not has_text_parts:
598+
parts.append({'text': 'I have completed the function calls above.'})
599+
607600
return ContentDict(role='model', parts=parts)
608601

609602

tests/models/test_google.py

Lines changed: 85 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import datetime
44
import os
5-
from typing import Any, Union
5+
from typing import Any
66

77
import pytest
88
from httpx import Timeout
@@ -50,7 +50,11 @@
5050
with try_import() as imports_successful:
5151
from google.genai.types import CodeExecutionResult, HarmBlockThreshold, HarmCategory, Language, Outcome
5252

53-
from pydantic_ai.models.google import GoogleModel, GoogleModelSettings
53+
from pydantic_ai.models.google import (
54+
GoogleModel,
55+
GoogleModelSettings,
56+
_content_model_response, # pyright: ignore[reportPrivateUsage]
57+
)
5458
from pydantic_ai.providers.google import GoogleProvider
5559

5660
pytestmark = [
@@ -1815,117 +1819,110 @@ async def test_google_model_response_part_handling(
18151819
assert contents == expected_contents
18161820

18171821

1818-
class FunctionCallDict(TypedDict):
1819-
name: str
1820-
args: dict[str, Any]
1821-
id: str
1822-
1823-
1824-
class FunctionCallPartDict(TypedDict):
1825-
function_call: FunctionCallDict
1826-
1827-
1828-
class TextPartDict(TypedDict):
1829-
text: str
1830-
1831-
1832-
class OtherPartDict(TypedDict, total=False):
1833-
other_field: str
1834-
1835-
1836-
# Union of all possible part types we're testing
1837-
TestPartDict = Union[FunctionCallPartDict, TextPartDict, OtherPartDict, str] # str for non-dict parts
1838-
1839-
1840-
class MockContentResponse(TypedDict, total=False):
1841-
role: str
1842-
parts: list[TestPartDict]
1843-
1844-
1845-
class ExpectedContent(TypedDict, total=False):
1846-
role: str
1847-
parts: list[TestPartDict]
1848-
1849-
18501822
@pytest.mark.parametrize(
1851-
'mock_content_response,expected_contents',
1823+
'model_parts,expected_parts',
18521824
[
18531825
pytest.param(
1854-
MockContentResponse(
1855-
{
1856-
'role': 'model',
1857-
'parts': [
1858-
'not_a_dict', # Non-dict part to test isinstance check
1859-
{'function_call': {'name': 'test', 'args': {}, 'id': '123'}},
1860-
],
1861-
}
1862-
),
1826+
[ToolCallPart(tool_name='test_tool', args={'param': 'value'}, tool_call_id='call_123')],
18631827
[
1864-
ExpectedContent(
1865-
{
1866-
'role': 'model',
1867-
'parts': [
1868-
'not_a_dict',
1869-
{'function_call': {'name': 'test', 'args': {}, 'id': '123'}},
1870-
{'text': 'I have completed the function calls above.'},
1871-
],
1828+
{
1829+
'function_call': {
1830+
'args': {'param': 'value'},
1831+
'id': 'call_123',
1832+
'name': 'test_tool',
18721833
}
1873-
)
1834+
},
1835+
{'text': 'I have completed the function calls above.'},
18741836
],
1875-
id='non_dict_parts_with_function_call',
1837+
id='function_call_only_adds_text',
18761838
),
18771839
pytest.param(
1878-
MockContentResponse(
1840+
[
1841+
ToolCallPart(tool_name='test_tool', args={'param': 'value'}, tool_call_id='call_123'),
1842+
TextPart(content='Here is the result:'),
1843+
],
1844+
[
18791845
{
1880-
'role': 'model',
1881-
'parts': [
1882-
{'other_field': 'value'}, # Dict without function_call or text
1883-
{'function_call': {'name': 'test', 'args': {}, 'id': '123'}},
1884-
],
1885-
}
1886-
),
1846+
'function_call': {
1847+
'args': {'param': 'value'},
1848+
'id': 'call_123',
1849+
'name': 'test_tool',
1850+
}
1851+
},
1852+
{'text': 'Here is the result:'},
1853+
],
1854+
id='function_call_with_text_no_addition',
1855+
),
1856+
pytest.param(
1857+
[TextPart(content='Just text response')],
1858+
[{'text': 'Just text response'}],
1859+
id='text_only_no_addition',
1860+
),
1861+
pytest.param(
1862+
[
1863+
ToolCallPart(tool_name='tool1', args={'a': 1}, tool_call_id='call_1'),
1864+
ToolCallPart(tool_name='tool2', args={'b': 2}, tool_call_id='call_2'),
1865+
],
18871866
[
1888-
ExpectedContent(
1889-
{
1890-
'role': 'model',
1891-
'parts': [
1892-
{'other_field': 'value'},
1893-
{'function_call': {'name': 'test', 'args': {}, 'id': '123'}},
1894-
{'text': 'I have completed the function calls above.'},
1895-
],
1867+
{
1868+
'function_call': {
1869+
'args': {'a': 1},
1870+
'id': 'call_1',
1871+
'name': 'tool1',
18961872
}
1897-
)
1873+
},
1874+
{
1875+
'function_call': {
1876+
'args': {'b': 2},
1877+
'id': 'call_2',
1878+
'name': 'tool2',
1879+
}
1880+
},
1881+
{'text': 'I have completed the function calls above.'},
18981882
],
1899-
id='dict_parts_without_function_call_or_text',
1883+
id='multiple_function_calls_only',
19001884
),
19011885
pytest.param(
1902-
MockContentResponse({'role': 'model'}), # No 'parts' key
1886+
[ThinkingPart(content='Let me think...')],
19031887
[],
1904-
id='no_parts_key',
1888+
id='thinking_only_empty_parts',
1889+
),
1890+
pytest.param(
1891+
[
1892+
ThinkingPart(content='Let me think...'),
1893+
ToolCallPart(tool_name='test_tool', args={'param': 'value'}, tool_call_id='call_123'),
1894+
],
1895+
[
1896+
{
1897+
'function_call': {
1898+
'args': {'param': 'value'},
1899+
'id': 'call_123',
1900+
'name': 'test_tool',
1901+
}
1902+
},
1903+
{'text': 'I have completed the function calls above.'},
1904+
],
1905+
id='thinking_and_function_call',
19051906
),
19061907
pytest.param(
1907-
MockContentResponse({'role': 'model', 'parts': []}), # Empty parts
19081908
[],
1909-
id='empty_parts_list',
1909+
[],
1910+
id='empty_parts',
19101911
),
19111912
],
19121913
)
1913-
async def test_google_model_response_edge_cases(
1914-
google_provider: GoogleProvider,
1915-
mock_content_response: MockContentResponse,
1916-
expected_contents: list[ExpectedContent],
1914+
def test_content_model_response_function_call_handling(
1915+
model_parts: list[ModelResponsePart], expected_parts: list[dict[str, Any]]
19171916
):
1918-
"""Test Google model's _map_messages method with various edge cases for function call handling."""
1919-
from unittest.mock import patch
1917+
"""Test _content_model_response function's handling of function calls without text."""
19201918

1921-
model = GoogleModel('gemini-2.0-flash', provider=google_provider)
19221919
model_response = ModelResponse(
1923-
parts=[ToolCallPart(tool_name='test_tool', args={'param': 'value'}, tool_call_id='call_123')],
1920+
parts=model_parts,
19241921
usage=Usage(requests=1, request_tokens=10, response_tokens=5, total_tokens=15),
19251922
model_name='gemini-2.0-flash',
19261923
)
19271924

1928-
with patch('pydantic_ai.models.google._content_model_response') as mock_content:
1929-
mock_content.return_value = mock_content_response
1930-
_, contents = await model._map_messages([model_response]) # pyright: ignore[reportPrivateUsage]
1931-
assert contents == expected_contents
1925+
result = _content_model_response(model_response)
1926+
1927+
expected_result = {'role': 'model', 'parts': expected_parts}
1928+
assert result == expected_result

0 commit comments

Comments
 (0)