Skip to content

Commit a68b631

Browse files
committed
Simplify Gemini response handling and remove unused tests
Refactored Gemini response handling in GoogleModel and GeminiStreamedResponse to treat missing or malformed function call content uniformly, avoiding unnecessary exceptions. Removed obsolete or redundant test helpers and tests related to malformed function call streaming and retry logic.
1 parent 4288e9e commit a68b631

File tree

2 files changed

+4
-215
lines changed

2 files changed

+4
-215
lines changed

pydantic_ai_slim/pydantic_ai/models/google.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -471,12 +471,7 @@ def _process_response(self, response: GenerateContentResponse) -> ModelResponse:
471471
raise UnexpectedModelBehavior(
472472
f'Content filter {raw_finish_reason.value!r} triggered', response.model_dump_json()
473473
)
474-
elif raw_finish_reason and raw_finish_reason.value == 'MALFORMED_FUNCTION_CALL':
475-
parts = []
476-
else:
477-
raise UnexpectedModelBehavior(
478-
'Content field missing from Gemini response', response.model_dump_json()
479-
) # pragma: no cover
474+
parts = []
480475
else:
481476
parts = candidate.content.parts or []
482477

@@ -652,19 +647,12 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
652647
# )
653648

654649
if candidate.content is None or candidate.content.parts is None:
655-
if self.finish_reason == 'stop': # pragma: no cover
656-
# Normal completion - skip this chunk
657-
continue
658-
elif raw_finish_reason and raw_finish_reason.value == 'MALFORMED_FUNCTION_CALL':
659-
continue
660-
elif self.finish_reason == 'content_filter' and raw_finish_reason: # pragma: no cover
650+
if self.finish_reason == 'content_filter' and raw_finish_reason: # pragma: no cover
661651
raise UnexpectedModelBehavior(
662652
f'Content filter {raw_finish_reason.value!r} triggered', chunk.model_dump_json()
663653
)
664654
else: # pragma: no cover
665-
raise UnexpectedModelBehavior(
666-
'Content field missing from streaming Gemini response', chunk.model_dump_json()
667-
)
655+
continue
668656

669657
parts = candidate.content.parts
670658
if not parts:

tests/models/test_google.py

Lines changed: 1 addition & 200 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from inline_snapshot import Is, snapshot
1111
from pydantic import BaseModel
1212
from typing_extensions import TypedDict
13-
from unittest.mock import AsyncMock, MagicMock
1413

