Skip to content

Commit 2997f66

Browse files
committed
Handle MALFORMED_FUNCTION_CALL in GoogleModel and add tests
Updated GoogleModel to handle 'MALFORMED_FUNCTION_CALL' finish reason by initializing parts as an empty list. Added and extended tests to cover retry logic for both streaming and non-streaming responses when a malformed function call is encountered.
1 parent 86b3cf3 commit 2997f66

File tree

2 files changed

+127
-8
lines changed

2 files changed

+127
-8
lines changed

pydantic_ai_slim/pydantic_ai/models/google.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -453,11 +453,14 @@ def _process_response(self, response: GenerateContentResponse) -> ModelResponse:
453453
if candidate.content is None or candidate.content.parts is None:
454454
if candidate.finish_reason == 'SAFETY':
455455
raise UnexpectedModelBehavior('Safety settings triggered', str(response))
456+
elif candidate.finish_reason == 'MALFORMED_FUNCTION_CALL':
457+
parts = []
456458
else:
457459
raise UnexpectedModelBehavior(
458460
'Content field missing from Gemini response', str(response)
459461
) # pragma: no cover
460-
parts = candidate.content.parts or []
462+
else:
463+
parts = candidate.content.parts or []
461464

462465
vendor_id = response.response_id
463466
vendor_details: dict[str, Any] | None = None

tests/models/test_google.py

