Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,11 @@ async def generate_content(call: ServiceCall) -> ServiceResponse:
f"Error generating content due to content violations, reason: {response.prompt_feedback.block_reason_message}"
)

if not response.candidates[0].content.parts:
if (
not response.candidates
or not response.candidates[0].content
or not response.candidates[0].content.parts
):
raise HomeAssistantError("Unknown error generating content")

return {"text": response.text}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ async def google_generative_ai_config_option_schema(
value=api_model.name,
)
for api_model in sorted(
api_models, key=lambda x: x.name.lstrip("models/") or ""
api_models, key=lambda x: (x.name or "").lstrip("models/")
)
if (
api_model.name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import asyncio
import codecs
from collections.abc import AsyncGenerator, Callable
from collections.abc import AsyncGenerator, AsyncIterator, Callable
from dataclasses import replace
import mimetypes
from pathlib import Path
Expand All @@ -15,6 +15,7 @@
from google.genai.types import (
AutomaticFunctionCallingConfig,
Content,
ContentDict,
File,
FileState,
FunctionDeclaration,
Expand All @@ -23,9 +24,11 @@
GoogleSearch,
HarmCategory,
Part,
PartUnionDict,
SafetySetting,
Schema,
Tool,
ToolListUnion,
)
import voluptuous as vol
from voluptuous_openapi import convert
Expand Down Expand Up @@ -237,7 +240,7 @@ def _convert_content(


async def _transform_stream(
result: AsyncGenerator[GenerateContentResponse],
result: AsyncIterator[GenerateContentResponse],
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
new_message = True
try:
Expand Down Expand Up @@ -342,7 +345,7 @@ async def _async_handle_chat_log(
"""Generate an answer for the chat log."""
options = self.subentry.data

tools: list[Tool | Callable[..., Any]] | None = None
tools: ToolListUnion | None = None
if chat_log.llm_api:
tools = [
_format_tool(tool, chat_log.llm_api.custom_serializer)
Expand Down Expand Up @@ -373,7 +376,7 @@ async def _async_handle_chat_log(
else:
raise HomeAssistantError("Invalid prompt content")

messages: list[Content] = []
messages: list[Content | ContentDict] = []

# Google groups tool results, we do not. Group them before sending.
tool_results: list[conversation.ToolResultContent] = []
Expand All @@ -400,7 +403,10 @@ async def _async_handle_chat_log(
# The SDK requires the first message to be a user message
# This is not the case if user used `start_conversation`
# Workaround from https://github.com/googleapis/python-genai/issues/529#issuecomment-2740964537
if messages and messages[0].role != "user":
if messages and (
(isinstance(messages[0], Content) and messages[0].role != "user")
or (isinstance(messages[0], dict) and messages[0]["role"] != "user")
):
messages.insert(
0,
Content(role="user", parts=[Part.from_text(text=" ")]),
Expand Down Expand Up @@ -440,14 +446,14 @@ async def _async_handle_chat_log(
)
user_message = chat_log.content[-1]
assert isinstance(user_message, conversation.UserContent)
chat_request: str | list[Part] = user_message.content
chat_request: list[PartUnionDict] = [user_message.content]
if user_message.attachments:
files = await async_prepare_files_for_prompt(
self.hass,
self._genai_client,
[a.path for a in user_message.attachments],
)
chat_request = [chat_request, *files]
chat_request = [*chat_request, *files]

# To prevent infinite loops, we limit the number of iterations
for _iteration in range(MAX_TOOL_ITERATIONS):
Expand All @@ -464,15 +470,17 @@ async def _async_handle_chat_log(
error = ERROR_GETTING_RESPONSE
raise HomeAssistantError(error) from err

chat_request = _create_google_tool_response_parts(
[
content
async for content in chat_log.async_add_delta_content_stream(
self.entity_id,
_transform_stream(chat_response_generator),
)
if isinstance(content, conversation.ToolResultContent)
]
chat_request = list(
_create_google_tool_response_parts(
[
content
async for content in chat_log.async_add_delta_content_stream(
self.entity_id,
_transform_stream(chat_response_generator),
)
if isinstance(content, conversation.ToolResultContent)
]
)
)

if not chat_log.unresponded_tool_results:
Expand Down Expand Up @@ -559,13 +567,13 @@ async def wait_for_file_processing(uploaded_file: File) -> None:
await asyncio.sleep(FILE_POLLING_INTERVAL_SECONDS)

uploaded_file = await client.aio.files.get(
name=uploaded_file.name,
name=uploaded_file.name or "",
config={"http_options": {"timeout": TIMEOUT_MILLIS}},
)

if uploaded_file.state == FileState.FAILED:
raise HomeAssistantError(
f"File `{uploaded_file.name}` processing failed, reason: {uploaded_file.error.message}"
f"File `{uploaded_file.name}` processing failed, reason: {uploaded_file.error.message if uploaded_file.error else 'unknown'}"
)

prompt_parts = await hass.async_add_executor_job(upload_files)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@
"documentation": "https://www.home-assistant.io/integrations/google_generative_ai_conversation",
"integration_type": "service",
"iot_class": "cloud_polling",
"requirements": ["google-genai==1.7.0"]
"requirements": ["google-genai==1.29.0"]
}
32 changes: 29 additions & 3 deletions homeassistant/components/google_generative_ai_conversation/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,15 +146,41 @@ async def async_get_tts_audio(
)
)
)

def _extract_audio_parts(
response: types.GenerateContentResponse,
) -> tuple[bytes, str]:
if (
not response.candidates
or not response.candidates[0].content
or not response.candidates[0].content.parts
or not response.candidates[0].content.parts[0].inline_data
):
raise ValueError("No content returned from TTS generation")

data = response.candidates[0].content.parts[0].inline_data.data
mime_type = response.candidates[0].content.parts[0].inline_data.mime_type

if not isinstance(data, bytes):
raise TypeError(
f"Expected bytes for audio data, got {type(data).__name__}"
)
if not isinstance(mime_type, str):
raise TypeError(
f"Expected str for mime_type, got {type(mime_type).__name__}"
)

return data, mime_type

try:
response = await self._genai_client.aio.models.generate_content(
model=self.subentry.data.get(CONF_CHAT_MODEL, RECOMMENDED_TTS_MODEL),
contents=message,
config=config,
)
data = response.candidates[0].content.parts[0].inline_data.data
mime_type = response.candidates[0].content.parts[0].inline_data.mime_type
except (APIError, ClientError, ValueError) as exc:

data, mime_type = _extract_audio_parts(response)
except (APIError, ClientError, ValueError, TypeError) as exc:
LOGGER.error("Error during TTS: %s", exc, exc_info=True)
raise HomeAssistantError(exc) from exc
return "wav", convert_to_wav(data, mime_type)
2 changes: 1 addition & 1 deletion requirements_all.txt

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion requirements_test_all.txt

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

33 changes: 3 additions & 30 deletions tests/components/google_generative_ai_conversation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,16 @@
"""Tests for the Google Generative AI Conversation integration."""

from unittest.mock import Mock

from google.genai.errors import APIError, ClientError
import httpx

API_ERROR_500 = APIError(
500,
Mock(
__class__=httpx.Response,
json=Mock(
return_value={
"message": "Internal Server Error",
"status": "internal-error",
}
),
),
{"message": "Internal Server Error", "status": "internal-error"},
)
CLIENT_ERROR_BAD_REQUEST = ClientError(
400,
Mock(
__class__=httpx.Response,
json=Mock(
return_value={
"message": "Bad Request",
"status": "invalid-argument",
}
),
),
{"message": "Bad Request", "status": "invalid-argument"},
)
CLIENT_ERROR_API_KEY_INVALID = ClientError(
400,
Mock(
__class__=httpx.Response,
json=Mock(
return_value={
"message": "'reason': API_KEY_INVALID",
"status": "unauthorized",
}
),
),
{"message": "'reason': API_KEY_INVALID", "status": "unauthorized"},
)
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,14 @@
dict({
'contents': list([
'Describe this image from my doorbell camera',
File(name='doorbell_snapshot.jpg', display_name=None, mime_type=None, size_bytes=None, create_time=None, expiration_time=None, update_time=None, sha256_hash=None, uri=None, download_uri=None, state=<FileState.ACTIVE: 'ACTIVE'>, source=None, video_metadata=None, error=None),
File(name='context.txt', display_name=None, mime_type=None, size_bytes=None, create_time=None, expiration_time=None, update_time=None, sha256_hash=None, uri=None, download_uri=None, state=<FileState.PROCESSING: 'PROCESSING'>, source=None, video_metadata=None, error=None),
File(
name='doorbell_snapshot.jpg',
state=<FileState.ACTIVE: 'ACTIVE'>
),
File(
name='context.txt',
state=<FileState.PROCESSING: 'PROCESSING'>
),
]),
'model': 'models/gemini-2.5-flash',
}),
Expand All @@ -145,8 +151,14 @@
dict({
'contents': list([
'Describe this image from my doorbell camera',
File(name='doorbell_snapshot.jpg', display_name=None, mime_type=None, size_bytes=None, create_time=None, expiration_time=None, update_time=None, sha256_hash=None, uri=None, download_uri=None, state=<FileState.ACTIVE: 'ACTIVE'>, source=None, video_metadata=None, error=None),
File(name='context.txt', display_name=None, mime_type=None, size_bytes=None, create_time=None, expiration_time=None, update_time=None, sha256_hash=None, uri=None, download_uri=None, state=<FileState.ACTIVE: 'ACTIVE'>, source=None, video_metadata=None, error=None),
File(
name='doorbell_snapshot.jpg',
state=<FileState.ACTIVE: 'ACTIVE'>
),
File(
name='context.txt',
state=<FileState.ACTIVE: 'ACTIVE'>
),
]),
'model': 'models/gemini-2.5-flash',
}),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,10 +195,13 @@ async def test_function_call(
"response": {
"result": "Test response",
},
"scheduling": None,
"will_continue": None,
},
"inline_data": None,
"text": None,
"thought": None,
"thought_signature": None,
"video_metadata": None,
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from tests.components.tts.common import retrieve_media
from tests.typing import ClientSessionGenerator

API_ERROR_500 = APIError("test", response=MagicMock())
API_ERROR_500 = APIError("test", response_json={})
TEST_CHAT_MODEL = "models/some-tts-model"


Expand Down
Loading