Skip to content
7 changes: 6 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,11 +471,14 @@ def _process_response(self, response: GenerateContentResponse) -> ModelResponse:
raise UnexpectedModelBehavior(
f'Content filter {raw_finish_reason.value!r} triggered', response.model_dump_json()
)
elif raw_finish_reason and raw_finish_reason.value == 'MALFORMED_FUNCTION_CALL':
parts = []
else:
raise UnexpectedModelBehavior(
'Content field missing from Gemini response', response.model_dump_json()
) # pragma: no cover
parts = candidate.content.parts or []
else:
parts = candidate.content.parts or []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that we could get here if candidate.content is None

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If candidate.content is None, parts would be [] and we would continue with that, no?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, you're totally right, I wasn't reading right


usage = _metadata_as_usage(response)
return _process_response_from_parts(
Expand Down Expand Up @@ -652,6 +655,8 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
if self.finish_reason == 'stop': # pragma: no cover
# Normal completion - skip this chunk
continue
elif raw_finish_reason and raw_finish_reason.value == 'MALFORMED_FUNCTION_CALL':
continue
elif self.finish_reason == 'content_filter' and raw_finish_reason: # pragma: no cover
raise UnexpectedModelBehavior(
f'Content filter {raw_finish_reason.value!r} triggered', chunk.model_dump_json()
Expand Down
208 changes: 207 additions & 1 deletion tests/models/test_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from inline_snapshot import Is, snapshot
from pydantic import BaseModel
from typing_extensions import TypedDict
from unittest.mock import AsyncMock, MagicMock

from pydantic_ai import (
AudioUrl,
Expand Down Expand Up @@ -3081,7 +3082,212 @@ async def test_google_vertexai_image_generation(allow_model_requests: None, vert
agent = Agent(model, output_type=BinaryImage)

result = await agent.run('Generate an image of an axolotl.')
assert result.output == snapshot(BinaryImage(data=IsBytes(), media_type='image/png', identifier='b037a4'))
assert result.output == snapshot(
BinaryImage(
data=IsBytes(),
media_type='image/png',
identifier='f3edd8',
)
)


def _create_google_stream_mock_helpers():
from google.genai.types import Candidate, Content, GenerateContentResponse, Part

class _FR(str):
def __new__(cls, name: str):
obj = str.__new__(cls, name)
return obj

@property
def value(self) -> str:
return str(self)

def create_mock_response(**kwargs: Any) -> MagicMock:
mock = MagicMock(spec=GenerateContentResponse)
for key, value in kwargs.items():
setattr(mock, key, value)
mock.model_version = 'gemini-1.5-flash'
mock.create_time = None
mock.response_id = 'test-response-id'
mock.usage_metadata = None
return mock

def create_mock_candidate(**kwargs: Any) -> MagicMock:
mock = MagicMock(spec=Candidate)
mock.grounding_metadata = None
if 'finish_reason' in kwargs and isinstance(kwargs['finish_reason'], str):
kwargs['finish_reason'] = _FR(kwargs['finish_reason'])
for key, value in kwargs.items():
setattr(mock, key, value)
return mock

def create_mock_content(*parts: MagicMock) -> MagicMock:
mock = MagicMock(spec=Content)
mock.parts = list(parts)
return mock

def create_mock_part(**kwargs: Any) -> MagicMock:
mock = MagicMock(spec=Part)
mock.text = None
mock.function_call = None
mock.thought_signature = None
mock.thought = None
mock.inline_data = None
mock.executable_code = None
mock.code_execution_result = None
mock.function_response = None

for key, value in kwargs.items():
setattr(mock, key, value)
return mock

return create_mock_response, create_mock_candidate, create_mock_content, create_mock_part


async def test_google_model_stream_malformed_function_call_continue(
allow_model_requests: None, google_provider: GoogleProvider, monkeypatch: pytest.MonkeyPatch
):
"""Test that the model can continue streaming after a malformed function call."""
mock_stream = AsyncMock()
(
create_mock_response,
create_mock_candidate,
create_mock_content,
create_mock_part,
) = _create_google_stream_mock_helpers()
create_mock_response: Any
create_mock_candidate: Any
create_mock_content: Any
create_mock_part: Any
async def mock_stream_iterator():
# This part is a bit of a fiction, in reality the tool call and return would not be in the same stream.
# But for testing the malformed function call in between, this is fine.
yield create_mock_response(
candidates=[
create_mock_candidate(
content=create_mock_content(
create_mock_part(function_call=MagicMock(name='get_capital', args={'country': 'France'}))
),
finish_reason=None,
)
]
)

yield create_mock_response(
candidates=[create_mock_candidate(content=None, finish_reason='MALFORMED_FUNCTION_CALL')]
)

yield create_mock_response(
candidates=[
create_mock_candidate(
content=create_mock_content(create_mock_part(text='Here is the capital.')), finish_reason=None
)
]
)

yield create_mock_response(candidates=[create_mock_candidate(content=None, finish_reason='STOP')])

mock_stream.return_value = mock_stream_iterator()
model = GoogleModel('gemini-2.5-flash', provider=google_provider)
monkeypatch.setattr(model.client.aio.models, 'generate_content_stream', mock_stream)

agent = Agent(model)

@agent.tool_plain
def get_capital(country: str) -> str:
"""Get the capital of a country."""
return 'Paris' if country == 'France' else 'Unknown'

async with agent.run_stream('What is the capital of France?') as result:
data = await result.get_output()

assert data == 'Here is the capital.'
assert mock_stream.call_count == 1


async def test_google_model_stream_malformed_function_call_retry(
allow_model_requests: None, google_provider: GoogleProvider, monkeypatch: pytest.MonkeyPatch
):
"""Test that the model retries when the stream only contains a malformed function call."""
mock_stream = AsyncMock()
(
create_mock_response,
create_mock_candidate,
create_mock_content,
create_mock_part,
) = _create_google_stream_mock_helpers()
create_mock_response: Any
create_mock_candidate: Any
create_mock_content: Any
create_mock_part: Any
async def first_call_iterator():
yield create_mock_response(
candidates=[create_mock_candidate(content=None, finish_reason='MALFORMED_FUNCTION_CALL')]
)

async def second_call_iterator():
yield create_mock_response(
candidates=[
create_mock_candidate(
content=create_mock_content(create_mock_part(text='Successful response')), finish_reason=None
)
]
)
yield create_mock_response(candidates=[create_mock_candidate(content=None, finish_reason='STOP')])

mock_stream.side_effect = [first_call_iterator(), second_call_iterator()]

model = GoogleModel('gemini-2.5-flash', provider=google_provider)
monkeypatch.setattr(model.client.aio.models, 'generate_content_stream', mock_stream)

agent = Agent(model, retries=1)

async with agent.run_stream('Some prompt') as result:
data = await result.get_output()

assert data == 'Successful response'
assert mock_stream.call_count == 2


async def test_google_model_malformed_function_call_retry(
allow_model_requests: None, google_provider: GoogleProvider, monkeypatch: pytest.MonkeyPatch
):
"""Test that the model retries when non-streamed response contains a malformed function call."""
mock_generate = AsyncMock()
(
create_mock_response,
create_mock_candidate,
create_mock_content,
create_mock_part,
) = _create_google_stream_mock_helpers()
create_mock_response: Any
create_mock_candidate: Any
create_mock_content: Any
create_mock_part: Any
first_response: MagicMock = create_mock_response(
candidates=[create_mock_candidate(content=None, finish_reason='MALFORMED_FUNCTION_CALL')]
)

second_response: MagicMock = create_mock_response(
candidates=[
create_mock_candidate(
content=create_mock_content(create_mock_part(text='Successful response')), finish_reason='STOP'
)
]
)

mock_generate.side_effect = [first_response, second_response]

model = GoogleModel('gemini-2.5-flash', provider=google_provider)
monkeypatch.setattr(model.client.aio.models, 'generate_content', mock_generate)

agent = Agent(model, retries=1)

result = await agent.run('Some prompt')

assert result.output == 'Successful response'
assert mock_generate.call_count == 2


async def test_google_httpx_client_is_not_closed(allow_model_requests: None, gemini_api_key: str):
Expand Down
Loading