Lines changed: 123 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2943,21 +2943,32 @@ async def test_google_model_stream_malformed_function_call_continue(
29432943

29442944
mock_stream = AsyncMock()
29452945

2946+
class _FinishReason(str):
2947+
"""Mock finish reason with value attribute."""
2948+
2949+
@property
2950+
def value(self):
2951+
return str(self)
2952+
29462953
def create_mock_response(**kwargs):
29472954
mock = MagicMock(spec=GenerateContentResponse)
29482955
for key, value in kwargs.items():
29492956
setattr(mock, key, value)
2950-
mock.model_version = 'gemini-1.5-flash'
2957+
mock.model_version = 'gemini-2.5-flash'
29512958
mock.create_time = None
29522959
mock.response_id = 'test-response-id'
2960+
mock.usage_metadata = None
29532961
return mock
29542962

29552963
def create_mock_candidate(**kwargs):
29562964
mock = MagicMock(spec=Candidate)
2965+
# Convert string finish_reason to our mock enum
2966+
if 'finish_reason' in kwargs and isinstance(kwargs['finish_reason'], str):
2967+
kwargs['finish_reason'] = _FinishReason(kwargs['finish_reason'])
29572968
for key, value in kwargs.items():
29582969
setattr(mock, key, value)
29592970
return mock
2960-
2971+
29612972
def create_mock_content(*parts):
29622973
mock = MagicMock(spec=Content)
29632974
mock.parts = list(parts)
@@ -2967,6 +2978,12 @@ def create_mock_part(**kwargs):
29672978
mock = MagicMock(spec=Part)
29682979
mock.text = None
29692980
mock.function_call = None
2981+
mock.thought_signature = None
2982+
mock.thought = None
2983+
mock.inline_data = None
2984+
mock.executable_code = None
2985+
mock.code_execution_result = None
2986+
mock.function_response = None
29702987
for key, value in kwargs.items():
29712988
setattr(mock, key, value)
29722989
return mock
@@ -2989,7 +3006,7 @@ async def mock_stream_iterator():
29893006
yield create_mock_response(
29903007
candidates=[create_mock_candidate(content=None, finish_reason='MALFORMED_FUNCTION_CALL')]
29913008
)
2992-
3009+
29933010
yield create_mock_response(
29943011
candidates=[
29953012
create_mock_candidate(
@@ -3001,7 +3018,7 @@ async def mock_stream_iterator():
30013018
yield create_mock_response(candidates=[create_mock_candidate(content=None, finish_reason='STOP')])
30023019

30033020
mock_stream.return_value = mock_stream_iterator()
3004-
model = GoogleModel('gemini-1.5-flash', provider=google_provider)
3021+
model = GoogleModel('gemini-2.5-flash', provider=google_provider)
30053022
monkeypatch.setattr(model.client.aio.models, 'generate_content_stream', mock_stream)
30063023

30073024
agent = Agent(model)
@@ -3028,21 +3045,32 @@ async def test_google_model_stream_malformed_function_call_retry(
30283045

30293046
mock_stream = AsyncMock()
30303047

3048+
class _FinishReason(str):
3049+
"""Mock finish reason with value attribute."""
3050+
3051+
@property
3052+
def value(self):
3053+
return str(self)
3054+
30313055
def create_mock_response(**kwargs):
30323056
mock = MagicMock(spec=GenerateContentResponse)
30333057
for key, value in kwargs.items():
30343058
setattr(mock, key, value)
3035-
mock.model_version = 'gemini-1.5-flash'
3059+
mock.model_version = 'gemini-2.5-flash'
30363060
mock.create_time = None
30373061
mock.response_id = 'test-response-id'
3062+
mock.usage_metadata = None
30383063
return mock
30393064

30403065
def create_mock_candidate(**kwargs):
30413066
mock = MagicMock(spec=Candidate)
3067+
# Convert string finish_reason to our mock enum
3068+
if 'finish_reason' in kwargs and isinstance(kwargs['finish_reason'], str):
3069+
kwargs['finish_reason'] = _FinishReason(kwargs['finish_reason'])
30423070
for key, value in kwargs.items():
30433071
setattr(mock, key, value)
30443072
return mock
3045-
3073+
30463074
def create_mock_content(*parts):
30473075
mock = MagicMock(spec=Content)
30483076
mock.parts = list(parts)
@@ -3051,6 +3079,13 @@ def create_mock_content(*parts):
30513079
def create_mock_part(**kwargs):
30523080
mock = MagicMock(spec=Part)
30533081
mock.text = None
3082+
mock.function_call = None
3083+
mock.thought_signature = None
3084+
mock.thought = None
3085+
mock.inline_data = None
3086+
mock.executable_code = None
3087+
mock.code_execution_result = None
3088+
mock.function_response = None
30543089
for key, value in kwargs.items():
30553090
setattr(mock, key, value)
30563091
return mock
@@ -3072,7 +3107,7 @@ async def second_call_iterator():
30723107

30733108
mock_stream.side_effect = [first_call_iterator(), second_call_iterator()]
30743109

3075-
model = GoogleModel('gemini-1.5-flash', provider=google_provider)
3110+
model = GoogleModel('gemini-2.5-flash', provider=google_provider)
30763111
monkeypatch.setattr(model.client.aio.models, 'generate_content_stream', mock_stream)
30773112

30783113
agent = Agent(model, retries=1)
@@ -3082,3 +3117,84 @@ async def second_call_iterator():
30823117

30833118
assert data == 'Successful response'
30843119
assert mock_stream.call_count == 2
3120+
3121+
3122+
async def test_google_model_malformed_function_call_retry(
3123+
allow_model_requests: None, google_provider: GoogleProvider, monkeypatch: pytest.MonkeyPatch
3124+
):
3125+
"""Test that the model retries when non-streamed response contains a malformed function call."""
3126+
from unittest.mock import MagicMock
3127+
3128+
from google.genai.types import Candidate, Content, GenerateContentResponse, Part
3129+
3130+
mock_generate = AsyncMock()
3131+
3132+
class _FinishReason(str):
3133+
"""Mock finish reason with value attribute."""
3134+
3135+
@property
3136+
def value(self):
3137+
return str(self)
3138+
3139+
def create_mock_response(**kwargs):
3140+
mock = MagicMock(spec=GenerateContentResponse)
3141+
for key, value in kwargs.items():
3142+
setattr(mock, key, value)
3143+
mock.model_version = 'gemini-2.5-flash'
3144+
mock.create_time = None
3145+
mock.response_id = 'test-response-id'
3146+
mock.usage_metadata = None
3147+
return mock
3148+
3149+
def create_mock_candidate(**kwargs):
3150+
mock = MagicMock(spec=Candidate)
3151+
# Convert string finish_reason to our mock enum
3152+
if 'finish_reason' in kwargs and isinstance(kwargs['finish_reason'], str):
3153+
kwargs['finish_reason'] = _FinishReason(kwargs['finish_reason'])
3154+
mock.grounding_metadata = None
3155+
for key, value in kwargs.items():
3156+
setattr(mock, key, value)
3157+
return mock
3158+
3159+
def create_mock_content(*parts):
3160+
mock = MagicMock(spec=Content)
3161+
mock.parts = list(parts)
3162+
return mock
3163+
3164+
def create_mock_part(**kwargs):
3165+
mock = MagicMock(spec=Part)
3166+
mock.text = None
3167+
mock.function_call = None
3168+
mock.thought_signature = None
3169+
mock.thought = None
3170+
mock.inline_data = None
3171+
mock.executable_code = None
3172+
mock.code_execution_result = None
3173+
mock.function_response = None
3174+
for key, value in kwargs.items():
3175+
setattr(mock, key, value)
3176+
return mock
3177+
3178+
first_response = create_mock_response(
3179+
candidates=[create_mock_candidate(content=None, finish_reason='MALFORMED_FUNCTION_CALL')]
3180+
)
3181+
3182+
second_response = create_mock_response(
3183+
candidates=[
3184+
create_mock_candidate(
3185+
content=create_mock_content(create_mock_part(text='Successful response')), finish_reason='STOP'
3186+
)
3187+
]
3188+
)
3189+
3190+
mock_generate.side_effect = [first_response, second_response]
3191+
3192+
model = GoogleModel('gemini-2.5-flash', provider=google_provider)
3193+
monkeypatch.setattr(model.client.aio.models, 'generate_content', mock_generate)
3194+
3195+
agent = Agent(model, retries=1)
3196+
3197+
result = await agent.run('Some prompt')
3198+
3199+
assert result.output == 'Successful response'
3200+
assert mock_generate.call_count == 2

0 commit comments

Comments
 (0)