diff --git a/homeassistant/components/eq3btsmart/__init__.py b/homeassistant/components/eq3btsmart/__init__.py index 4493f944db3cf6..b4be3cf5ee9851 100644 --- a/homeassistant/components/eq3btsmart/__init__.py +++ b/homeassistant/components/eq3btsmart/__init__.py @@ -6,7 +6,6 @@ from eq3btsmart import Thermostat from eq3btsmart.exceptions import Eq3Exception -from eq3btsmart.thermostat_config import ThermostatConfig from homeassistant.components import bluetooth from homeassistant.config_entries import ConfigEntry @@ -53,12 +52,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: Eq3ConfigEntry) -> bool: f"[{eq3_config.mac_address}] Device could not be found" ) - thermostat = Thermostat( - thermostat_config=ThermostatConfig( - mac_address=mac_address, - ), - ble_device=device, - ) + thermostat = Thermostat(mac_address=device) # type: ignore[arg-type] entry.runtime_data = Eq3ConfigEntryData( eq3_config=eq3_config, thermostat=thermostat diff --git a/homeassistant/components/eq3btsmart/binary_sensor.py b/homeassistant/components/eq3btsmart/binary_sensor.py index 55b1f4d6ced6dd..8cec495f017708 100644 --- a/homeassistant/components/eq3btsmart/binary_sensor.py +++ b/homeassistant/components/eq3btsmart/binary_sensor.py @@ -2,7 +2,6 @@ from collections.abc import Callable from dataclasses import dataclass -from typing import TYPE_CHECKING from eq3btsmart.models import Status @@ -80,7 +79,4 @@ def __init__( def is_on(self) -> bool: """Return the state of the binary sensor.""" - if TYPE_CHECKING: - assert self._thermostat.status is not None - return self.entity_description.value_func(self._thermostat.status) diff --git a/homeassistant/components/eq3btsmart/climate.py b/homeassistant/components/eq3btsmart/climate.py index 738efa99187d7b..c11328c7ec3e58 100644 --- a/homeassistant/components/eq3btsmart/climate.py +++ b/homeassistant/components/eq3btsmart/climate.py @@ -1,9 +1,16 @@ """Platform for eQ-3 climate entities.""" +from datetime import timedelta import logging from typing import Any -from eq3btsmart.const import EQ3BT_MAX_TEMP, EQ3BT_OFF_TEMP, Eq3Preset, OperationMode +from eq3btsmart.const import ( + EQ3_DEFAULT_AWAY_TEMP, + EQ3_MAX_TEMP, + EQ3_OFF_TEMP, + Eq3OperationMode, + Eq3Preset, +) from eq3btsmart.exceptions import Eq3Exception from homeassistant.components.climate import ( @@ -20,9 +27,11 @@ from homeassistant.helpers import device_registry as dr from homeassistant.helpers.device_registry import CONNECTION_BLUETOOTH from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback +import homeassistant.util.dt as dt_util from . import Eq3ConfigEntry from .const import ( + DEFAULT_AWAY_HOURS, EQ_TO_HA_HVAC, HA_TO_EQ_HVAC, CurrentTemperatureSelector, @@ -57,8 +66,8 @@ class Eq3Climate(Eq3Entity, ClimateEntity): | ClimateEntityFeature.TURN_ON ) _attr_temperature_unit = UnitOfTemperature.CELSIUS - _attr_min_temp = EQ3BT_OFF_TEMP - _attr_max_temp = EQ3BT_MAX_TEMP + _attr_min_temp = EQ3_OFF_TEMP + _attr_max_temp = EQ3_MAX_TEMP _attr_precision = PRECISION_HALVES _attr_hvac_modes = list(HA_TO_EQ_HVAC.keys()) _attr_preset_modes = list(Preset) @@ -70,38 +79,21 @@ class Eq3Climate(Eq3Entity, ClimateEntity): _target_temperature: float | None = None @callback - def _async_on_updated(self) -> None: - """Handle updated data from the thermostat.""" - - if self._thermostat.status is not None: - self._async_on_status_updated() - - if self._thermostat.device_data is not None: - self._async_on_device_updated() - - super()._async_on_updated() - - @callback - def _async_on_status_updated(self) -> None: + def _async_on_status_updated(self, data: Any) -> None: """Handle updated status from the thermostat.""" - if self._thermostat.status is None: - return - - self._target_temperature = self._thermostat.status.target_temperature.value + self._target_temperature = self._thermostat.status.target_temperature self._attr_hvac_mode = EQ_TO_HA_HVAC[self._thermostat.status.operation_mode] self._attr_current_temperature = self._get_current_temperature() self._attr_target_temperature = self._get_target_temperature() self._attr_preset_mode = self._get_current_preset_mode() self._attr_hvac_action = self._get_current_hvac_action() + super()._async_on_status_updated(data) @callback - def _async_on_device_updated(self) -> None: + def _async_on_device_updated(self, data: Any) -> None: """Handle updated device data from the thermostat.""" - if self._thermostat.device_data is None: - return - device_registry = dr.async_get(self.hass) if device := device_registry.async_get_device( connections={(CONNECTION_BLUETOOTH, self._eq3_config.mac_address)}, @@ -109,8 +101,9 @@ def _async_on_device_updated(self) -> None: device_registry.async_update_device( device.id, sw_version=str(self._thermostat.device_data.firmware_version), - serial_number=self._thermostat.device_data.device_serial.value, + serial_number=self._thermostat.device_data.device_serial, ) + super()._async_on_device_updated(data) def _get_current_temperature(self) -> float | None: """Return the current temperature.""" @@ -119,17 +112,11 @@ def _get_current_temperature(self) -> float | None: case CurrentTemperatureSelector.NOTHING: return None case CurrentTemperatureSelector.VALVE: - if self._thermostat.status is None: - return None - return float(self._thermostat.status.valve_temperature) case CurrentTemperatureSelector.UI: return self._target_temperature case CurrentTemperatureSelector.DEVICE: - if self._thermostat.status is None: - return None - - return float(self._thermostat.status.target_temperature.value) + return float(self._thermostat.status.target_temperature) case CurrentTemperatureSelector.ENTITY: state = self.hass.states.get(self._eq3_config.external_temp_sensor) if state is not None: @@ -147,16 +134,12 @@ def _get_target_temperature(self) -> float | None: case TargetTemperatureSelector.TARGET: return self._target_temperature case TargetTemperatureSelector.LAST_REPORTED: - if self._thermostat.status is None: - return None - - return float(self._thermostat.status.target_temperature.value) + return float(self._thermostat.status.target_temperature) def _get_current_preset_mode(self) -> str: """Return the current preset mode.""" - if (status := self._thermostat.status) is None: - return PRESET_NONE + status = self._thermostat.status if status.is_window_open: return Preset.WINDOW_OPEN if status.is_boost: @@ -165,7 +148,7 @@ def _get_current_preset_mode(self) -> str: return Preset.LOW_BATTERY if status.is_away: return Preset.AWAY - if status.operation_mode is OperationMode.ON: + if status.operation_mode is Eq3OperationMode.ON: return Preset.OPEN if status.presets is None: return PRESET_NONE @@ -179,10 +162,7 @@ def _get_current_preset_mode(self) -> str: def _get_current_hvac_action(self) -> HVACAction: """Return the current hvac action.""" - if ( - self._thermostat.status is None - or self._thermostat.status.operation_mode is OperationMode.OFF - ): + if self._thermostat.status.operation_mode is Eq3OperationMode.OFF: return HVACAction.OFF if self._thermostat.status.valve == 0: return HVACAction.IDLE @@ -227,7 +207,7 @@ async def async_set_hvac_mode(self, hvac_mode: HVACMode) -> None: """Set new target hvac mode.""" if hvac_mode is HVACMode.OFF: - await self.async_set_temperature(temperature=EQ3BT_OFF_TEMP) + await self.async_set_temperature(temperature=EQ3_OFF_TEMP) try: await self._thermostat.async_set_mode(HA_TO_EQ_HVAC[hvac_mode]) @@ -241,10 +221,11 @@ async def async_set_preset_mode(self, preset_mode: str) -> None: case Preset.BOOST: await self._thermostat.async_set_boost(True) case Preset.AWAY: - await self._thermostat.async_set_away(True) + away_until = dt_util.now() + timedelta(hours=DEFAULT_AWAY_HOURS) + await self._thermostat.async_set_away(away_until, EQ3_DEFAULT_AWAY_TEMP) case Preset.ECO: await self._thermostat.async_set_preset(Eq3Preset.ECO) case Preset.COMFORT: await self._thermostat.async_set_preset(Eq3Preset.COMFORT) case Preset.OPEN: - await self._thermostat.async_set_mode(OperationMode.ON) + await self._thermostat.async_set_mode(Eq3OperationMode.ON) diff --git a/homeassistant/components/eq3btsmart/const.py b/homeassistant/components/eq3btsmart/const.py index a5f7ea2ff9518e..33698d2d076ade 100644 --- a/homeassistant/components/eq3btsmart/const.py +++ b/homeassistant/components/eq3btsmart/const.py @@ -2,7 +2,7 @@ from enum import Enum -from eq3btsmart.const import OperationMode +from eq3btsmart.const import Eq3OperationMode from homeassistant.components.climate import ( PRESET_AWAY, @@ -34,17 +34,17 @@ GET_DEVICE_TIMEOUT = 5 # seconds -EQ_TO_HA_HVAC: dict[OperationMode, HVACMode] = { - OperationMode.OFF: HVACMode.OFF, - OperationMode.ON: HVACMode.HEAT, - OperationMode.AUTO: HVACMode.AUTO, - OperationMode.MANUAL: HVACMode.HEAT, +EQ_TO_HA_HVAC: dict[Eq3OperationMode, HVACMode] = { + Eq3OperationMode.OFF: HVACMode.OFF, + Eq3OperationMode.ON: HVACMode.HEAT, + Eq3OperationMode.AUTO: HVACMode.AUTO, + Eq3OperationMode.MANUAL: HVACMode.HEAT, } HA_TO_EQ_HVAC = { - HVACMode.OFF: OperationMode.OFF, - HVACMode.AUTO: OperationMode.AUTO, - HVACMode.HEAT: OperationMode.MANUAL, + HVACMode.OFF: Eq3OperationMode.OFF, + HVACMode.AUTO: Eq3OperationMode.AUTO, + HVACMode.HEAT: Eq3OperationMode.MANUAL, } @@ -81,6 +81,7 @@ class TargetTemperatureSelector(str, Enum): DEFAULT_CURRENT_TEMP_SELECTOR = CurrentTemperatureSelector.DEVICE DEFAULT_TARGET_TEMP_SELECTOR = TargetTemperatureSelector.TARGET DEFAULT_SCAN_INTERVAL = 10 # seconds +DEFAULT_AWAY_HOURS = 30 * 24 SIGNAL_THERMOSTAT_DISCONNECTED = f"{DOMAIN}.thermostat_disconnected" SIGNAL_THERMOSTAT_CONNECTED = f"{DOMAIN}.thermostat_connected" diff --git a/homeassistant/components/eq3btsmart/entity.py b/homeassistant/components/eq3btsmart/entity.py index e68545c08c7e59..e8dbb93428907d 100644 --- a/homeassistant/components/eq3btsmart/entity.py +++ b/homeassistant/components/eq3btsmart/entity.py @@ -1,5 +1,10 @@ """Base class for all eQ-3 entities.""" +from typing import Any + +from eq3btsmart import Eq3Exception +from eq3btsmart.const import Eq3Event + from homeassistant.core import callback from homeassistant.helpers.device_registry import ( CONNECTION_BLUETOOTH, @@ -45,7 +50,15 @@ def __init__( async def async_added_to_hass(self) -> None: """Run when entity about to be added to hass.""" - self._thermostat.register_update_callback(self._async_on_updated) + self._thermostat.register_callback( + Eq3Event.DEVICE_DATA_RECEIVED, self._async_on_device_updated + ) + self._thermostat.register_callback( + Eq3Event.STATUS_RECEIVED, self._async_on_status_updated + ) + self._thermostat.register_callback( + Eq3Event.SCHEDULE_RECEIVED, self._async_on_status_updated + ) self.async_on_remove( async_dispatcher_connect( @@ -65,10 +78,25 @@ async def async_added_to_hass(self) -> None: async def async_will_remove_from_hass(self) -> None: """Run when entity will be removed from hass.""" - self._thermostat.unregister_update_callback(self._async_on_updated) + self._thermostat.unregister_callback( + Eq3Event.DEVICE_DATA_RECEIVED, self._async_on_device_updated + ) + self._thermostat.unregister_callback( + Eq3Event.STATUS_RECEIVED, self._async_on_status_updated + ) + self._thermostat.unregister_callback( + Eq3Event.SCHEDULE_RECEIVED, self._async_on_status_updated + ) + + @callback + def _async_on_status_updated(self, data: Any) -> None: + """Handle updated status from the thermostat.""" - def _async_on_updated(self) -> None: - """Handle updated data from the thermostat.""" + self.async_write_ha_state() + + @callback + def _async_on_device_updated(self, data: Any) -> None: + """Handle updated device data from the thermostat.""" self.async_write_ha_state() @@ -90,4 +118,9 @@ def _async_on_connected(self) -> None: def available(self) -> bool: """Whether the entity is available.""" - return self._thermostat.status is not None and self._attr_available + try: + _ = self._thermostat.status + except Eq3Exception: + return False + + return self._attr_available diff --git a/homeassistant/components/eq3btsmart/manifest.json b/homeassistant/components/eq3btsmart/manifest.json index 889401ffc3ee37..62128077f2febe 100644 --- a/homeassistant/components/eq3btsmart/manifest.json +++ b/homeassistant/components/eq3btsmart/manifest.json @@ -22,5 +22,5 @@ "integration_type": "device", "iot_class": "local_polling", "loggers": ["eq3btsmart"], - "requirements": ["eq3btsmart==1.4.1", "bleak-esphome==2.16.0"] + "requirements": ["eq3btsmart==2.1.0", "bleak-esphome==2.16.0"] } diff --git a/homeassistant/components/eq3btsmart/number.py b/homeassistant/components/eq3btsmart/number.py index c3cbd8eae315b2..c9601a4437efa8 100644 --- a/homeassistant/components/eq3btsmart/number.py +++ b/homeassistant/components/eq3btsmart/number.py @@ -1,17 +1,12 @@ """Platform for eq3 number entities.""" -from collections.abc import Awaitable, Callable +from collections.abc import Callable, Coroutine from dataclasses import dataclass from typing import TYPE_CHECKING from eq3btsmart import Thermostat -from eq3btsmart.const import ( - EQ3BT_MAX_OFFSET, - EQ3BT_MAX_TEMP, - EQ3BT_MIN_OFFSET, - EQ3BT_MIN_TEMP, -) -from eq3btsmart.models import Presets +from eq3btsmart.const import EQ3_MAX_OFFSET, EQ3_MAX_TEMP, EQ3_MIN_OFFSET, EQ3_MIN_TEMP +from eq3btsmart.models import Presets, Status from homeassistant.components.number import ( NumberDeviceClass, @@ -42,7 +37,7 @@ class Eq3NumberEntityDescription(NumberEntityDescription): value_func: Callable[[Presets], float] value_set_func: Callable[ [Thermostat], - Callable[[float], Awaitable[None]], + Callable[[float], Coroutine[None, None, Status]], ] mode: NumberMode = NumberMode.BOX entity_category: EntityCategory | None = EntityCategory.CONFIG @@ -51,44 +46,44 @@ class Eq3NumberEntityDescription(NumberEntityDescription): NUMBER_ENTITY_DESCRIPTIONS = [ Eq3NumberEntityDescription( key=ENTITY_KEY_COMFORT, - value_func=lambda presets: presets.comfort_temperature.value, + value_func=lambda presets: presets.comfort_temperature, value_set_func=lambda thermostat: thermostat.async_configure_comfort_temperature, translation_key=ENTITY_KEY_COMFORT, - native_min_value=EQ3BT_MIN_TEMP, - native_max_value=EQ3BT_MAX_TEMP, + native_min_value=EQ3_MIN_TEMP, + native_max_value=EQ3_MAX_TEMP, native_step=EQ3BT_STEP, native_unit_of_measurement=UnitOfTemperature.CELSIUS, device_class=NumberDeviceClass.TEMPERATURE, ), Eq3NumberEntityDescription( key=ENTITY_KEY_ECO, - value_func=lambda presets: presets.eco_temperature.value, + value_func=lambda presets: presets.eco_temperature, value_set_func=lambda thermostat: thermostat.async_configure_eco_temperature, translation_key=ENTITY_KEY_ECO, - native_min_value=EQ3BT_MIN_TEMP, - native_max_value=EQ3BT_MAX_TEMP, + native_min_value=EQ3_MIN_TEMP, + native_max_value=EQ3_MAX_TEMP, native_step=EQ3BT_STEP, native_unit_of_measurement=UnitOfTemperature.CELSIUS, device_class=NumberDeviceClass.TEMPERATURE, ), Eq3NumberEntityDescription( key=ENTITY_KEY_WINDOW_OPEN_TEMPERATURE, - value_func=lambda presets: presets.window_open_temperature.value, + value_func=lambda presets: presets.window_open_temperature, value_set_func=lambda thermostat: thermostat.async_configure_window_open_temperature, translation_key=ENTITY_KEY_WINDOW_OPEN_TEMPERATURE, - native_min_value=EQ3BT_MIN_TEMP, - native_max_value=EQ3BT_MAX_TEMP, + native_min_value=EQ3_MIN_TEMP, + native_max_value=EQ3_MAX_TEMP, native_step=EQ3BT_STEP, native_unit_of_measurement=UnitOfTemperature.CELSIUS, device_class=NumberDeviceClass.TEMPERATURE, ), Eq3NumberEntityDescription( key=ENTITY_KEY_OFFSET, - value_func=lambda presets: presets.offset_temperature.value, + value_func=lambda presets: presets.offset_temperature, value_set_func=lambda thermostat: thermostat.async_configure_temperature_offset, translation_key=ENTITY_KEY_OFFSET, - native_min_value=EQ3BT_MIN_OFFSET, - native_max_value=EQ3BT_MAX_OFFSET, + native_min_value=EQ3_MIN_OFFSET, + native_max_value=EQ3_MAX_OFFSET, native_step=EQ3BT_STEP, native_unit_of_measurement=UnitOfTemperature.CELSIUS, device_class=NumberDeviceClass.TEMPERATURE, @@ -96,7 +91,7 @@ class Eq3NumberEntityDescription(NumberEntityDescription): Eq3NumberEntityDescription( key=ENTITY_KEY_WINDOW_OPEN_TIMEOUT, value_set_func=lambda thermostat: thermostat.async_configure_window_open_duration, - value_func=lambda presets: presets.window_open_time.value.total_seconds() / 60, + value_func=lambda presets: presets.window_open_time.total_seconds() / 60, translation_key=ENTITY_KEY_WINDOW_OPEN_TIMEOUT, native_min_value=0, native_max_value=60, @@ -137,7 +132,6 @@ def native_value(self) -> float: """Return the state of the entity.""" if TYPE_CHECKING: - assert self._thermostat.status is not None assert self._thermostat.status.presets is not None return self.entity_description.value_func(self._thermostat.status.presets) @@ -152,7 +146,7 @@ def available(self) -> bool: """Return whether the entity is available.""" return ( - self._thermostat.status is not None + super().available and self._thermostat.status.presets is not None and self._attr_available ) diff --git a/homeassistant/components/eq3btsmart/schemas.py b/homeassistant/components/eq3btsmart/schemas.py index 643bb4a02a6c65..daeed5a05e3cad 100644 --- a/homeassistant/components/eq3btsmart/schemas.py +++ b/homeassistant/components/eq3btsmart/schemas.py @@ -1,12 +1,12 @@ """Voluptuous schemas for eq3btsmart.""" -from eq3btsmart.const import EQ3BT_MAX_TEMP, EQ3BT_MIN_TEMP +from eq3btsmart.const import EQ3_MAX_TEMP, EQ3_MIN_TEMP import voluptuous as vol from homeassistant.const import CONF_MAC from homeassistant.helpers import config_validation as cv -SCHEMA_TEMPERATURE = vol.Range(min=EQ3BT_MIN_TEMP, max=EQ3BT_MAX_TEMP) +SCHEMA_TEMPERATURE = vol.Range(min=EQ3_MIN_TEMP, max=EQ3_MAX_TEMP) SCHEMA_DEVICE = vol.Schema({vol.Required(CONF_MAC): cv.string}) SCHEMA_MAC = vol.Schema( { diff --git a/homeassistant/components/eq3btsmart/sensor.py b/homeassistant/components/eq3btsmart/sensor.py index aab3cbf19257f6..0f61ef22452050 100644 --- a/homeassistant/components/eq3btsmart/sensor.py +++ b/homeassistant/components/eq3btsmart/sensor.py @@ -3,7 +3,6 @@ from collections.abc import Callable from dataclasses import dataclass from datetime import datetime -from typing import TYPE_CHECKING from eq3btsmart.models import Status @@ -40,9 +39,7 @@ class Eq3SensorEntityDescription(SensorEntityDescription): Eq3SensorEntityDescription( key=ENTITY_KEY_AWAY_UNTIL, translation_key=ENTITY_KEY_AWAY_UNTIL, - value_func=lambda status: ( - status.away_until.value if status.away_until else None - ), + value_func=lambda status: (status.away_until if status.away_until else None), device_class=SensorDeviceClass.DATE, ), ] @@ -78,7 +75,4 @@ def __init__( def native_value(self) -> int | datetime | None: """Return the value reported by the sensor.""" - if TYPE_CHECKING: - assert self._thermostat.status is not None - return self.entity_description.value_func(self._thermostat.status) diff --git a/homeassistant/components/eq3btsmart/switch.py b/homeassistant/components/eq3btsmart/switch.py index 61da133cb712f9..0d5521fee3259b 100644 --- a/homeassistant/components/eq3btsmart/switch.py +++ b/homeassistant/components/eq3btsmart/switch.py @@ -1,26 +1,45 @@ """Platform for eq3 switch entities.""" -from collections.abc import Awaitable, Callable +from collections.abc import Callable, Coroutine from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from datetime import timedelta +from functools import partial +from typing import Any from eq3btsmart import Thermostat +from eq3btsmart.const import EQ3_DEFAULT_AWAY_TEMP, Eq3OperationMode from eq3btsmart.models import Status from homeassistant.components.switch import SwitchEntity, SwitchEntityDescription from homeassistant.core import HomeAssistant from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback +import homeassistant.util.dt as dt_util from . import Eq3ConfigEntry -from .const import ENTITY_KEY_AWAY, ENTITY_KEY_BOOST, ENTITY_KEY_LOCK +from .const import ( + DEFAULT_AWAY_HOURS, + ENTITY_KEY_AWAY, + ENTITY_KEY_BOOST, + ENTITY_KEY_LOCK, +) from .entity import Eq3Entity +async def async_set_away(thermostat: Thermostat, enable: bool) -> Status: + """Backport old async_set_away behavior.""" + + if not enable: + return await thermostat.async_set_mode(Eq3OperationMode.AUTO) + + away_until = dt_util.now() + timedelta(hours=DEFAULT_AWAY_HOURS) + return await thermostat.async_set_away(away_until, EQ3_DEFAULT_AWAY_TEMP) + + @dataclass(frozen=True, kw_only=True) class Eq3SwitchEntityDescription(SwitchEntityDescription): """Entity description for eq3 switch entities.""" - toggle_func: Callable[[Thermostat], Callable[[bool], Awaitable[None]]] + toggle_func: Callable[[Thermostat], Callable[[bool], Coroutine[None, None, Status]]] value_func: Callable[[Status], bool] @@ -40,7 +59,7 @@ class Eq3SwitchEntityDescription(SwitchEntityDescription): Eq3SwitchEntityDescription( key=ENTITY_KEY_AWAY, translation_key=ENTITY_KEY_AWAY, - toggle_func=lambda thermostat: thermostat.async_set_away, + toggle_func=lambda thermostat: partial(async_set_away, thermostat), value_func=lambda status: status.is_away, ), ] @@ -88,7 +107,4 @@ async def async_turn_off(self, **kwargs: Any) -> None: def is_on(self) -> bool: """Return the state of the switch.""" - if TYPE_CHECKING: - assert self._thermostat.status is not None - return self.entity_description.value_func(self._thermostat.status) diff --git a/homeassistant/components/google_generative_ai_conversation/__init__.py b/homeassistant/components/google_generative_ai_conversation/__init__.py index 79d092a60c306d..7e9ca550275a82 100644 --- a/homeassistant/components/google_generative_ai_conversation/__init__.py +++ b/homeassistant/components/google_generative_ai_conversation/__init__.py @@ -45,7 +45,10 @@ CONF_FILENAMES = "filenames" CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN) -PLATFORMS = (Platform.CONVERSATION,) +PLATFORMS = ( + Platform.CONVERSATION, + Platform.TTS, +) type GoogleGenerativeAIConfigEntry = ConfigEntry[Client] diff --git a/homeassistant/components/google_generative_ai_conversation/const.py b/homeassistant/components/google_generative_ai_conversation/const.py index 239b3ff763e729..831e7d8f5085f9 100644 --- a/homeassistant/components/google_generative_ai_conversation/const.py +++ b/homeassistant/components/google_generative_ai_conversation/const.py @@ -6,9 +6,11 @@ LOGGER = logging.getLogger(__package__) CONF_PROMPT = "prompt" +ATTR_MODEL = "model" CONF_RECOMMENDED = "recommended" CONF_CHAT_MODEL = "chat_model" RECOMMENDED_CHAT_MODEL = "models/gemini-2.0-flash" +RECOMMENDED_TTS_MODEL = "gemini-2.5-flash-preview-tts" CONF_TEMPERATURE = "temperature" RECOMMENDED_TEMPERATURE = 1.0 CONF_TOP_P = "top_p" diff --git a/homeassistant/components/google_generative_ai_conversation/conversation.py b/homeassistant/components/google_generative_ai_conversation/conversation.py index 1038377af68f33..726572fc5aee60 100644 --- a/homeassistant/components/google_generative_ai_conversation/conversation.py +++ b/homeassistant/components/google_generative_ai_conversation/conversation.py @@ -2,63 +2,18 @@ from __future__ import annotations -import codecs -from collections.abc import AsyncGenerator, Callable -from dataclasses import replace -from typing import Any, Literal, cast - -from google.genai.errors import APIError, ClientError -from google.genai.types import ( - AutomaticFunctionCallingConfig, - Content, - FunctionDeclaration, - GenerateContentConfig, - GenerateContentResponse, - GoogleSearch, - HarmCategory, - Part, - SafetySetting, - Schema, - Tool, -) -from voluptuous_openapi import convert +from typing import Literal from homeassistant.components import assist_pipeline, conversation from homeassistant.config_entries import ConfigEntry from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError -from homeassistant.helpers import device_registry as dr, intent, llm +from homeassistant.helpers import intent from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback -from .const import ( - CONF_CHAT_MODEL, - CONF_DANGEROUS_BLOCK_THRESHOLD, - CONF_HARASSMENT_BLOCK_THRESHOLD, - CONF_HATE_BLOCK_THRESHOLD, - CONF_MAX_TOKENS, - CONF_PROMPT, - CONF_SEXUAL_BLOCK_THRESHOLD, - CONF_TEMPERATURE, - CONF_TOP_K, - CONF_TOP_P, - CONF_USE_GOOGLE_SEARCH_TOOL, - DOMAIN, - LOGGER, - RECOMMENDED_CHAT_MODEL, - RECOMMENDED_HARM_BLOCK_THRESHOLD, - RECOMMENDED_MAX_TOKENS, - RECOMMENDED_TEMPERATURE, - RECOMMENDED_TOP_K, - RECOMMENDED_TOP_P, -) - -# Max number of back and forth with the LLM to generate a response -MAX_TOOL_ITERATIONS = 10 - -ERROR_GETTING_RESPONSE = ( - "Sorry, I had a problem getting a response from Google Generative AI." -) +from .const import CONF_PROMPT, DOMAIN, LOGGER +from .entity import ERROR_GETTING_RESPONSE, GoogleGenerativeAILLMBaseEntity async def async_setup_entry( @@ -71,267 +26,18 @@ async def async_setup_entry( async_add_entities([agent]) -SUPPORTED_SCHEMA_KEYS = { - # Gemini API does not support all of the OpenAPI schema - # SoT: https://ai.google.dev/api/caching#Schema - "type", - "format", - "description", - "nullable", - "enum", - "max_items", - "min_items", - "properties", - "required", - "items", -} - - -def _camel_to_snake(name: str) -> str: - """Convert camel case to snake case.""" - return "".join(["_" + c.lower() if c.isupper() else c for c in name]).lstrip("_") - - -def _format_schema(schema: dict[str, Any]) -> Schema: - """Format the schema to be compatible with Gemini API.""" - if subschemas := schema.get("allOf"): - for subschema in subschemas: # Gemini API does not support allOf keys - if "type" in subschema: # Fallback to first subschema with 'type' field - return _format_schema(subschema) - return _format_schema( - subschemas[0] - ) # Or, if not found, to any of the subschemas - - result = {} - for key, val in schema.items(): - key = _camel_to_snake(key) - if key not in SUPPORTED_SCHEMA_KEYS: - continue - if key == "type": - val = val.upper() - elif key == "format": - # Gemini API does not support all formats, see: https://ai.google.dev/api/caching#Schema - # formats that are not supported are ignored - if schema.get("type") == "string" and val not in ("enum", "date-time"): - continue - if schema.get("type") == "number" and val not in ("float", "double"): - continue - if schema.get("type") == "integer" and val not in ("int32", "int64"): - continue - if schema.get("type") not in ("string", "number", "integer"): - continue - elif key == "items": - val = _format_schema(val) - elif key == "properties": - val = {k: _format_schema(v) for k, v in val.items()} - result[key] = val - - if result.get("enum") and result.get("type") != "STRING": - # enum is only allowed for STRING type. This is safe as long as the schema - # contains vol.Coerce for the respective type, for example: - # vol.All(vol.Coerce(int), vol.In([1, 2, 3])) - result["type"] = "STRING" - result["enum"] = [str(item) for item in result["enum"]] - - if result.get("type") == "OBJECT" and not result.get("properties"): - # An object with undefined properties is not supported by Gemini API. - # Fallback to JSON string. This will probably fail for most tools that want it, - # but we don't have a better fallback strategy so far. - result["properties"] = {"json": {"type": "STRING"}} - result["required"] = [] - return cast(Schema, result) - - -def _format_tool( - tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None -) -> Tool: - """Format tool specification.""" - - if tool.parameters.schema: - parameters = _format_schema( - convert(tool.parameters, custom_serializer=custom_serializer) - ) - else: - parameters = None - - return Tool( - function_declarations=[ - FunctionDeclaration( - name=tool.name, - description=tool.description, - parameters=parameters, - ) - ] - ) - - -def _escape_decode(value: Any) -> Any: - """Recursively call codecs.escape_decode on all values.""" - if isinstance(value, str): - return codecs.escape_decode(bytes(value, "utf-8"))[0].decode("utf-8") # type: ignore[attr-defined] - if isinstance(value, list): - return [_escape_decode(item) for item in value] - if isinstance(value, dict): - return {k: _escape_decode(v) for k, v in value.items()} - return value - - -def _create_google_tool_response_parts( - parts: list[conversation.ToolResultContent], -) -> list[Part]: - """Create Google tool response parts.""" - return [ - Part.from_function_response( - name=tool_result.tool_name, response=tool_result.tool_result - ) - for tool_result in parts - ] - - -def _create_google_tool_response_content( - content: list[conversation.ToolResultContent], -) -> Content: - """Create a Google tool response content.""" - return Content( - role="user", - parts=_create_google_tool_response_parts(content), - ) - - -def _convert_content( - content: ( - conversation.UserContent - | conversation.AssistantContent - | conversation.SystemContent - ), -) -> Content: - """Convert HA content to Google content.""" - if content.role != "assistant" or not content.tool_calls: - role = "model" if content.role == "assistant" else content.role - return Content( - role=role, - parts=[ - Part.from_text(text=content.content if content.content else ""), - ], - ) - - # Handle the Assistant content with tool calls. - assert type(content) is conversation.AssistantContent - parts: list[Part] = [] - - if content.content: - parts.append(Part.from_text(text=content.content)) - - if content.tool_calls: - parts.extend( - [ - Part.from_function_call( - name=tool_call.tool_name, - args=_escape_decode(tool_call.tool_args), - ) - for tool_call in content.tool_calls - ] - ) - - return Content(role="model", parts=parts) - - -async def _transform_stream( - result: AsyncGenerator[GenerateContentResponse], -) -> AsyncGenerator[conversation.AssistantContentDeltaDict]: - new_message = True - try: - async for response in result: - LOGGER.debug("Received response chunk: %s", response) - chunk: conversation.AssistantContentDeltaDict = {} - - if new_message: - chunk["role"] = "assistant" - new_message = False - - # According to the API docs, this would mean no candidate is returned, so we can safely throw an error here. - if response.prompt_feedback or not response.candidates: - reason = ( - response.prompt_feedback.block_reason_message - if response.prompt_feedback - else "unknown" - ) - raise HomeAssistantError( - f"The message got blocked due to content violations, reason: {reason}" - ) - - candidate = response.candidates[0] - - if ( - candidate.finish_reason is not None - and candidate.finish_reason != "STOP" - ): - # The message ended due to a content error as explained in: https://ai.google.dev/api/generate-content#FinishReason - LOGGER.error( - "Error in Google Generative AI response: %s, see: https://ai.google.dev/api/generate-content#FinishReason", - candidate.finish_reason, - ) - raise HomeAssistantError( - f"{ERROR_GETTING_RESPONSE} Reason: {candidate.finish_reason}" - ) - - response_parts = ( - candidate.content.parts - if candidate.content is not None and candidate.content.parts is not None - else [] - ) - - content = "".join([part.text for part in response_parts if part.text]) - tool_calls = [] - for part in response_parts: - if not part.function_call: - continue - tool_call = part.function_call - tool_name = tool_call.name if tool_call.name else "" - tool_args = _escape_decode(tool_call.args) - tool_calls.append( - llm.ToolInput(tool_name=tool_name, tool_args=tool_args) - ) - - if tool_calls: - chunk["tool_calls"] = tool_calls - - chunk["content"] = content - yield chunk - except ( - APIError, - ValueError, - ) as err: - LOGGER.error("Error sending message: %s %s", type(err), err) - if isinstance(err, APIError): - message = err.message - else: - message = type(err).__name__ - error = f"{ERROR_GETTING_RESPONSE}: {message}" - raise HomeAssistantError(error) from err - - class GoogleGenerativeAIConversationEntity( - conversation.ConversationEntity, conversation.AbstractConversationAgent + conversation.ConversationEntity, + conversation.AbstractConversationAgent, + GoogleGenerativeAILLMBaseEntity, ): """Google Generative AI conversation agent.""" - _attr_has_entity_name = True - _attr_name = None _attr_supports_streaming = True def __init__(self, entry: ConfigEntry) -> None: """Initialize the agent.""" - self.entry = entry - self._genai_client = entry.runtime_data - self._attr_unique_id = entry.entry_id - self._attr_device_info = dr.DeviceInfo( - identifiers={(DOMAIN, entry.entry_id)}, - name=entry.title, - manufacturer="Google", - model="Generative AI", - entry_type=dr.DeviceEntryType.SERVICE, - ) + super().__init__(entry) if self.entry.options.get(CONF_LLM_HASS_API): self._attr_supported_features = ( conversation.ConversationEntityFeature.CONTROL @@ -358,13 +64,6 @@ async def async_will_remove_from_hass(self) -> None: conversation.async_unset_agent(self.hass, self.entry) await super().async_will_remove_from_hass() - def _fix_tool_name(self, tool_name: str) -> str: - """Fix tool name if needed.""" - # The Gemini 2.0+ tokenizer seemingly has a issue with the HassListAddItem tool - # name. This makes sure when it incorrectly changes the name, that we change it - # back for HA to call. - return tool_name if tool_name != "HasListAddItem" else "HassListAddItem" - async def _async_handle_message( self, user_input: conversation.ConversationInput, @@ -399,163 +98,6 @@ async def _async_handle_message( continue_conversation=chat_log.continue_conversation, ) - async def _async_handle_chat_log( - self, - chat_log: conversation.ChatLog, - ) -> None: - """Generate an answer for the chat log.""" - options = self.entry.options - - tools: list[Tool | Callable[..., Any]] | None = None - if chat_log.llm_api: - tools = [ - _format_tool(tool, chat_log.llm_api.custom_serializer) - for tool in chat_log.llm_api.tools - ] - - # Using search grounding allows the model to retrieve information from the web, - # however, it may interfere with how the model decides to use some tools, or entities - # for example weather entity may be disregarded if the model chooses to Google it. - if options.get(CONF_USE_GOOGLE_SEARCH_TOOL) is True: - tools = tools or [] - tools.append(Tool(google_search=GoogleSearch())) - - model_name = self.entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL) - # Avoid INVALID_ARGUMENT Developer instruction is not enabled for - supports_system_instruction = ( - "gemma" not in model_name - and "gemini-2.0-flash-preview-image-generation" not in model_name - ) - - prompt_content = cast( - conversation.SystemContent, - chat_log.content[0], - ) - - if prompt_content.content: - prompt = prompt_content.content - else: - raise HomeAssistantError("Invalid prompt content") - - messages: list[Content] = [] - - # Google groups tool results, we do not. Group them before sending. - tool_results: list[conversation.ToolResultContent] = [] - - for chat_content in chat_log.content[1:-1]: - if chat_content.role == "tool_result": - tool_results.append(chat_content) - continue - - if ( - not isinstance(chat_content, conversation.ToolResultContent) - and chat_content.content == "" - ): - # Skipping is not possible since the number of function calls need to match the number of function responses - # and skipping one would mean removing the other and hence this would prevent a proper chat log - chat_content = replace(chat_content, content=" ") - - if tool_results: - messages.append(_create_google_tool_response_content(tool_results)) - tool_results.clear() - - messages.append(_convert_content(chat_content)) - - # The SDK requires the first message to be a user message - # This is not the case if user used `start_conversation` - # Workaround from https://github.com/googleapis/python-genai/issues/529#issuecomment-2740964537 - if messages and messages[0].role != "user": - messages.insert( - 0, - Content(role="user", parts=[Part.from_text(text=" ")]), - ) - - if tool_results: - messages.append(_create_google_tool_response_content(tool_results)) - generateContentConfig = GenerateContentConfig( - temperature=self.entry.options.get( - CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE - ), - top_k=self.entry.options.get(CONF_TOP_K, RECOMMENDED_TOP_K), - top_p=self.entry.options.get(CONF_TOP_P, RECOMMENDED_TOP_P), - max_output_tokens=self.entry.options.get( - CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS - ), - safety_settings=[ - SafetySetting( - category=HarmCategory.HARM_CATEGORY_HATE_SPEECH, - threshold=self.entry.options.get( - CONF_HATE_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD - ), - ), - SafetySetting( - category=HarmCategory.HARM_CATEGORY_HARASSMENT, - threshold=self.entry.options.get( - CONF_HARASSMENT_BLOCK_THRESHOLD, - RECOMMENDED_HARM_BLOCK_THRESHOLD, - ), - ), - SafetySetting( - category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - threshold=self.entry.options.get( - CONF_DANGEROUS_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD - ), - ), - SafetySetting( - category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - threshold=self.entry.options.get( - CONF_SEXUAL_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD - ), - ), - ], - tools=tools or None, - system_instruction=prompt if supports_system_instruction else None, - automatic_function_calling=AutomaticFunctionCallingConfig( - disable=True, maximum_remote_calls=None - ), - ) - - if not supports_system_instruction: - messages = [ - Content(role="user", parts=[Part.from_text(text=prompt)]), - Content(role="model", parts=[Part.from_text(text="Ok")]), - *messages, - ] - chat = self._genai_client.aio.chats.create( - model=model_name, history=messages, config=generateContentConfig - ) - user_message = chat_log.content[-1] - assert isinstance(user_message, conversation.UserContent) - chat_request: str | list[Part] = user_message.content - # To prevent infinite loops, we limit the number of iterations - for _iteration in range(MAX_TOOL_ITERATIONS): - try: - chat_response_generator = await chat.send_message_stream( - message=chat_request - ) - except ( - APIError, - ClientError, - ValueError, - ) as err: - LOGGER.error("Error sending message: %s %s", type(err), err) - error = ERROR_GETTING_RESPONSE - raise HomeAssistantError(error) from err - - chat_request = _create_google_tool_response_parts( - [ - content - async for content in chat_log.async_add_delta_content_stream( - self.entity_id, - _transform_stream(chat_response_generator), - ) - if isinstance(content, conversation.ToolResultContent) - ] - ) - - if not chat_log.unresponded_tool_results: - break - async def _async_entry_update_listener( self, hass: HomeAssistant, entry: ConfigEntry ) -> None: diff --git a/homeassistant/components/google_generative_ai_conversation/entity.py b/homeassistant/components/google_generative_ai_conversation/entity.py new file mode 100644 index 00000000000000..7eef3dbacff425 --- /dev/null +++ b/homeassistant/components/google_generative_ai_conversation/entity.py @@ -0,0 +1,475 @@ +"""Conversation support for the Google Generative AI Conversation integration.""" + +from __future__ import annotations + +import codecs +from collections.abc import AsyncGenerator, Callable +from dataclasses import replace +from typing import Any, cast + +from google.genai.errors import APIError, ClientError +from google.genai.types import ( + AutomaticFunctionCallingConfig, + Content, + FunctionDeclaration, + GenerateContentConfig, + GenerateContentResponse, + GoogleSearch, + HarmCategory, + Part, + SafetySetting, + Schema, + Tool, +) +from voluptuous_openapi import convert + +from homeassistant.components import conversation +from homeassistant.config_entries import ConfigEntry +from homeassistant.exceptions import HomeAssistantError +from homeassistant.helpers import device_registry as dr, llm +from homeassistant.helpers.entity import Entity + +from .const import ( + CONF_CHAT_MODEL, + CONF_DANGEROUS_BLOCK_THRESHOLD, + CONF_HARASSMENT_BLOCK_THRESHOLD, + CONF_HATE_BLOCK_THRESHOLD, + CONF_MAX_TOKENS, + CONF_SEXUAL_BLOCK_THRESHOLD, + CONF_TEMPERATURE, + CONF_TOP_K, + CONF_TOP_P, + CONF_USE_GOOGLE_SEARCH_TOOL, + DOMAIN, + LOGGER, + RECOMMENDED_CHAT_MODEL, + RECOMMENDED_HARM_BLOCK_THRESHOLD, + RECOMMENDED_MAX_TOKENS, + RECOMMENDED_TEMPERATURE, + RECOMMENDED_TOP_K, + RECOMMENDED_TOP_P, +) + +# Max number of back and forth with the LLM to generate a response +MAX_TOOL_ITERATIONS = 10 + +ERROR_GETTING_RESPONSE = ( + "Sorry, I had a problem getting a response from Google Generative AI." +) + + +SUPPORTED_SCHEMA_KEYS = { + # Gemini API does not support all of the OpenAPI schema + # SoT: https://ai.google.dev/api/caching#Schema + "type", + "format", + "description", + "nullable", + "enum", + "max_items", + "min_items", + "properties", + "required", + "items", +} + + +def _camel_to_snake(name: str) -> str: + """Convert camel case to snake case.""" + return "".join(["_" + c.lower() if c.isupper() else c for c in name]).lstrip("_") + + +def _format_schema(schema: dict[str, Any]) -> Schema: + """Format the schema to be compatible with Gemini API.""" + if subschemas := schema.get("allOf"): + for subschema in subschemas: # Gemini API does not support allOf keys + if "type" in subschema: # Fallback to first subschema with 'type' field + return _format_schema(subschema) + return _format_schema( + subschemas[0] + ) # Or, if not found, to any of the subschemas + + result = {} + for key, val in schema.items(): + key = _camel_to_snake(key) + if key not in SUPPORTED_SCHEMA_KEYS: + continue + if key == "type": + val = val.upper() + elif key == "format": + # Gemini API does not support all formats, see: https://ai.google.dev/api/caching#Schema + # formats that are not supported are ignored + if schema.get("type") == "string" and val not in ("enum", "date-time"): + continue + if schema.get("type") == "number" and val not in ("float", "double"): + continue + if schema.get("type") == "integer" and val not in ("int32", "int64"): + continue + if schema.get("type") not in ("string", "number", "integer"): + continue + elif key == "items": + val = _format_schema(val) + elif key == "properties": + val = {k: _format_schema(v) for k, v in val.items()} + result[key] = val + + if result.get("enum") and result.get("type") != "STRING": + # enum is only allowed for STRING type. This is safe as long as the schema + # contains vol.Coerce for the respective type, for example: + # vol.All(vol.Coerce(int), vol.In([1, 2, 3])) + result["type"] = "STRING" + result["enum"] = [str(item) for item in result["enum"]] + + if result.get("type") == "OBJECT" and not result.get("properties"): + # An object with undefined properties is not supported by Gemini API. + # Fallback to JSON string. This will probably fail for most tools that want it, + # but we don't have a better fallback strategy so far. + result["properties"] = {"json": {"type": "STRING"}} + result["required"] = [] + return cast(Schema, result) + + +def _format_tool( + tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None +) -> Tool: + """Format tool specification.""" + + if tool.parameters.schema: + parameters = _format_schema( + convert(tool.parameters, custom_serializer=custom_serializer) + ) + else: + parameters = None + + return Tool( + function_declarations=[ + FunctionDeclaration( + name=tool.name, + description=tool.description, + parameters=parameters, + ) + ] + ) + + +def _escape_decode(value: Any) -> Any: + """Recursively call codecs.escape_decode on all values.""" + if isinstance(value, str): + return codecs.escape_decode(bytes(value, "utf-8"))[0].decode("utf-8") # type: ignore[attr-defined] + if isinstance(value, list): + return [_escape_decode(item) for item in value] + if isinstance(value, dict): + return {k: _escape_decode(v) for k, v in value.items()} + return value + + +def _create_google_tool_response_parts( + parts: list[conversation.ToolResultContent], +) -> list[Part]: + """Create Google tool response parts.""" + return [ + Part.from_function_response( + name=tool_result.tool_name, response=tool_result.tool_result + ) + for tool_result in parts + ] + + +def _create_google_tool_response_content( + content: list[conversation.ToolResultContent], +) -> Content: + """Create a Google tool response content.""" + return Content( + role="user", + parts=_create_google_tool_response_parts(content), + ) + + +def _convert_content( + content: ( + conversation.UserContent + | conversation.AssistantContent + | conversation.SystemContent + ), +) -> Content: + """Convert HA content to Google content.""" + if content.role != "assistant" or not content.tool_calls: + role = "model" if content.role == "assistant" else content.role + return Content( + role=role, + parts=[ + Part.from_text(text=content.content if content.content else ""), + ], + ) + + # Handle the Assistant content with tool calls. + assert type(content) is conversation.AssistantContent + parts: list[Part] = [] + + if content.content: + parts.append(Part.from_text(text=content.content)) + + if content.tool_calls: + parts.extend( + [ + Part.from_function_call( + name=tool_call.tool_name, + args=_escape_decode(tool_call.tool_args), + ) + for tool_call in content.tool_calls + ] + ) + + return Content(role="model", parts=parts) + + +async def _transform_stream( + result: AsyncGenerator[GenerateContentResponse], +) -> AsyncGenerator[conversation.AssistantContentDeltaDict]: + new_message = True + try: + async for response in result: + LOGGER.debug("Received response chunk: %s", response) + chunk: conversation.AssistantContentDeltaDict = {} + + if new_message: + chunk["role"] = "assistant" + new_message = False + + # According to the API docs, this would mean no candidate is returned, so we can safely throw an error here. + if response.prompt_feedback or not response.candidates: + reason = ( + response.prompt_feedback.block_reason_message + if response.prompt_feedback + else "unknown" + ) + raise HomeAssistantError( + f"The message got blocked due to content violations, reason: {reason}" + ) + + candidate = response.candidates[0] + + if ( + candidate.finish_reason is not None + and candidate.finish_reason != "STOP" + ): + # The message ended due to a content error as explained in: https://ai.google.dev/api/generate-content#FinishReason + LOGGER.error( + "Error in Google Generative AI response: %s, see: https://ai.google.dev/api/generate-content#FinishReason", + candidate.finish_reason, + ) + raise HomeAssistantError( + f"{ERROR_GETTING_RESPONSE} Reason: {candidate.finish_reason}" + ) + + response_parts = ( + candidate.content.parts + if candidate.content is not None and candidate.content.parts is not None + else [] + ) + + content = "".join([part.text for part in response_parts if part.text]) + tool_calls = [] + for part in response_parts: + if not part.function_call: + continue + tool_call = part.function_call + tool_name = tool_call.name if tool_call.name else "" + tool_args = _escape_decode(tool_call.args) + tool_calls.append( + llm.ToolInput(tool_name=tool_name, tool_args=tool_args) + ) + + if tool_calls: + chunk["tool_calls"] = tool_calls + + chunk["content"] = content + yield chunk + except ( + APIError, + ValueError, + ) as err: + LOGGER.error("Error sending message: %s %s", type(err), err) + if isinstance(err, APIError): + message = err.message + else: + message = type(err).__name__ + error = f"{ERROR_GETTING_RESPONSE}: {message}" + raise HomeAssistantError(error) from err + + +class GoogleGenerativeAILLMBaseEntity(Entity): + """Google Generative AI base entity.""" + + _attr_has_entity_name = True + _attr_name = None + + def __init__(self, entry: ConfigEntry) -> None: + """Initialize the agent.""" + self.entry = entry + self._genai_client = entry.runtime_data + self._attr_unique_id = entry.entry_id + self._attr_device_info = dr.DeviceInfo( + identifiers={(DOMAIN, entry.entry_id)}, + name=entry.title, + manufacturer="Google", + model="Generative AI", + entry_type=dr.DeviceEntryType.SERVICE, + ) + + async def _async_handle_chat_log( + self, + chat_log: conversation.ChatLog, + ) -> None: + """Generate an answer for the chat log.""" + options = self.entry.options + + tools: list[Tool | Callable[..., Any]] | None = None + if chat_log.llm_api: + tools = [ + _format_tool(tool, chat_log.llm_api.custom_serializer) + for tool in chat_log.llm_api.tools + ] + + # Using search grounding allows the model to retrieve information from the web, + # however, it may interfere with how the model decides to use some tools, or entities + # for example weather entity may be disregarded if the model chooses to Google it. + if options.get(CONF_USE_GOOGLE_SEARCH_TOOL) is True: + tools = tools or [] + tools.append(Tool(google_search=GoogleSearch())) + + model_name = self.entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL) + # Avoid INVALID_ARGUMENT Developer instruction is not enabled for + supports_system_instruction = ( + "gemma" not in model_name + and "gemini-2.0-flash-preview-image-generation" not in model_name + ) + + prompt_content = cast( + conversation.SystemContent, + chat_log.content[0], + ) + + if prompt_content.content: + prompt = prompt_content.content + else: + raise HomeAssistantError("Invalid prompt content") + + messages: list[Content] = [] + + # Google groups tool results, we do not. Group them before sending. + tool_results: list[conversation.ToolResultContent] = [] + + for chat_content in chat_log.content[1:-1]: + if chat_content.role == "tool_result": + tool_results.append(chat_content) + continue + + if ( + not isinstance(chat_content, conversation.ToolResultContent) + and chat_content.content == "" + ): + # Skipping is not possible since the number of function calls need to match the number of function responses + # and skipping one would mean removing the other and hence this would prevent a proper chat log + chat_content = replace(chat_content, content=" ") + + if tool_results: + messages.append(_create_google_tool_response_content(tool_results)) + tool_results.clear() + + messages.append(_convert_content(chat_content)) + + # The SDK requires the first message to be a user message + # This is not the case if user used `start_conversation` + # Workaround from https://github.com/googleapis/python-genai/issues/529#issuecomment-2740964537 + if messages and messages[0].role != "user": + messages.insert( + 0, + Content(role="user", parts=[Part.from_text(text=" ")]), + ) + + if tool_results: + messages.append(_create_google_tool_response_content(tool_results)) + generateContentConfig = GenerateContentConfig( + temperature=self.entry.options.get( + CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE + ), + top_k=self.entry.options.get(CONF_TOP_K, RECOMMENDED_TOP_K), + top_p=self.entry.options.get(CONF_TOP_P, RECOMMENDED_TOP_P), + max_output_tokens=self.entry.options.get( + CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS + ), + safety_settings=[ + SafetySetting( + category=HarmCategory.HARM_CATEGORY_HATE_SPEECH, + threshold=self.entry.options.get( + CONF_HATE_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD + ), + ), + SafetySetting( + category=HarmCategory.HARM_CATEGORY_HARASSMENT, + threshold=self.entry.options.get( + CONF_HARASSMENT_BLOCK_THRESHOLD, + RECOMMENDED_HARM_BLOCK_THRESHOLD, + ), + ), + SafetySetting( + category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=self.entry.options.get( + CONF_DANGEROUS_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD + ), + ), + SafetySetting( + category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + threshold=self.entry.options.get( + CONF_SEXUAL_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD + ), + ), + ], + tools=tools or None, + system_instruction=prompt if supports_system_instruction else None, + automatic_function_calling=AutomaticFunctionCallingConfig( + disable=True, maximum_remote_calls=None + ), + ) + + if not supports_system_instruction: + messages = [ + Content(role="user", parts=[Part.from_text(text=prompt)]), + Content(role="model", parts=[Part.from_text(text="Ok")]), + *messages, + ] + chat = self._genai_client.aio.chats.create( + model=model_name, history=messages, config=generateContentConfig + ) + user_message = chat_log.content[-1] + assert isinstance(user_message, conversation.UserContent) + chat_request: str | list[Part] = user_message.content + # To prevent infinite loops, we limit the number of iterations + for _iteration in range(MAX_TOOL_ITERATIONS): + try: + chat_response_generator = await chat.send_message_stream( + message=chat_request + ) + except ( + APIError, + ClientError, + ValueError, + ) as err: + LOGGER.error("Error sending message: %s %s", type(err), err) + error = ERROR_GETTING_RESPONSE + raise HomeAssistantError(error) from err + + chat_request = _create_google_tool_response_parts( + [ + content + async for content in chat_log.async_add_delta_content_stream( + self.entity_id, + _transform_stream(chat_response_generator), + ) + if isinstance(content, conversation.ToolResultContent) + ] + ) + + if not chat_log.unresponded_tool_results: + break diff --git a/homeassistant/components/google_generative_ai_conversation/tts.py b/homeassistant/components/google_generative_ai_conversation/tts.py new file mode 100644 index 00000000000000..160048e4897fe2 --- /dev/null +++ b/homeassistant/components/google_generative_ai_conversation/tts.py @@ -0,0 +1,216 @@ +"""Text to speech support for Google Generative AI.""" + +from __future__ import annotations + +from contextlib import suppress +import io +import logging +from typing import Any +import wave + +from google.genai import types + +from homeassistant.components.tts import ( + ATTR_VOICE, + TextToSpeechEntity, + TtsAudioType, + Voice, +) +from homeassistant.config_entries import ConfigEntry +from homeassistant.core import HomeAssistant, callback +from homeassistant.exceptions import HomeAssistantError +from homeassistant.helpers import device_registry as dr +from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback + +from .const import ATTR_MODEL, DOMAIN, RECOMMENDED_TTS_MODEL + +_LOGGER = logging.getLogger(__name__) + + +async def async_setup_entry( + hass: HomeAssistant, + config_entry: ConfigEntry, + async_add_entities: AddConfigEntryEntitiesCallback, +) -> None: + """Set up TTS entity.""" + tts_entity = GoogleGenerativeAITextToSpeechEntity(config_entry) + async_add_entities([tts_entity]) + + +class GoogleGenerativeAITextToSpeechEntity(TextToSpeechEntity): + """Google Generative AI text-to-speech entity.""" + + _attr_supported_options = [ATTR_VOICE, ATTR_MODEL] + # See https://ai.google.dev/gemini-api/docs/speech-generation#languages + _attr_supported_languages = [ + "ar-EG", + "bn-BD", + "de-DE", + "en-IN", + "en-US", + "es-US", + "fr-FR", + "hi-IN", + "id-ID", + "it-IT", + "ja-JP", + "ko-KR", + "mr-IN", + "nl-NL", + "pl-PL", + "pt-BR", + "ro-RO", + "ru-RU", + "ta-IN", + "te-IN", + "th-TH", + "tr-TR", + "uk-UA", + "vi-VN", + ] + _attr_default_language = "en-US" + # See https://ai.google.dev/gemini-api/docs/speech-generation#voices + _supported_voices = [ + Voice(voice.split(" ", 1)[0].lower(), voice) + for voice in ( + "Zephyr (Bright)", + "Puck (Upbeat)", + "Charon (Informative)", + "Kore (Firm)", + "Fenrir (Excitable)", + "Leda (Youthful)", + "Orus (Firm)", + "Aoede (Breezy)", + "Callirrhoe (Easy-going)", + "Autonoe (Bright)", + "Enceladus (Breathy)", + "Iapetus (Clear)", + "Umbriel (Easy-going)", + "Algieba (Smooth)", + "Despina (Smooth)", + "Erinome (Clear)", + "Algenib (Gravelly)", + "Rasalgethi (Informative)", + "Laomedeia (Upbeat)", + "Achernar (Soft)", + "Alnilam (Firm)", + "Schedar (Even)", + "Gacrux (Mature)", + "Pulcherrima (Forward)", + "Achird (Friendly)", + "Zubenelgenubi (Casual)", + "Vindemiatrix (Gentle)", + "Sadachbia (Lively)", + "Sadaltager (Knowledgeable)", + "Sulafat (Warm)", + ) + ] + + def __init__(self, entry: ConfigEntry) -> None: + """Initialize Google Generative AI Conversation speech entity.""" + self.entry = entry + self._attr_name = "Google Generative AI TTS" + self._attr_unique_id = f"{entry.entry_id}_tts" + self._attr_device_info = dr.DeviceInfo( + identifiers={(DOMAIN, entry.entry_id)}, + name=entry.title, + manufacturer="Google", + model="Generative AI", + entry_type=dr.DeviceEntryType.SERVICE, + ) + self._genai_client = entry.runtime_data + self._default_voice_id = self._supported_voices[0].voice_id + + @callback + def async_get_supported_voices(self, language: str) -> list[Voice] | None: + """Return a list of supported voices for a language.""" + return self._supported_voices + + async def async_get_tts_audio( + self, message: str, language: str, options: dict[str, Any] + ) -> TtsAudioType: + """Load tts audio file from the engine.""" + try: + response = self._genai_client.models.generate_content( + model=options.get(ATTR_MODEL, RECOMMENDED_TTS_MODEL), + contents=message, + config=types.GenerateContentConfig( + response_modalities=["AUDIO"], + speech_config=types.SpeechConfig( + voice_config=types.VoiceConfig( + prebuilt_voice_config=types.PrebuiltVoiceConfig( + voice_name=options.get( + ATTR_VOICE, self._default_voice_id + ) + ) + ) + ), + ), + ) + + data = response.candidates[0].content.parts[0].inline_data.data + mime_type = response.candidates[0].content.parts[0].inline_data.mime_type + except Exception as exc: + _LOGGER.warning( + "Error during processing of TTS request %s", exc, exc_info=True + ) + raise HomeAssistantError(exc) from exc + return "wav", self._convert_to_wav(data, mime_type) + + def _convert_to_wav(self, audio_data: bytes, mime_type: str) -> bytes: + """Generate a WAV file header for the given audio data and parameters. + + Args: + audio_data: The raw audio data as a bytes object. + mime_type: Mime type of the audio data. + + Returns: + A bytes object representing the WAV file header. + + """ + parameters = self._parse_audio_mime_type(mime_type) + + wav_buffer = io.BytesIO() + with wave.open(wav_buffer, "wb") as wf: + wf.setnchannels(1) + wf.setsampwidth(parameters["bits_per_sample"] // 8) + wf.setframerate(parameters["rate"]) + wf.writeframes(audio_data) + + return wav_buffer.getvalue() + + def _parse_audio_mime_type(self, mime_type: str) -> dict[str, int]: + """Parse bits per sample and rate from an audio MIME type string. + + Assumes bits per sample is encoded like "L16" and rate as "rate=xxxxx". + + Args: + mime_type: The audio MIME type string (e.g., "audio/L16;rate=24000"). + + Returns: + A dictionary with "bits_per_sample" and "rate" keys. Values will be + integers if found, otherwise None. + + """ + if not mime_type.startswith("audio/L"): + _LOGGER.warning("Received unexpected MIME type %s", mime_type) + raise HomeAssistantError(f"Unsupported audio MIME type: {mime_type}") + + bits_per_sample = 16 + rate = 24000 + + # Extract rate from parameters + parts = mime_type.split(";") + for param in parts: # Skip the main type part + param = param.strip() + if param.lower().startswith("rate="): + # Handle cases like "rate=" with no value or non-integer value and keep rate as default + with suppress(ValueError, IndexError): + rate_str = param.split("=", 1)[1] + rate = int(rate_str) + elif param.startswith("audio/L"): + # Keep bits_per_sample as default if conversion fails + with suppress(ValueError, IndexError): + bits_per_sample = int(param.split("L", 1)[1]) + + return {"bits_per_sample": bits_per_sample, "rate": rate} diff --git a/requirements_all.txt b/requirements_all.txt index f2df39a74a2275..59e02cd3e6c916 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -893,7 +893,7 @@ epion==0.0.3 epson-projector==0.5.1 # homeassistant.components.eq3btsmart -eq3btsmart==1.4.1 +eq3btsmart==2.1.0 # homeassistant.components.esphome esphome-dashboard-api==1.3.0 diff --git a/requirements_test_all.txt b/requirements_test_all.txt index de3f60b1edb21a..ddf736619f5451 100644 --- a/requirements_test_all.txt +++ b/requirements_test_all.txt @@ -772,7 +772,7 @@ epion==0.0.3 epson-projector==0.5.1 # homeassistant.components.eq3btsmart -eq3btsmart==1.4.1 +eq3btsmart==2.1.0 # homeassistant.components.esphome esphome-dashboard-api==1.3.0 diff --git a/script/hassfest/requirements.py b/script/hassfest/requirements.py index b8265e4e58d0ed..b1aff0dc1fd44e 100644 --- a/script/hassfest/requirements.py +++ b/script/hassfest/requirements.py @@ -331,10 +331,6 @@ # https://github.com/hbldh/bleak/pull/1718 (not yet released) "homeassistant": {"bleak"} }, - "eq3btsmart": { - # https://github.com/EuleMitKeule/eq3btsmart/releases/tag/2.0.0 - "homeassistant": {"eq3btsmart"} - }, "python_script": { # Security audits are needed for each Python version "homeassistant": {"restrictedpython"} diff --git a/tests/components/google_generative_ai_conversation/conftest.py b/tests/components/google_generative_ai_conversation/conftest.py index 6ec147da2ab0b6..0d222c78472071 100644 --- a/tests/components/google_generative_ai_conversation/conftest.py +++ b/tests/components/google_generative_ai_conversation/conftest.py @@ -4,7 +4,7 @@ import pytest -from homeassistant.components.google_generative_ai_conversation.conversation import ( +from homeassistant.components.google_generative_ai_conversation.entity import ( CONF_USE_GOOGLE_SEARCH_TOOL, ) from homeassistant.config_entries import ConfigEntry diff --git a/tests/components/google_generative_ai_conversation/test_conversation.py b/tests/components/google_generative_ai_conversation/test_conversation.py index 2d1a46393fd6e7..a55a86b67c92bc 100644 --- a/tests/components/google_generative_ai_conversation/test_conversation.py +++ b/tests/components/google_generative_ai_conversation/test_conversation.py @@ -9,7 +9,7 @@ from homeassistant.components import conversation from homeassistant.components.conversation import UserContent -from homeassistant.components.google_generative_ai_conversation.conversation import ( +from homeassistant.components.google_generative_ai_conversation.entity import ( ERROR_GETTING_RESPONSE, _escape_decode, _format_schema, diff --git a/tests/components/google_generative_ai_conversation/test_tts.py b/tests/components/google_generative_ai_conversation/test_tts.py new file mode 100644 index 00000000000000..5ea056307b5163 --- /dev/null +++ b/tests/components/google_generative_ai_conversation/test_tts.py @@ -0,0 +1,413 @@ +"""Tests for the Google Generative AI Conversation TTS entity.""" + +from __future__ import annotations + +from collections.abc import Generator +from http import HTTPStatus +from pathlib import Path +from typing import Any +from unittest.mock import AsyncMock, MagicMock, Mock, patch + +from google.genai import types +import pytest + +from homeassistant.components import tts +from homeassistant.components.google_generative_ai_conversation.tts import ( + ATTR_MODEL, + DOMAIN, + RECOMMENDED_TTS_MODEL, +) +from homeassistant.components.media_player import ( + ATTR_MEDIA_CONTENT_ID, + DOMAIN as DOMAIN_MP, + SERVICE_PLAY_MEDIA, +) +from homeassistant.const import ATTR_ENTITY_ID, CONF_API_KEY, CONF_PLATFORM +from homeassistant.core import HomeAssistant, ServiceCall +from homeassistant.core_config import async_process_ha_core_config +from homeassistant.setup import async_setup_component + +from . import API_ERROR_500 + +from tests.common import MockConfigEntry, async_mock_service +from tests.components.tts.common import retrieve_media +from tests.typing import ClientSessionGenerator + + +@pytest.fixture(autouse=True) +def tts_mutagen_mock_fixture_autouse(tts_mutagen_mock: MagicMock) -> None: + """Mock writing tags.""" + + +@pytest.fixture(autouse=True) +def mock_tts_cache_dir_autouse(mock_tts_cache_dir: Path) -> None: + """Mock the TTS cache dir with empty dir.""" + + +@pytest.fixture +async def calls(hass: HomeAssistant) -> list[ServiceCall]: + """Mock media player calls.""" + return async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) + + +@pytest.fixture(autouse=True) +async def setup_internal_url(hass: HomeAssistant) -> None: + """Set up internal url.""" + await async_process_ha_core_config( + hass, {"internal_url": "http://example.local:8123"} + ) + + +@pytest.fixture +def mock_genai_client() -> Generator[AsyncMock]: + """Mock genai_client.""" + client = Mock() + client.aio.models.get = AsyncMock() + client.models.generate_content.return_value = types.GenerateContentResponse( + candidates=( + types.Candidate( + content=types.Content( + parts=( + types.Part( + inline_data=types.Blob( + data=b"raw-audio-bytes", + mime_type="audio/L16;rate=24000", + ) + ), + ) + ) + ), + ) + ) + with patch( + "homeassistant.components.google_generative_ai_conversation.Client", + return_value=client, + ) as mock_client: + yield mock_client + + +@pytest.fixture(name="setup") +async def setup_fixture( + hass: HomeAssistant, + config: dict[str, Any], + request: pytest.FixtureRequest, + mock_genai_client: AsyncMock, +) -> None: + """Set up the test environment.""" + if request.param == "mock_setup": + await mock_setup(hass, config) + if request.param == "mock_config_entry_setup": + await mock_config_entry_setup(hass, config) + else: + raise RuntimeError("Invalid setup fixture") + + await hass.async_block_till_done() + + +@pytest.fixture(name="config") +def config_fixture() -> dict[str, Any]: + """Return config.""" + return { + CONF_API_KEY: "bla", + } + + +async def mock_setup(hass: HomeAssistant, config: dict[str, Any]) -> None: + """Mock setup.""" + assert await async_setup_component( + hass, tts.DOMAIN, {tts.DOMAIN: {CONF_PLATFORM: DOMAIN} | config} + ) + + +async def mock_config_entry_setup(hass: HomeAssistant, config: dict[str, Any]) -> None: + """Mock config entry setup.""" + default_config = {tts.CONF_LANG: "en-US"} + config_entry = MockConfigEntry(domain=DOMAIN, data=default_config | config) + + client_mock = Mock() + client_mock.models.get = None + client_mock.models.generate_content.return_value = types.GenerateContentResponse( + candidates=( + types.Candidate( + content=types.Content( + parts=( + types.Part( + inline_data=types.Blob( + data=b"raw-audio-bytes", + mime_type="audio/L16;rate=24000", + ) + ), + ) + ) + ), + ) + ) + config_entry.runtime_data = client_mock + config_entry.add_to_hass(hass) + + assert await hass.config_entries.async_setup(config_entry.entry_id) + + +@pytest.mark.parametrize( + ("setup", "tts_service", "service_data"), + [ + ( + "mock_config_entry_setup", + "speak", + { + ATTR_ENTITY_ID: "tts.google_generative_ai_tts", + tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something", + tts.ATTR_MESSAGE: "There is a person at the front door.", + tts.ATTR_OPTIONS: {}, + }, + ), + ( + "mock_config_entry_setup", + "speak", + { + ATTR_ENTITY_ID: "tts.google_generative_ai_tts", + tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something", + tts.ATTR_MESSAGE: "There is a person at the front door.", + tts.ATTR_OPTIONS: {tts.ATTR_VOICE: "voice2"}, + }, + ), + ( + "mock_config_entry_setup", + "speak", + { + ATTR_ENTITY_ID: "tts.google_generative_ai_tts", + tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something", + tts.ATTR_MESSAGE: "There is a person at the front door.", + tts.ATTR_OPTIONS: {ATTR_MODEL: "model2"}, + }, + ), + ( + "mock_config_entry_setup", + "speak", + { + ATTR_ENTITY_ID: "tts.google_generative_ai_tts", + tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something", + tts.ATTR_MESSAGE: "There is a person at the front door.", + tts.ATTR_OPTIONS: {tts.ATTR_VOICE: "voice2", ATTR_MODEL: "model2"}, + }, + ), + ], + indirect=["setup"], +) +async def test_tts_service_speak( + setup: AsyncMock, + hass: HomeAssistant, + hass_client: ClientSessionGenerator, + calls: list[ServiceCall], + tts_service: str, + service_data: dict[str, Any], +) -> None: + """Test tts service.""" + tts_entity = hass.data[tts.DOMAIN].get_entity(service_data[ATTR_ENTITY_ID]) + tts_entity._genai_client.models.generate_content.reset_mock() + + await hass.services.async_call( + tts.DOMAIN, + tts_service, + service_data, + blocking=True, + ) + + assert len(calls) == 1 + assert ( + await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + == HTTPStatus.OK + ) + voice_id = service_data[tts.ATTR_OPTIONS].get(tts.ATTR_VOICE, "zephyr") + model_id = service_data[tts.ATTR_OPTIONS].get(ATTR_MODEL, RECOMMENDED_TTS_MODEL) + + tts_entity._genai_client.models.generate_content.assert_called_once_with( + model=model_id, + contents="There is a person at the front door.", + config=types.GenerateContentConfig( + response_modalities=["AUDIO"], + speech_config=types.SpeechConfig( + voice_config=types.VoiceConfig( + prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name=voice_id) + ) + ), + ), + ) + + +@pytest.mark.parametrize( + ("setup", "tts_service", "service_data"), + [ + ( + "mock_config_entry_setup", + "speak", + { + ATTR_ENTITY_ID: "tts.google_generative_ai_tts", + tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something", + tts.ATTR_MESSAGE: "There is a person at the front door.", + tts.ATTR_LANGUAGE: "de-DE", + tts.ATTR_OPTIONS: {tts.ATTR_VOICE: "voice1"}, + }, + ), + ( + "mock_config_entry_setup", + "speak", + { + ATTR_ENTITY_ID: "tts.google_generative_ai_tts", + tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something", + tts.ATTR_MESSAGE: "There is a person at the front door.", + tts.ATTR_LANGUAGE: "it-IT", + tts.ATTR_OPTIONS: {tts.ATTR_VOICE: "voice1"}, + }, + ), + ], + indirect=["setup"], +) +async def test_tts_service_speak_lang_config( + setup: AsyncMock, + hass: HomeAssistant, + hass_client: ClientSessionGenerator, + calls: list[ServiceCall], + tts_service: str, + service_data: dict[str, Any], +) -> None: + """Test service call with languages in the config.""" + tts_entity = hass.data[tts.DOMAIN].get_entity(service_data[ATTR_ENTITY_ID]) + tts_entity._genai_client.models.generate_content.reset_mock() + + await hass.services.async_call( + tts.DOMAIN, + tts_service, + service_data, + blocking=True, + ) + + assert len(calls) == 1 + assert ( + await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + == HTTPStatus.OK + ) + + tts_entity._genai_client.models.generate_content.assert_called_once_with( + model=RECOMMENDED_TTS_MODEL, + contents="There is a person at the front door.", + config=types.GenerateContentConfig( + response_modalities=["AUDIO"], + speech_config=types.SpeechConfig( + voice_config=types.VoiceConfig( + prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name="voice1") + ) + ), + ), + ) + + +@pytest.mark.parametrize( + ("setup", "tts_service", "service_data"), + [ + ( + "mock_config_entry_setup", + "speak", + { + ATTR_ENTITY_ID: "tts.google_generative_ai_tts", + tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something", + tts.ATTR_MESSAGE: "There is a person at the front door.", + tts.ATTR_OPTIONS: {tts.ATTR_VOICE: "voice1"}, + }, + ), + ], + indirect=["setup"], +) +async def test_tts_service_speak_error( + setup: AsyncMock, + hass: HomeAssistant, + hass_client: ClientSessionGenerator, + calls: list[ServiceCall], + tts_service: str, + service_data: dict[str, Any], +) -> None: + """Test service call with HTTP response 500.""" + tts_entity = hass.data[tts.DOMAIN].get_entity(service_data[ATTR_ENTITY_ID]) + tts_entity._genai_client.models.generate_content.reset_mock() + tts_entity._genai_client.models.generate_content.side_effect = API_ERROR_500 + + await hass.services.async_call( + tts.DOMAIN, + tts_service, + service_data, + blocking=True, + ) + + assert len(calls) == 1 + assert ( + await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + == HTTPStatus.INTERNAL_SERVER_ERROR + ) + + tts_entity._genai_client.models.generate_content.assert_called_once_with( + model=RECOMMENDED_TTS_MODEL, + contents="There is a person at the front door.", + config=types.GenerateContentConfig( + response_modalities=["AUDIO"], + speech_config=types.SpeechConfig( + voice_config=types.VoiceConfig( + prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name="voice1") + ) + ), + ), + ) + + +@pytest.mark.parametrize( + ("setup", "tts_service", "service_data"), + [ + ( + "mock_config_entry_setup", + "speak", + { + ATTR_ENTITY_ID: "tts.google_generative_ai_tts", + tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something", + tts.ATTR_MESSAGE: "There is a person at the front door.", + tts.ATTR_OPTIONS: {}, + }, + ), + ], + indirect=["setup"], +) +async def test_tts_service_speak_without_options( + setup: AsyncMock, + hass: HomeAssistant, + hass_client: ClientSessionGenerator, + calls: list[ServiceCall], + tts_service: str, + service_data: dict[str, Any], +) -> None: + """Test service call with HTTP response 200.""" + tts_entity = hass.data[tts.DOMAIN].get_entity(service_data[ATTR_ENTITY_ID]) + tts_entity._genai_client.models.generate_content.reset_mock() + + await hass.services.async_call( + tts.DOMAIN, + tts_service, + service_data, + blocking=True, + ) + + assert len(calls) == 1 + assert ( + await retrieve_media(hass, hass_client, calls[0].data[ATTR_MEDIA_CONTENT_ID]) + == HTTPStatus.OK + ) + + tts_entity._genai_client.models.generate_content.assert_called_once_with( + model=RECOMMENDED_TTS_MODEL, + contents="There is a person at the front door.", + config=types.GenerateContentConfig( + response_modalities=["AUDIO"], + speech_config=types.SpeechConfig( + voice_config=types.VoiceConfig( + prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name="zephyr") + ) + ), + ), + ) diff --git a/tests/components/telegram_bot/test_config_flow.py b/tests/components/telegram_bot/test_config_flow.py index 47b6d99b9ce9fe..0287ccc5dfafc7 100644 --- a/tests/components/telegram_bot/test_config_flow.py +++ b/tests/components/telegram_bot/test_config_flow.py @@ -1,6 +1,6 @@ """Config flow tests for the Telegram Bot integration.""" -from unittest.mock import patch +from unittest.mock import AsyncMock, patch from telegram import ChatFullInfo, User from telegram.constants import AccentColor @@ -305,10 +305,19 @@ async def test_reauth_flow( # test: valid - with patch( - "homeassistant.components.telegram_bot.config_flow.Bot.get_me", - return_value=User(123456, "Testbot", True), + with ( + patch( + "homeassistant.components.telegram_bot.config_flow.Bot.get_me", + return_value=User(123456, "Testbot", True), + ), + patch( + "homeassistant.components.telegram_bot.webhooks.PushBot", + ) as mock_pushbot, ): + mock_pushbot.return_value.start_application = AsyncMock() + mock_pushbot.return_value.register_webhook = AsyncMock() + mock_pushbot.return_value.shutdown = AsyncMock() + result = await hass.config_entries.flow.async_configure( result["flow_id"], {CONF_API_KEY: "new mock api key"},