Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import asyncio
from functools import partial
import mimetypes
from pathlib import Path
from types import MappingProxyType
Expand Down Expand Up @@ -37,11 +38,13 @@

from .const import (
CONF_PROMPT,
DEFAULT_AI_TASK_NAME,
DEFAULT_TITLE,
DEFAULT_TTS_NAME,
DOMAIN,
FILE_POLLING_INTERVAL_SECONDS,
LOGGER,
RECOMMENDED_AI_TASK_OPTIONS,
RECOMMENDED_CHAT_MODEL,
RECOMMENDED_TTS_OPTIONS,
TIMEOUT_MILLIS,
Expand All @@ -53,6 +56,7 @@

CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
PLATFORMS = (
Platform.AI_TASK,
Platform.CONVERSATION,
Platform.TTS,
)
Expand Down Expand Up @@ -187,11 +191,9 @@ async def async_setup_entry(
"""Set up Google Generative AI Conversation from a config entry."""

try:

def _init_client() -> Client:
return Client(api_key=entry.data[CONF_API_KEY])

client = await hass.async_add_executor_job(_init_client)
client = await hass.async_add_executor_job(
partial(Client, api_key=entry.data[CONF_API_KEY])
)
await client.aio.models.get(
model=RECOMMENDED_CHAT_MODEL,
config={"http_options": {"timeout": TIMEOUT_MILLIS}},
Expand Down Expand Up @@ -350,6 +352,19 @@ async def async_migrate_entry(

hass.config_entries.async_update_entry(entry, minor_version=2)

if entry.version == 2 and entry.minor_version == 2:
# Add AI Task subentry with default options
hass.config_entries.async_add_subentry(
entry,
ConfigSubentry(
data=MappingProxyType(RECOMMENDED_AI_TASK_OPTIONS),
subentry_type="ai_task_data",
title=DEFAULT_AI_TASK_NAME,
unique_id=None,
),
)
hass.config_entries.async_update_entry(entry, minor_version=3)

LOGGER.debug(
"Migration to version %s:%s successful", entry.version, entry.minor_version
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""AI Task integration for Google Generative AI Conversation."""

from __future__ import annotations

from homeassistant.components import ai_task, conversation
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback

from .const import LOGGER
from .entity import ERROR_GETTING_RESPONSE, GoogleGenerativeAILLMBaseEntity


async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddConfigEntryEntitiesCallback,
) -> None:
"""Set up AI Task entities."""
for subentry in config_entry.subentries.values():
if subentry.subentry_type != "ai_task_data":
continue

async_add_entities(
[GoogleGenerativeAITaskEntity(config_entry, subentry)],
config_subentry_id=subentry.subentry_id,
)


class GoogleGenerativeAITaskEntity(
ai_task.AITaskEntity,
GoogleGenerativeAILLMBaseEntity,
):
"""Google Generative AI AI Task entity."""

_attr_supported_features = ai_task.AITaskEntityFeature.GENERATE_DATA

async def _async_generate_data(
self,
task: ai_task.GenDataTask,
chat_log: conversation.ChatLog,
) -> ai_task.GenDataTaskResult:
"""Handle a generate data task."""
await self._async_handle_chat_log(chat_log)

if not isinstance(chat_log.content[-1], conversation.AssistantContent):
LOGGER.error(
"Last content in chat log is not an AssistantContent: %s. This could be due to the model not returning a valid response",
chat_log.content[-1],
)
raise HomeAssistantError(ERROR_GETTING_RESPONSE)

return ai_task.GenDataTaskResult(
conversation_id=chat_log.conversation_id,
data=chat_log.content[-1].content or "",
)
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

from collections.abc import Mapping
from functools import partial
import logging
from typing import Any, cast

Expand Down Expand Up @@ -46,10 +47,12 @@
CONF_TOP_K,
CONF_TOP_P,
CONF_USE_GOOGLE_SEARCH_TOOL,
DEFAULT_AI_TASK_NAME,
DEFAULT_CONVERSATION_NAME,
DEFAULT_TITLE,
DEFAULT_TTS_NAME,
DOMAIN,
RECOMMENDED_AI_TASK_OPTIONS,
RECOMMENDED_CHAT_MODEL,
RECOMMENDED_CONVERSATION_OPTIONS,
RECOMMENDED_HARM_BLOCK_THRESHOLD,
Expand All @@ -72,12 +75,14 @@
)


async def validate_input(data: dict[str, Any]) -> None:
async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None:
"""Validate the user input allows us to connect.

Data has the keys from STEP_USER_DATA_SCHEMA with values provided by the user.
"""
client = genai.Client(api_key=data[CONF_API_KEY])
client = await hass.async_add_executor_job(
partial(genai.Client, api_key=data[CONF_API_KEY])
)
await client.aio.models.list(
config={
"http_options": {
Expand All @@ -92,7 +97,7 @@ class GoogleGenerativeAIConfigFlow(ConfigFlow, domain=DOMAIN):
"""Handle a config flow for Google Generative AI Conversation."""

VERSION = 2
MINOR_VERSION = 2
MINOR_VERSION = 3

async def async_step_api(
self, user_input: dict[str, Any] | None = None
Expand All @@ -102,7 +107,7 @@ async def async_step_api(
if user_input is not None:
self._async_abort_entries_match(user_input)
try:
await validate_input(user_input)
await validate_input(self.hass, user_input)
except (APIError, Timeout) as err:
if isinstance(err, ClientError) and "API_KEY_INVALID" in str(err):
errors["base"] = "invalid_auth"
Expand Down Expand Up @@ -133,6 +138,12 @@ async def async_step_api(
"title": DEFAULT_TTS_NAME,
"unique_id": None,
},
{
"subentry_type": "ai_task_data",
"data": RECOMMENDED_AI_TASK_OPTIONS,
"title": DEFAULT_AI_TASK_NAME,
"unique_id": None,
},
],
)
return self.async_show_form(
Expand Down Expand Up @@ -181,6 +192,7 @@ def async_get_supported_subentry_types(
return {
"conversation": LLMSubentryFlowHandler,
"tts": LLMSubentryFlowHandler,
"ai_task_data": LLMSubentryFlowHandler,
}


Expand Down Expand Up @@ -214,6 +226,8 @@ async def async_step_set_options(
options: dict[str, Any]
if self._subentry_type == "tts":
options = RECOMMENDED_TTS_OPTIONS.copy()
elif self._subentry_type == "ai_task_data":
options = RECOMMENDED_AI_TASK_OPTIONS.copy()
else:
options = RECOMMENDED_CONVERSATION_OPTIONS.copy()
else:
Expand Down Expand Up @@ -288,6 +302,8 @@ async def google_generative_ai_config_option_schema(
default_name = options[CONF_NAME]
elif subentry_type == "tts":
default_name = DEFAULT_TTS_NAME
elif subentry_type == "ai_task_data":
default_name = DEFAULT_AI_TASK_NAME
else:
default_name = DEFAULT_CONVERSATION_NAME
schema: dict[vol.Required | vol.Optional, Any] = {
Expand Down Expand Up @@ -315,6 +331,7 @@ async def google_generative_ai_config_option_schema(
),
}
)

schema.update(
{
vol.Required(
Expand Down Expand Up @@ -443,4 +460,5 @@ async def google_generative_ai_config_option_schema(
): bool,
}
)

return schema
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

DEFAULT_CONVERSATION_NAME = "Google AI Conversation"
DEFAULT_TTS_NAME = "Google AI TTS"
DEFAULT_AI_TASK_NAME = "Google AI Task"

CONF_RECOMMENDED = "recommended"
CONF_CHAT_MODEL = "chat_model"
Expand All @@ -35,6 +36,7 @@

TIMEOUT_MILLIS = 10000
FILE_POLLING_INTERVAL_SECONDS = 0.05

RECOMMENDED_CONVERSATION_OPTIONS = {
CONF_PROMPT: llm.DEFAULT_INSTRUCTIONS_PROMPT,
CONF_LLM_HASS_API: [llm.LLM_API_ASSIST],
Expand All @@ -44,3 +46,7 @@
RECOMMENDED_TTS_OPTIONS = {
CONF_RECOMMENDED: True,
}

RECOMMENDED_AI_TASK_OPTIONS = {
CONF_RECOMMENDED: True,
}
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,34 @@
"entry_not_loaded": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::abort::entry_not_loaded%]",
"reconfigure_successful": "[%key:common::config_flow::abort::reconfigure_successful%]"
}
},
"ai_task_data": {
"initiate_flow": {
"user": "Add Generate data with AI service",
"reconfigure": "Reconfigure Generate data with AI service"
},
"entry_type": "Generate data with AI service",
"step": {
"set_options": {
"data": {
"name": "[%key:common::config_flow::data::name%]",
"recommended": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::recommended%]",
"chat_model": "[%key:common::generic::model%]",
"temperature": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::temperature%]",
"top_p": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::top_p%]",
"top_k": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::top_k%]",
"max_tokens": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::max_tokens%]",
"harassment_block_threshold": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::harassment_block_threshold%]",
"hate_block_threshold": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::hate_block_threshold%]",
"sexual_block_threshold": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::sexual_block_threshold%]",
"dangerous_block_threshold": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::dangerous_block_threshold%]"
}
}
},
"abort": {
"entry_not_loaded": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::abort::entry_not_loaded%]",
"reconfigure_successful": "[%key:common::config_flow::abort::reconfigure_successful%]"
}
}
},
"services": {
Expand Down
Loading
Loading