Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions pydantic_ai_slim/pydantic_ai/models/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,13 +677,18 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
provider_name=self.provider_name,
)

if part.text:
if part.thought:
yield self._parts_manager.handle_thinking_delta(vendor_part_id='thinking', content=part.text)
else:
maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=part.text)
if maybe_event is not None: # pragma: no branch
yield maybe_event
if part.text is not None:
if len(part.text) > 0:
if part.thought:
yield self._parts_manager.handle_thinking_delta(
vendor_part_id='thinking', content=part.text
)
else:
maybe_event = self._parts_manager.handle_text_delta(
vendor_part_id='content', content=part.text
)
if maybe_event is not None: # pragma: no branch
yield maybe_event
elif part.function_call:
maybe_event = self._parts_manager.handle_tool_call_delta(
vendor_part_id=uuid4(),
Expand Down Expand Up @@ -822,7 +827,10 @@ def _process_response_from_parts(
elif part.code_execution_result is not None:
assert code_execution_tool_call_id is not None
item = _map_code_execution_result(part.code_execution_result, provider_name, code_execution_tool_call_id)
elif part.text:
elif part.text is not None:
# Google sometimes sends empty text parts, we don't want to add them to the response
if len(part.text) == 0:
continue
if part.thought:
item = ThinkingPart(content=part.text)
else:
Expand Down
59 changes: 58 additions & 1 deletion tests/models/test_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import datetime
import os
import re
from collections.abc import AsyncIterator
from typing import Any

import pytest
Expand Down Expand Up @@ -47,6 +48,7 @@
BuiltinToolCallEvent, # pyright: ignore[reportDeprecated]
BuiltinToolResultEvent, # pyright: ignore[reportDeprecated]
)
from pydantic_ai.models import ModelRequestParameters
from pydantic_ai.output import NativeOutput, PromptedOutput, TextOutput, ToolOutput
from pydantic_ai.settings import ModelSettings
from pydantic_ai.usage import RequestUsage, RunUsage, UsageLimits
Expand All @@ -56,6 +58,7 @@

with try_import() as imports_successful:
from google.genai.types import (
FinishReason as GoogleFinishReason,
GenerateContentResponse,
GenerateContentResponseUsageMetadata,
HarmBlockThreshold,
Expand All @@ -64,7 +67,12 @@
ModalityTokenCount,
)

from pydantic_ai.models.google import GoogleModel, GoogleModelSettings, _metadata_as_usage # type: ignore
from pydantic_ai.models.google import (
GeminiStreamedResponse,
GoogleModel,
GoogleModelSettings,
_metadata_as_usage, # pyright: ignore[reportPrivateUsage]
)
from pydantic_ai.models.openai import OpenAIResponsesModel, OpenAIResponsesModelSettings
from pydantic_ai.providers.google import GoogleProvider
from pydantic_ai.providers.openai import OpenAIProvider
Expand Down Expand Up @@ -3063,3 +3071,52 @@ async def test_google_httpx_client_is_not_closed(allow_model_requests: None, gem
agent = Agent(GoogleModel('gemini-2.5-flash-lite', provider=GoogleProvider(api_key=gemini_api_key)))
result = await agent.run('What is the capital of Mexico?')
assert result.output == snapshot('The capital of Mexico is **Mexico City**.')


def test_google_process_response_filters_empty_text_parts(google_provider: GoogleProvider):
model = GoogleModel('gemini-2.5-pro', provider=google_provider)
response = _generate_response_with_texts(response_id='resp-123', texts=['', 'first', '', 'second'])

result = model._process_response(response) # pyright: ignore[reportPrivateUsage]

assert result.parts == snapshot([TextPart(content='first'), TextPart(content='second')])


async def test_gemini_streamed_response_emits_text_events_for_non_empty_parts():
chunk = _generate_response_with_texts('stream-1', ['', 'streamed text'])

async def response_iterator() -> AsyncIterator[GenerateContentResponse]:
yield chunk

streamed_response = GeminiStreamedResponse(
model_request_parameters=ModelRequestParameters(),
_model_name='gemini-test',
_response=response_iterator(),
_timestamp=datetime.datetime.now(datetime.timezone.utc),
_provider_name='test-provider',
)

events = [event async for event in streamed_response._get_event_iterator()] # pyright: ignore[reportPrivateUsage]
assert events == snapshot([PartStartEvent(index=0, part=TextPart(content='streamed text'))])


def _generate_response_with_texts(response_id: str, texts: list[str]) -> GenerateContentResponse:
return GenerateContentResponse.model_validate(
{
'response_id': response_id,
'model_version': 'gemini-test',
'usage_metadata': GenerateContentResponseUsageMetadata(
prompt_token_count=0,
candidates_token_count=0,
),
'candidates': [
{
'finish_reason': GoogleFinishReason.STOP,
'content': {
'role': 'model',
'parts': [{'text': text} for text in texts],
},
}
],
}
)