Skip to content

Commit 86b3cf3

Browse files
committed
Add tests for Gemini MALFORMED_FUNCTION_CALL error
1 parent 6166f3d commit 86b3cf3

File tree

2 files changed

+154
-1
lines changed

2 files changed

+154
-1
lines changed

pydantic_ai_slim/pydantic_ai/models/google.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -640,7 +640,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
640640
if candidate.finish_reason == 'STOP': # pragma: no cover
641641
# Normal completion - skip this chunk
642642
continue
643-
elif candidate.finish_reason == 'MALFORMED_FUNCTION_CALL': # pragma: no cover
643+
elif candidate.finish_reason == 'MALFORMED_FUNCTION_CALL':
644644
continue
645645
elif candidate.finish_reason == 'SAFETY': # pragma: no cover
646646
raise UnexpectedModelBehavior('Safety settings triggered', str(chunk))

tests/models/test_google.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from inline_snapshot import Is, snapshot
1111
from pydantic import BaseModel
1212
from typing_extensions import TypedDict
13+
from unittest.mock import AsyncMock
14+
from uuid import uuid4
1315

1416
from pydantic_ai import (
1517
AudioUrl,
@@ -2929,3 +2931,154 @@ async def test_google_vertexai_image_generation(allow_model_requests: None, vert
29292931
identifier='f3edd8',
29302932
)
29312933
)
2934+
2935+
2936+
async def test_google_model_stream_malformed_function_call_continue(
2937+
allow_model_requests: None, google_provider: GoogleProvider, monkeypatch: pytest.MonkeyPatch
2938+
):
2939+
"""Test that the model can continue streaming after a malformed function call."""
2940+
from unittest.mock import MagicMock
2941+
2942+
from google.genai.types import Candidate, Content, GenerateContentResponse, Part
2943+
2944+
mock_stream = AsyncMock()
2945+
2946+
def create_mock_response(**kwargs):
2947+
mock = MagicMock(spec=GenerateContentResponse)
2948+
for key, value in kwargs.items():
2949+
setattr(mock, key, value)
2950+
mock.model_version = 'gemini-1.5-flash'
2951+
mock.create_time = None
2952+
mock.response_id = 'test-response-id'
2953+
return mock
2954+
2955+
def create_mock_candidate(**kwargs):
2956+
mock = MagicMock(spec=Candidate)
2957+
for key, value in kwargs.items():
2958+
setattr(mock, key, value)
2959+
return mock
2960+
2961+
def create_mock_content(*parts):
2962+
mock = MagicMock(spec=Content)
2963+
mock.parts = list(parts)
2964+
return mock
2965+
2966+
def create_mock_part(**kwargs):
2967+
mock = MagicMock(spec=Part)
2968+
mock.text = None
2969+
mock.function_call = None
2970+
for key, value in kwargs.items():
2971+
setattr(mock, key, value)
2972+
return mock
2973+
2974+
2975+
async def mock_stream_iterator():
2976+
# This part is a bit of a fiction, in reality the tool call and return would not be in the same stream.
2977+
# But for testing the malformed function call in between, this is fine.
2978+
yield create_mock_response(
2979+
candidates=[
2980+
create_mock_candidate(
2981+
content=create_mock_content(
2982+
create_mock_part(function_call=MagicMock(name='get_capital', args={'country': 'France'}))
2983+
),
2984+
finish_reason=None,
2985+
)
2986+
]
2987+
)
2988+
2989+
yield create_mock_response(
2990+
candidates=[create_mock_candidate(content=None, finish_reason='MALFORMED_FUNCTION_CALL')]
2991+
)
2992+
2993+
yield create_mock_response(
2994+
candidates=[
2995+
create_mock_candidate(
2996+
content=create_mock_content(create_mock_part(text='Here is the capital.')), finish_reason=None
2997+
)
2998+
]
2999+
)
3000+
3001+
yield create_mock_response(candidates=[create_mock_candidate(content=None, finish_reason='STOP')])
3002+
3003+
mock_stream.return_value = mock_stream_iterator()
3004+
model = GoogleModel('gemini-1.5-flash', provider=google_provider)
3005+
monkeypatch.setattr(model.client.aio.models, 'generate_content_stream', mock_stream)
3006+
3007+
agent = Agent(model)
3008+
3009+
@agent.tool_plain
3010+
def get_capital(country: str) -> str:
3011+
"""Get the capital of a country."""
3012+
return 'Paris' if country == 'France' else 'Unknown'
3013+
3014+
async with agent.run_stream('What is the capital of France?') as result:
3015+
data = await result.get_output()
3016+
3017+
assert data == 'Here is the capital.'
3018+
assert mock_stream.call_count == 1
3019+
3020+
3021+
async def test_google_model_stream_malformed_function_call_retry(
3022+
allow_model_requests: None, google_provider: GoogleProvider, monkeypatch: pytest.MonkeyPatch
3023+
):
3024+
"""Test that the model retries when the stream only contains a malformed function call."""
3025+
from unittest.mock import MagicMock
3026+
3027+
from google.genai.types import Candidate, Content, GenerateContentResponse, Part
3028+
3029+
mock_stream = AsyncMock()
3030+
3031+
def create_mock_response(**kwargs):
3032+
mock = MagicMock(spec=GenerateContentResponse)
3033+
for key, value in kwargs.items():
3034+
setattr(mock, key, value)
3035+
mock.model_version = 'gemini-1.5-flash'
3036+
mock.create_time = None
3037+
mock.response_id = 'test-response-id'
3038+
return mock
3039+
3040+
def create_mock_candidate(**kwargs):
3041+
mock = MagicMock(spec=Candidate)
3042+
for key, value in kwargs.items():
3043+
setattr(mock, key, value)
3044+
return mock
3045+
3046+
def create_mock_content(*parts):
3047+
mock = MagicMock(spec=Content)
3048+
mock.parts = list(parts)
3049+
return mock
3050+
3051+
def create_mock_part(**kwargs):
3052+
mock = MagicMock(spec=Part)
3053+
mock.text = None
3054+
for key, value in kwargs.items():
3055+
setattr(mock, key, value)
3056+
return mock
3057+
3058+
async def first_call_iterator():
3059+
yield create_mock_response(
3060+
candidates=[create_mock_candidate(content=None, finish_reason='MALFORMED_FUNCTION_CALL')]
3061+
)
3062+
3063+
async def second_call_iterator():
3064+
yield create_mock_response(
3065+
candidates=[
3066+
create_mock_candidate(
3067+
content=create_mock_content(create_mock_part(text='Successful response')), finish_reason=None
3068+
)
3069+
]
3070+
)
3071+
yield create_mock_response(candidates=[create_mock_candidate(content=None, finish_reason='STOP')])
3072+
3073+
mock_stream.side_effect = [first_call_iterator(), second_call_iterator()]
3074+
3075+
model = GoogleModel('gemini-1.5-flash', provider=google_provider)
3076+
monkeypatch.setattr(model.client.aio.models, 'generate_content_stream', mock_stream)
3077+
3078+
agent = Agent(model, retries=1)
3079+
3080+
async with agent.run_stream('Some prompt') as result:
3081+
data = await result.get_output()
3082+
3083+
assert data == 'Successful response'
3084+
assert mock_stream.call_count == 2

0 commit comments

Comments
 (0)