1514
from pydantic_ai import (
1615
AudioUrl,
@@ -1896,6 +1895,7 @@ def dummy() -> None: ... # pragma: no cover
18961895
PartDeltaEvent(index=1, delta=TextPartDelta(content_delta=IsStr())),
18971896
PartDeltaEvent(index=1, delta=TextPartDelta(content_delta=IsStr())),
18981897
PartDeltaEvent(index=1, delta=TextPartDelta(content_delta=IsStr())),
1898+
PartDeltaEvent(index=1, delta=TextPartDelta(content_delta=IsStr())),
18991899
PartEndEvent(
19001900
index=1,
19011901
part=TextPart(
@@ -3091,205 +3091,6 @@ async def test_google_vertexai_image_generation(allow_model_requests: None, vert
30913091
)
30923092

30933093

3094-
def _create_google_stream_mock_helpers():
3095-
from google.genai.types import Candidate, Content, GenerateContentResponse, Part
3096-
3097-
class _FR(str):
3098-
def __new__(cls, name: str):
3099-
obj = str.__new__(cls, name)
3100-
return obj
3101-
3102-
@property
3103-
def value(self) -> str:
3104-
return str(self)
3105-
3106-
def create_mock_response(**kwargs: Any) -> MagicMock:
3107-
mock = MagicMock(spec=GenerateContentResponse)
3108-
for key, value in kwargs.items():
3109-
setattr(mock, key, value)
3110-
mock.model_version = 'gemini-1.5-flash'
3111-
mock.create_time = None
3112-
mock.response_id = 'test-response-id'
3113-
mock.usage_metadata = None
3114-
return mock
3115-
3116-
def create_mock_candidate(**kwargs: Any) -> MagicMock:
3117-
mock = MagicMock(spec=Candidate)
3118-
mock.grounding_metadata = None
3119-
if 'finish_reason' in kwargs and isinstance(kwargs['finish_reason'], str):
3120-
kwargs['finish_reason'] = _FR(kwargs['finish_reason'])
3121-
for key, value in kwargs.items():
3122-
setattr(mock, key, value)
3123-
return mock
3124-
3125-
def create_mock_content(*parts: MagicMock) -> MagicMock:
3126-
mock = MagicMock(spec=Content)
3127-
mock.parts = list(parts)
3128-
return mock
3129-
3130-
def create_mock_part(**kwargs: Any) -> MagicMock:
3131-
mock = MagicMock(spec=Part)
3132-
mock.text = None
3133-
mock.function_call = None
3134-
mock.thought_signature = None
3135-
mock.thought = None
3136-
mock.inline_data = None
3137-
mock.executable_code = None
3138-
mock.code_execution_result = None
3139-
mock.function_response = None
3140-
3141-
for key, value in kwargs.items():
3142-
setattr(mock, key, value)
3143-
return mock
3144-
3145-
return create_mock_response, create_mock_candidate, create_mock_content, create_mock_part
3146-
3147-
3148-
async def test_google_model_stream_malformed_function_call_continue(
3149-
allow_model_requests: None, google_provider: GoogleProvider, monkeypatch: pytest.MonkeyPatch
3150-
):
3151-
"""Test that the model can continue streaming after a malformed function call."""
3152-
mock_stream = AsyncMock()
3153-
(
3154-
create_mock_response,
3155-
create_mock_candidate,
3156-
create_mock_content,
3157-
create_mock_part,
3158-
) = _create_google_stream_mock_helpers()
3159-
create_mock_response: Any
3160-
create_mock_candidate: Any
3161-
create_mock_content: Any
3162-
create_mock_part: Any
3163-
async def mock_stream_iterator():
3164-
# This part is a bit of a fiction, in reality the tool call and return would not be in the same stream.
3165-
# But for testing the malformed function call in between, this is fine.
3166-
yield create_mock_response(
3167-
candidates=[
3168-
create_mock_candidate(
3169-
content=create_mock_content(
3170-
create_mock_part(function_call=MagicMock(name='get_capital', args={'country': 'France'}))
3171-
),
3172-
finish_reason=None,
3173-
)
3174-
]
3175-
)
3176-
3177-
yield create_mock_response(
3178-
candidates=[create_mock_candidate(content=None, finish_reason='MALFORMED_FUNCTION_CALL')]
3179-
)
3180-
3181-
yield create_mock_response(
3182-
candidates=[
3183-
create_mock_candidate(
3184-
content=create_mock_content(create_mock_part(text='Here is the capital.')), finish_reason=None
3185-
)
3186-
]
3187-
)
3188-
3189-
yield create_mock_response(candidates=[create_mock_candidate(content=None, finish_reason='STOP')])
3190-
3191-
mock_stream.return_value = mock_stream_iterator()
3192-
model = GoogleModel('gemini-2.5-flash', provider=google_provider)
3193-
monkeypatch.setattr(model.client.aio.models, 'generate_content_stream', mock_stream)
3194-
3195-
agent = Agent(model)
3196-
3197-
@agent.tool_plain
3198-
def get_capital(country: str) -> str:
3199-
"""Get the capital of a country."""
3200-
return 'Paris' if country == 'France' else 'Unknown'
3201-
3202-
async with agent.run_stream('What is the capital of France?') as result:
3203-
data = await result.get_output()
3204-
3205-
assert data == 'Here is the capital.'
3206-
assert mock_stream.call_count == 1
3207-
3208-
3209-
async def test_google_model_stream_malformed_function_call_retry(
3210-
allow_model_requests: None, google_provider: GoogleProvider, monkeypatch: pytest.MonkeyPatch
3211-
):
3212-
"""Test that the model retries when the stream only contains a malformed function call."""
3213-
mock_stream = AsyncMock()
3214-
(
3215-
create_mock_response,
3216-
create_mock_candidate,
3217-
create_mock_content,
3218-
create_mock_part,
3219-
) = _create_google_stream_mock_helpers()
3220-
create_mock_response: Any
3221-
create_mock_candidate: Any
3222-
create_mock_content: Any
3223-
create_mock_part: Any
3224-
async def first_call_iterator():
3225-
yield create_mock_response(
3226-
candidates=[create_mock_candidate(content=None, finish_reason='MALFORMED_FUNCTION_CALL')]
3227-
)
3228-
3229-
async def second_call_iterator():
3230-
yield create_mock_response(
3231-
candidates=[
3232-
create_mock_candidate(
3233-
content=create_mock_content(create_mock_part(text='Successful response')), finish_reason=None
3234-
)
3235-
]
3236-
)
3237-
yield create_mock_response(candidates=[create_mock_candidate(content=None, finish_reason='STOP')])
3238-
3239-
mock_stream.side_effect = [first_call_iterator(), second_call_iterator()]
3240-
3241-
model = GoogleModel('gemini-2.5-flash', provider=google_provider)
3242-
monkeypatch.setattr(model.client.aio.models, 'generate_content_stream', mock_stream)
3243-
3244-
agent = Agent(model, retries=1)
3245-
3246-
async with agent.run_stream('Some prompt') as result:
3247-
data = await result.get_output()
3248-
3249-
assert data == 'Successful response'
3250-
assert mock_stream.call_count == 2
3251-
3252-
3253-
async def test_google_model_malformed_function_call_retry(
3254-
allow_model_requests: None, google_provider: GoogleProvider, monkeypatch: pytest.MonkeyPatch
3255-
):
3256-
"""Test that the model retries when non-streamed response contains a malformed function call."""
3257-
mock_generate = AsyncMock()
3258-
(
3259-
create_mock_response,
3260-
create_mock_candidate,
3261-
create_mock_content,
3262-
create_mock_part,
3263-
) = _create_google_stream_mock_helpers()
3264-
create_mock_response: Any
3265-
create_mock_candidate: Any
3266-
create_mock_content: Any
3267-
create_mock_part: Any
3268-
first_response: MagicMock = create_mock_response(
3269-
candidates=[create_mock_candidate(content=None, finish_reason='MALFORMED_FUNCTION_CALL')]
3270-
)
3271-
3272-
second_response: MagicMock = create_mock_response(
3273-
candidates=[
3274-
create_mock_candidate(
3275-
content=create_mock_content(create_mock_part(text='Successful response')), finish_reason='STOP'
3276-
)
3277-
]
3278-
)
3279-
3280-
mock_generate.side_effect = [first_response, second_response]
3281-
3282-
model = GoogleModel('gemini-2.5-flash', provider=google_provider)
3283-
monkeypatch.setattr(model.client.aio.models, 'generate_content', mock_generate)
3284-
3285-
agent = Agent(model, retries=1)
3286-
3287-
result = await agent.run('Some prompt')
3288-
3289-
assert result.output == 'Successful response'
3290-
assert mock_generate.call_count == 2
3291-
3292-
32933094
async def test_google_httpx_client_is_not_closed(allow_model_requests: None, gemini_api_key: str):
32943095
# This should not raise any errors, see https://github.com/pydantic/pydantic-ai/issues/3242.
32953096
agent = Agent(GoogleModel('gemini-2.5-flash-lite', provider=GoogleProvider(api_key=gemini_api_key)))

0 commit comments

Comments
 (0)