diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index 93e857f4b2bff5..a1b6ea53445146 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -1119,6 +1119,7 @@ async def recognize_intent( ) is not None: # Sentence trigger matched agent_id = "sentence_trigger" + processed_locally = True intent_response = intent.IntentResponse( self.pipeline.conversation_language ) diff --git a/homeassistant/components/assist_satellite/services.yaml b/homeassistant/components/assist_satellite/services.yaml index c5484e22dad954..6beb099186189b 100644 --- a/homeassistant/components/assist_satellite/services.yaml +++ b/homeassistant/components/assist_satellite/services.yaml @@ -86,3 +86,17 @@ ask_question: required: false selector: object: + label_field: sentences + description_field: id + multiple: true + translation_key: answers + fields: + id: + required: true + selector: + text: + sentences: + required: true + selector: + text: + multiple: true diff --git a/homeassistant/components/assist_satellite/strings.json b/homeassistant/components/assist_satellite/strings.json index e0bf2bcfb94915..52df24924802a0 100644 --- a/homeassistant/components/assist_satellite/strings.json +++ b/homeassistant/components/assist_satellite/strings.json @@ -90,5 +90,13 @@ } } } + }, + "selector": { + "answers": { + "fields": { + "id": "Answer ID", + "sentences": "Sentences" + } + } } } diff --git a/homeassistant/components/derivative/sensor.py b/homeassistant/components/derivative/sensor.py index 60f4611c5eb56d..0639826b1ee1d1 100644 --- a/homeassistant/components/derivative/sensor.py +++ b/homeassistant/components/derivative/sensor.py @@ -336,13 +336,7 @@ def calc_derivative( "" if unit is None else unit ) - # filter out all derivatives older than `time_window` from our window list - self._state_list = [ - (time_start, time_end, state) - for time_start, time_end, state in self._state_list - if (new_state.last_reported - time_end).total_seconds() - < self._time_window - ] + self._prune_state_list(new_state.last_reported) try: elapsed_time = ( @@ -380,25 +374,14 @@ def calc_derivative( (old_last_reported, new_state.last_reported, new_derivative) ) - def calculate_weight( - start: datetime, end: datetime, now: datetime - ) -> float: - window_start = now - timedelta(seconds=self._time_window) - if start < window_start: - weight = (end - window_start).total_seconds() / self._time_window - else: - weight = (end - start).total_seconds() / self._time_window - return weight - # If outside of time window just report derivative (is the same as modeling it in the window), # otherwise take the weighted average with the previous derivatives if elapsed_time > self._time_window: derivative = new_derivative else: - derivative = Decimal("0.00") - for start, end, value in self._state_list: - weight = calculate_weight(start, end, new_state.last_reported) - derivative = derivative + (value * Decimal(weight)) + derivative = self._calc_derivative_from_state_list( + new_state.last_reported + ) self._attr_native_value = round(derivative, self._round_digits) self.async_write_ha_state() diff --git a/homeassistant/components/devolo_home_control/climate.py b/homeassistant/components/devolo_home_control/climate.py index 3fdfa60870afd4..95db596c3ef631 100644 --- a/homeassistant/components/devolo_home_control/climate.py +++ b/homeassistant/components/devolo_home_control/climate.py @@ -18,7 +18,7 @@ from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback from . import DevoloHomeControlConfigEntry -from .devolo_multi_level_switch import DevoloMultiLevelSwitchDeviceEntity +from .entity import DevoloMultiLevelSwitchDeviceEntity async def async_setup_entry( diff --git a/homeassistant/components/devolo_home_control/cover.py b/homeassistant/components/devolo_home_control/cover.py index f23244f1b500c9..bafef2b02c98f4 100644 --- a/homeassistant/components/devolo_home_control/cover.py +++ b/homeassistant/components/devolo_home_control/cover.py @@ -13,7 +13,7 @@ from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback from . import DevoloHomeControlConfigEntry -from .devolo_multi_level_switch import DevoloMultiLevelSwitchDeviceEntity +from .entity import DevoloMultiLevelSwitchDeviceEntity async def async_setup_entry( diff --git a/homeassistant/components/devolo_home_control/devolo_multi_level_switch.py b/homeassistant/components/devolo_home_control/devolo_multi_level_switch.py deleted file mode 100644 index 3e2d551d1f8626..00000000000000 --- a/homeassistant/components/devolo_home_control/devolo_multi_level_switch.py +++ /dev/null @@ -1,27 +0,0 @@ -"""Base class for multi level switches in devolo Home Control.""" - -from devolo_home_control_api.devices.zwave import Zwave -from devolo_home_control_api.homecontrol import HomeControl - -from .entity import DevoloDeviceEntity - - -class DevoloMultiLevelSwitchDeviceEntity(DevoloDeviceEntity): - """Representation of a multi level switch device within devolo Home Control. Something like a dimmer or a thermostat.""" - - _attr_name = None - - def __init__( - self, homecontrol: HomeControl, device_instance: Zwave, element_uid: str - ) -> None: - """Initialize a multi level switch within devolo Home Control.""" - super().__init__( - homecontrol=homecontrol, - device_instance=device_instance, - element_uid=element_uid, - ) - self._multi_level_switch_property = device_instance.multi_level_switch_property[ - element_uid - ] - - self._value = self._multi_level_switch_property.value diff --git a/homeassistant/components/devolo_home_control/entity.py b/homeassistant/components/devolo_home_control/entity.py index 26b450a2cf26c2..dbe53c214121cf 100644 --- a/homeassistant/components/devolo_home_control/entity.py +++ b/homeassistant/components/devolo_home_control/entity.py @@ -90,3 +90,24 @@ def _generic_message(self, message: tuple) -> None: self._attr_available = self._device_instance.is_online() else: _LOGGER.debug("No valid message received: %s", message) + + +class DevoloMultiLevelSwitchDeviceEntity(DevoloDeviceEntity): + """Representation of a multi level switch device within devolo Home Control. Something like a dimmer or a thermostat.""" + + _attr_name = None + + def __init__( + self, homecontrol: HomeControl, device_instance: Zwave, element_uid: str + ) -> None: + """Initialize a multi level switch within devolo Home Control.""" + super().__init__( + homecontrol=homecontrol, + device_instance=device_instance, + element_uid=element_uid, + ) + self._multi_level_switch_property = device_instance.multi_level_switch_property[ + element_uid + ] + + self._value = self._multi_level_switch_property.value diff --git a/homeassistant/components/devolo_home_control/light.py b/homeassistant/components/devolo_home_control/light.py index 8a88081ed058e0..907a46ec27b9d1 100644 --- a/homeassistant/components/devolo_home_control/light.py +++ b/homeassistant/components/devolo_home_control/light.py @@ -12,7 +12,7 @@ from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback from . import DevoloHomeControlConfigEntry -from .devolo_multi_level_switch import DevoloMultiLevelSwitchDeviceEntity +from .entity import DevoloMultiLevelSwitchDeviceEntity async def async_setup_entry( diff --git a/homeassistant/components/devolo_home_control/siren.py b/homeassistant/components/devolo_home_control/siren.py index 5e4df944b3c5ea..e3f91ca4d7d9a3 100644 --- a/homeassistant/components/devolo_home_control/siren.py +++ b/homeassistant/components/devolo_home_control/siren.py @@ -10,7 +10,7 @@ from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback from . import DevoloHomeControlConfigEntry -from .devolo_multi_level_switch import DevoloMultiLevelSwitchDeviceEntity +from .entity import DevoloMultiLevelSwitchDeviceEntity async def async_setup_entry( diff --git a/homeassistant/components/esphome/assist_satellite.py b/homeassistant/components/esphome/assist_satellite.py index f63671654004c7..adddacd3998285 100644 --- a/homeassistant/components/esphome/assist_satellite.py +++ b/homeassistant/components/esphome/assist_satellite.py @@ -284,11 +284,15 @@ def on_pipeline_event(self, event: PipelineEvent) -> None: assert event.data is not None data_to_send = {"text": event.data["stt_output"]["text"]} elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_PROGRESS: - data_to_send = { - "tts_start_streaming": "1" - if (event.data and event.data.get("tts_start_streaming")) - else "0", - } + if ( + not event.data + or ("tts_start_streaming" not in event.data) + or (not event.data["tts_start_streaming"]) + ): + # ESPHome only needs to know if early TTS streaming is available + return + + data_to_send = {"tts_start_streaming": "1"} elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END: assert event.data is not None data_to_send = { diff --git a/homeassistant/components/homeassistant_hardware/firmware_config_flow.py b/homeassistant/components/homeassistant_hardware/firmware_config_flow.py index 1b4840e5a9839a..7519e0ae39471b 100644 --- a/homeassistant/components/homeassistant_hardware/firmware_config_flow.py +++ b/homeassistant/components/homeassistant_hardware/firmware_config_flow.py @@ -7,6 +7,8 @@ import logging from typing import Any +from ha_silabs_firmware_client import FirmwareUpdateClient + from homeassistant.components.hassio import ( AddonError, AddonInfo, @@ -22,17 +24,17 @@ ) from homeassistant.core import callback from homeassistant.data_entry_flow import AbortFlow +from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.hassio import is_hassio -from . import silabs_multiprotocol_addon from .const import OTBR_DOMAIN, ZHA_DOMAIN from .util import ( ApplicationType, FirmwareInfo, OwningAddon, OwningIntegration, + async_flash_silabs_firmware, get_otbr_addon_manager, - get_zigbee_flasher_addon_manager, guess_firmware_info, guess_hardware_owners, probe_silabs_firmware_info, @@ -61,6 +63,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self.addon_install_task: asyncio.Task | None = None self.addon_start_task: asyncio.Task | None = None self.addon_uninstall_task: asyncio.Task | None = None + self.firmware_install_task: asyncio.Task | None = None def _get_translation_placeholders(self) -> dict[str, str]: """Shared translation placeholders.""" @@ -77,22 +80,6 @@ def _get_translation_placeholders(self) -> dict[str, str]: return placeholders - async def _async_set_addon_config( - self, config: dict, addon_manager: AddonManager - ) -> None: - """Set add-on config.""" - try: - await addon_manager.async_set_addon_options(config) - except AddonError as err: - _LOGGER.error(err) - raise AbortFlow( - "addon_set_config_failed", - description_placeholders={ - **self._get_translation_placeholders(), - "addon_name": addon_manager.addon_name, - }, - ) from err - async def _async_get_addon_info(self, addon_manager: AddonManager) -> AddonInfo: """Return add-on info.""" try: @@ -150,98 +137,72 @@ async def _probe_firmware_info( ) ) - async def async_step_pick_firmware_zigbee( - self, user_input: dict[str, Any] | None = None - ) -> ConfigFlowResult: - """Pick Zigbee firmware.""" - if not await self._probe_firmware_info(): - return self.async_abort( - reason="unsupported_firmware", - description_placeholders=self._get_translation_placeholders(), - ) - - # Allow the stick to be used with ZHA without flashing - if ( - self._probed_firmware_info is not None - and self._probed_firmware_info.firmware_type == ApplicationType.EZSP - ): - return await self.async_step_confirm_zigbee() - - if not is_hassio(self.hass): - return self.async_abort( - reason="not_hassio", - description_placeholders=self._get_translation_placeholders(), - ) - - # Only flash new firmware if we need to - fw_flasher_manager = get_zigbee_flasher_addon_manager(self.hass) - addon_info = await self._async_get_addon_info(fw_flasher_manager) - - if addon_info.state == AddonState.NOT_INSTALLED: - return await self.async_step_install_zigbee_flasher_addon() - - if addon_info.state == AddonState.NOT_RUNNING: - return await self.async_step_run_zigbee_flasher_addon() - - # If the addon is already installed and running, fail - return self.async_abort( - reason="addon_already_running", - description_placeholders={ - **self._get_translation_placeholders(), - "addon_name": fw_flasher_manager.addon_name, - }, - ) - - async def async_step_install_zigbee_flasher_addon( - self, user_input: dict[str, Any] | None = None - ) -> ConfigFlowResult: - """Show progress dialog for installing the Zigbee flasher addon.""" - return await self._install_addon( - get_zigbee_flasher_addon_manager(self.hass), - "install_zigbee_flasher_addon", - "run_zigbee_flasher_addon", - ) - - async def _install_addon( + async def _install_firmware_step( self, - addon_manager: silabs_multiprotocol_addon.WaitingAddonManager, + fw_update_url: str, + fw_type: str, + firmware_name: str, + expected_installed_firmware_type: ApplicationType, step_id: str, next_step_id: str, ) -> ConfigFlowResult: - """Show progress dialog for installing an addon.""" - addon_info = await self._async_get_addon_info(addon_manager) + assert self._device is not None - _LOGGER.debug("Flasher addon state: %s", addon_info) + if not self.firmware_install_task: + session = async_get_clientsession(self.hass) + client = FirmwareUpdateClient(fw_update_url, session) + manifest = await client.async_update_data() - if not self.addon_install_task: - self.addon_install_task = self.hass.async_create_task( - addon_manager.async_install_addon_waiting(), - "Addon install", + fw_meta = next( + fw for fw in manifest.firmwares if fw.filename.startswith(fw_type) ) - if not self.addon_install_task.done(): + fw_data = await client.async_fetch_firmware(fw_meta) + self.firmware_install_task = self.hass.async_create_task( + async_flash_silabs_firmware( + hass=self.hass, + device=self._device, + fw_data=fw_data, + expected_installed_firmware_type=expected_installed_firmware_type, + bootloader_reset_type=None, + progress_callback=lambda offset, total: self.async_update_progress( + offset / total + ), + ), + f"Flash {firmware_name} firmware", + ) + + if not self.firmware_install_task.done(): return self.async_show_progress( step_id=step_id, - progress_action="install_addon", + progress_action="install_firmware", description_placeholders={ **self._get_translation_placeholders(), - "addon_name": addon_manager.addon_name, + "firmware_name": firmware_name, }, - progress_task=self.addon_install_task, + progress_task=self.firmware_install_task, ) - try: - await self.addon_install_task - except AddonError as err: - _LOGGER.error(err) - self._failed_addon_name = addon_manager.addon_name - self._failed_addon_reason = "addon_install_failed" - return self.async_show_progress_done(next_step_id="addon_operation_failed") - finally: - self.addon_install_task = None - return self.async_show_progress_done(next_step_id=next_step_id) + async def async_step_pick_firmware_zigbee( + self, user_input: dict[str, Any] | None = None + ) -> ConfigFlowResult: + """Pick Zigbee firmware.""" + if not await self._probe_firmware_info(): + return self.async_abort( + reason="unsupported_firmware", + description_placeholders=self._get_translation_placeholders(), + ) + + return await self.async_step_install_zigbee_firmware() + + async def async_step_install_zigbee_firmware( + self, user_input: dict[str, Any] | None = None + ) -> ConfigFlowResult: + """Install Zigbee firmware.""" + raise NotImplementedError + async def async_step_addon_operation_failed( self, user_input: dict[str, Any] | None = None ) -> ConfigFlowResult: @@ -254,130 +215,73 @@ async def async_step_addon_operation_failed( }, ) - async def async_step_run_zigbee_flasher_addon( + async def async_step_confirm_zigbee( self, user_input: dict[str, Any] | None = None ) -> ConfigFlowResult: - """Configure the flasher addon to point to the SkyConnect and run it.""" - fw_flasher_manager = get_zigbee_flasher_addon_manager(self.hass) - addon_info = await self._async_get_addon_info(fw_flasher_manager) - + """Confirm Zigbee setup.""" assert self._device is not None - new_addon_config = { - **addon_info.options, - "device": self._device, - "baudrate": 115200, - "bootloader_baudrate": 115200, - "flow_control": True, - } - - _LOGGER.debug("Reconfiguring flasher addon with %s", new_addon_config) - await self._async_set_addon_config(new_addon_config, fw_flasher_manager) - - if not self.addon_start_task: - - async def start_and_wait_until_done() -> None: - await fw_flasher_manager.async_start_addon_waiting() - # Now that the addon is running, wait for it to finish - await fw_flasher_manager.async_wait_until_addon_state( - AddonState.NOT_RUNNING - ) - - self.addon_start_task = self.hass.async_create_task( - start_and_wait_until_done() - ) + assert self._hardware_name is not None - if not self.addon_start_task.done(): - return self.async_show_progress( - step_id="run_zigbee_flasher_addon", - progress_action="run_zigbee_flasher_addon", - description_placeholders={ - **self._get_translation_placeholders(), - "addon_name": fw_flasher_manager.addon_name, - }, - progress_task=self.addon_start_task, + if user_input is None: + return self.async_show_form( + step_id="confirm_zigbee", + description_placeholders=self._get_translation_placeholders(), ) - try: - await self.addon_start_task - except (AddonError, AbortFlow) as err: - _LOGGER.error(err) - self._failed_addon_name = fw_flasher_manager.addon_name - self._failed_addon_reason = "addon_start_failed" - return self.async_show_progress_done(next_step_id="addon_operation_failed") - finally: - self.addon_start_task = None - - return self.async_show_progress_done( - next_step_id="uninstall_zigbee_flasher_addon" - ) - - async def async_step_uninstall_zigbee_flasher_addon( - self, user_input: dict[str, Any] | None = None - ) -> ConfigFlowResult: - """Uninstall the flasher addon.""" - fw_flasher_manager = get_zigbee_flasher_addon_manager(self.hass) - - if not self.addon_uninstall_task: - _LOGGER.debug("Uninstalling flasher addon") - self.addon_uninstall_task = self.hass.async_create_task( - fw_flasher_manager.async_uninstall_addon_waiting() + if not await self._probe_firmware_info(probe_methods=(ApplicationType.EZSP,)): + return self.async_abort( + reason="unsupported_firmware", + description_placeholders=self._get_translation_placeholders(), ) - if not self.addon_uninstall_task.done(): - return self.async_show_progress( - step_id="uninstall_zigbee_flasher_addon", - progress_action="uninstall_zigbee_flasher_addon", - description_placeholders={ - **self._get_translation_placeholders(), - "addon_name": fw_flasher_manager.addon_name, + await self.hass.config_entries.flow.async_init( + ZHA_DOMAIN, + context={"source": "hardware"}, + data={ + "name": self._hardware_name, + "port": { + "path": self._device, + "baudrate": 115200, + "flow_control": "hardware", }, - progress_task=self.addon_uninstall_task, - ) + "radio_type": "ezsp", + }, + ) - try: - await self.addon_uninstall_task - except (AddonError, AbortFlow) as err: - _LOGGER.error(err) - # The uninstall failing isn't critical so we can just continue - finally: - self.addon_uninstall_task = None + return self._async_flow_finished() - return self.async_show_progress_done(next_step_id="confirm_zigbee") + async def _ensure_thread_addon_setup(self) -> ConfigFlowResult | None: + """Ensure the OTBR addon is set up and not running.""" - async def async_step_confirm_zigbee( - self, user_input: dict[str, Any] | None = None - ) -> ConfigFlowResult: - """Confirm Zigbee setup.""" - assert self._device is not None - assert self._hardware_name is not None - - if not await self._probe_firmware_info(probe_methods=(ApplicationType.EZSP,)): + # We install the OTBR addon no matter what, since it is required to use Thread + if not is_hassio(self.hass): return self.async_abort( - reason="unsupported_firmware", + reason="not_hassio_thread", description_placeholders=self._get_translation_placeholders(), ) - if user_input is not None: - await self.hass.config_entries.flow.async_init( - ZHA_DOMAIN, - context={"source": "hardware"}, - data={ - "name": self._hardware_name, - "port": { - "path": self._device, - "baudrate": 115200, - "flow_control": "hardware", + otbr_manager = get_otbr_addon_manager(self.hass) + addon_info = await self._async_get_addon_info(otbr_manager) + + if addon_info.state == AddonState.NOT_INSTALLED: + return await self.async_step_install_otbr_addon() + + if addon_info.state == AddonState.RUNNING: + # We only fail setup if we have an instance of OTBR running *and* it's + # pointing to different hardware + if addon_info.options["device"] != self._device: + return self.async_abort( + reason="otbr_addon_already_running", + description_placeholders={ + **self._get_translation_placeholders(), + "addon_name": otbr_manager.addon_name, }, - "radio_type": "ezsp", - }, - ) + ) - return self._async_flow_finished() + # Otherwise, stop the addon before continuing to flash firmware + await otbr_manager.async_stop_addon() - return self.async_show_form( - step_id="confirm_zigbee", - description_placeholders=self._get_translation_placeholders(), - ) + return None async def async_step_pick_firmware_thread( self, user_input: dict[str, Any] | None = None @@ -389,59 +293,97 @@ async def async_step_pick_firmware_thread( description_placeholders=self._get_translation_placeholders(), ) - # We install the OTBR addon no matter what, since it is required to use Thread - if not is_hassio(self.hass): - return self.async_abort( - reason="not_hassio_thread", - description_placeholders=self._get_translation_placeholders(), - ) + if result := await self._ensure_thread_addon_setup(): + return result - otbr_manager = get_otbr_addon_manager(self.hass) - addon_info = await self._async_get_addon_info(otbr_manager) + return await self.async_step_install_thread_firmware() - if addon_info.state == AddonState.NOT_INSTALLED: - return await self.async_step_install_otbr_addon() - - if addon_info.state == AddonState.NOT_RUNNING: - return await self.async_step_start_otbr_addon() - - # If the addon is already installed and running, fail - return self.async_abort( - reason="otbr_addon_already_running", - description_placeholders={ - **self._get_translation_placeholders(), - "addon_name": otbr_manager.addon_name, - }, - ) + async def async_step_install_thread_firmware( + self, user_input: dict[str, Any] | None = None + ) -> ConfigFlowResult: + """Install Thread firmware.""" + raise NotImplementedError async def async_step_install_otbr_addon( self, user_input: dict[str, Any] | None = None ) -> ConfigFlowResult: """Show progress dialog for installing the OTBR addon.""" - return await self._install_addon( - get_otbr_addon_manager(self.hass), "install_otbr_addon", "start_otbr_addon" - ) + addon_manager = get_otbr_addon_manager(self.hass) + addon_info = await self._async_get_addon_info(addon_manager) + + _LOGGER.debug("OTBR addon info: %s", addon_info) + + if not self.addon_install_task: + self.addon_install_task = self.hass.async_create_task( + addon_manager.async_install_addon_waiting(), + "OTBR addon install", + ) + + if not self.addon_install_task.done(): + return self.async_show_progress( + step_id="install_otbr_addon", + progress_action="install_addon", + description_placeholders={ + **self._get_translation_placeholders(), + "addon_name": addon_manager.addon_name, + }, + progress_task=self.addon_install_task, + ) + + try: + await self.addon_install_task + except AddonError as err: + _LOGGER.error(err) + self._failed_addon_name = addon_manager.addon_name + self._failed_addon_reason = "addon_install_failed" + return self.async_show_progress_done(next_step_id="addon_operation_failed") + finally: + self.addon_install_task = None + + return self.async_show_progress_done(next_step_id="install_thread_firmware") async def async_step_start_otbr_addon( self, user_input: dict[str, Any] | None = None ) -> ConfigFlowResult: """Configure OTBR to point to the SkyConnect and run the addon.""" otbr_manager = get_otbr_addon_manager(self.hass) - addon_info = await self._async_get_addon_info(otbr_manager) - assert self._device is not None - new_addon_config = { - **addon_info.options, - "device": self._device, - "baudrate": 460800, - "flow_control": True, - "autoflash_firmware": True, - } + if not self.addon_start_task: + # Before we start the addon, confirm that the correct firmware is running + # and populate `self._probed_firmware_info` with the correct information + if not await self._probe_firmware_info( + probe_methods=(ApplicationType.SPINEL,) + ): + return self.async_abort( + reason="unsupported_firmware", + description_placeholders=self._get_translation_placeholders(), + ) - _LOGGER.debug("Reconfiguring OTBR addon with %s", new_addon_config) - await self._async_set_addon_config(new_addon_config, otbr_manager) + addon_info = await self._async_get_addon_info(otbr_manager) + + assert self._device is not None + new_addon_config = { + **addon_info.options, + "device": self._device, + "baudrate": 460800, + "flow_control": True, + "autoflash_firmware": False, + } + + _LOGGER.debug("Reconfiguring OTBR addon with %s", new_addon_config) + + try: + await otbr_manager.async_set_addon_options(new_addon_config) + except AddonError as err: + _LOGGER.error(err) + raise AbortFlow( + "addon_set_config_failed", + description_placeholders={ + **self._get_translation_placeholders(), + "addon_name": otbr_manager.addon_name, + }, + ) from err - if not self.addon_start_task: self.addon_start_task = self.hass.async_create_task( otbr_manager.async_start_addon_waiting() ) @@ -475,20 +417,14 @@ async def async_step_confirm_otbr( """Confirm OTBR setup.""" assert self._device is not None - if not await self._probe_firmware_info(probe_methods=(ApplicationType.SPINEL,)): - return self.async_abort( - reason="unsupported_firmware", + if user_input is None: + return self.async_show_form( + step_id="confirm_otbr", description_placeholders=self._get_translation_placeholders(), ) - if user_input is not None: - # OTBR discovery is done automatically via hassio - return self._async_flow_finished() - - return self.async_show_form( - step_id="confirm_otbr", - description_placeholders=self._get_translation_placeholders(), - ) + # OTBR discovery is done automatically via hassio + return self._async_flow_finished() @abstractmethod def _async_flow_finished(self) -> ConfigFlowResult: diff --git a/homeassistant/components/homeassistant_hardware/strings.json b/homeassistant/components/homeassistant_hardware/strings.json index e184f9b3a85647..99172c963b8836 100644 --- a/homeassistant/components/homeassistant_hardware/strings.json +++ b/homeassistant/components/homeassistant_hardware/strings.json @@ -10,22 +10,6 @@ "pick_firmware_thread": "Thread" } }, - "install_zigbee_flasher_addon": { - "title": "Installing flasher", - "description": "Installing the Silicon Labs Flasher add-on." - }, - "run_zigbee_flasher_addon": { - "title": "Installing Zigbee firmware", - "description": "Installing Zigbee firmware. This will take about a minute." - }, - "uninstall_zigbee_flasher_addon": { - "title": "Removing flasher", - "description": "Removing the Silicon Labs Flasher add-on." - }, - "zigbee_flasher_failed": { - "title": "Zigbee installation failed", - "description": "The Zigbee firmware installation process was unsuccessful. Ensure no other software is trying to communicate with the {model} and try again." - }, "confirm_zigbee": { "title": "Zigbee setup complete", "description": "Your {model} is now a Zigbee coordinator and will be shown as discovered by the Zigbee Home Automation integration." @@ -55,9 +39,7 @@ "unsupported_firmware": "The radio firmware on your {model} could not be determined. Make sure that no other integration or add-on is currently trying to communicate with the device. If you are running Home Assistant OS in a virtual machine or in Docker, please make sure that permissions are set correctly for the device." }, "progress": { - "install_zigbee_flasher_addon": "The Silicon Labs Flasher add-on is installed, this may take a few minutes.", - "run_zigbee_flasher_addon": "Please wait while Zigbee firmware is installed to your {model}, this will take a few minutes. Do not make any changes to your hardware or software until this finishes.", - "uninstall_zigbee_flasher_addon": "The Silicon Labs Flasher add-on is being removed." + "install_firmware": "Please wait while {firmware_name} firmware is installed to your {model}, this will take a few minutes. Do not make any changes to your hardware or software until this finishes." } } }, @@ -110,16 +92,6 @@ "data": { "disable_multi_pan": "Disable multiprotocol support" } - }, - "install_flasher_addon": { - "title": "The Silicon Labs Flasher add-on installation has started" - }, - "configure_flasher_addon": { - "title": "The Silicon Labs Flasher add-on installation has started" - }, - "start_flasher_addon": { - "title": "Installing firmware", - "description": "Zigbee firmware is now being installed. This will take a few minutes." } }, "error": { diff --git a/homeassistant/components/homeassistant_hardware/update.py b/homeassistant/components/homeassistant_hardware/update.py index 1b0f15ca02182c..831d9f3f4da78c 100644 --- a/homeassistant/components/homeassistant_hardware/update.py +++ b/homeassistant/components/homeassistant_hardware/update.py @@ -2,15 +2,12 @@ from __future__ import annotations -from collections.abc import AsyncIterator, Callable -from contextlib import AsyncExitStack, asynccontextmanager +from collections.abc import Callable from dataclasses import dataclass import logging from typing import Any, cast from ha_silabs_firmware_client import FirmwareManifest, FirmwareMetadata -from universal_silabs_flasher.firmware import parse_firmware_image -from universal_silabs_flasher.flasher import Flasher from yarl import URL from homeassistant.components.update import ( @@ -20,18 +17,12 @@ ) from homeassistant.config_entries import ConfigEntry from homeassistant.core import CALLBACK_TYPE, callback -from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers.restore_state import ExtraStoredData from homeassistant.helpers.update_coordinator import CoordinatorEntity from .coordinator import FirmwareUpdateCoordinator from .helpers import async_register_firmware_info_callback -from .util import ( - ApplicationType, - FirmwareInfo, - guess_firmware_info, - probe_silabs_firmware_info, -) +from .util import ApplicationType, FirmwareInfo, async_flash_silabs_firmware _LOGGER = logging.getLogger(__name__) @@ -249,19 +240,11 @@ def _update_progress(self, offset: int, total_size: int) -> None: self._attr_update_percentage = round((offset * 100) / total_size) self.async_write_ha_state() - @asynccontextmanager - async def _temporarily_stop_hardware_owners( - self, device: str - ) -> AsyncIterator[None]: - """Temporarily stop addons and integrations communicating with the device.""" - firmware_info = await guess_firmware_info(self.hass, device) - _LOGGER.debug("Identified firmware info: %s", firmware_info) - - async with AsyncExitStack() as stack: - for owner in firmware_info.owners: - await stack.enter_async_context(owner.temporarily_stop(self.hass)) - - yield + # Switch to an indeterminate progress bar after installation is complete, since + # we probe the firmware after flashing + if offset == total_size: + self._attr_update_percentage = None + self.async_write_ha_state() async def async_install( self, version: str | None, backup: bool, **kwargs: Any @@ -278,49 +261,18 @@ async def async_install( fw_data = await self.coordinator.client.async_fetch_firmware( self._latest_firmware ) - fw_image = await self.hass.async_add_executor_job(parse_firmware_image, fw_data) - - device = self._current_device - - flasher = Flasher( - device=device, - probe_methods=( - ApplicationType.GECKO_BOOTLOADER.as_flasher_application_type(), - ApplicationType.EZSP.as_flasher_application_type(), - ApplicationType.SPINEL.as_flasher_application_type(), - ApplicationType.CPC.as_flasher_application_type(), - ), - bootloader_reset=self.bootloader_reset_type, - ) - async with self._temporarily_stop_hardware_owners(device): - try: - try: - # Enter the bootloader with indeterminate progress - await flasher.enter_bootloader() - - # Flash the firmware, with progress - await flasher.flash_firmware( - fw_image, progress_callback=self._update_progress - ) - except Exception as err: - raise HomeAssistantError("Failed to flash firmware") from err - - # Probe the running application type with indeterminate progress - self._attr_update_percentage = None - self.async_write_ha_state() - - firmware_info = await probe_silabs_firmware_info( - device, - probe_methods=(self.entity_description.expected_firmware_type,), - ) - - if firmware_info is None: - raise HomeAssistantError( - "Failed to probe the firmware after flashing" - ) + try: + firmware_info = await async_flash_silabs_firmware( + hass=self.hass, + device=self._current_device, + fw_data=fw_data, + expected_installed_firmware_type=self.entity_description.expected_firmware_type, + bootloader_reset_type=self.bootloader_reset_type, + progress_callback=self._update_progress, + ) + finally: + self._attr_in_progress = False + self.async_write_ha_state() - self._firmware_info_callback(firmware_info) - finally: - self._attr_in_progress = False - self.async_write_ha_state() + self._firmware_info_callback(firmware_info) diff --git a/homeassistant/components/homeassistant_hardware/util.py b/homeassistant/components/homeassistant_hardware/util.py index 64f363e4f23983..d84f4f75ff7083 100644 --- a/homeassistant/components/homeassistant_hardware/util.py +++ b/homeassistant/components/homeassistant_hardware/util.py @@ -4,18 +4,20 @@ import asyncio from collections import defaultdict -from collections.abc import AsyncIterator, Iterable -from contextlib import asynccontextmanager +from collections.abc import AsyncIterator, Callable, Iterable +from contextlib import AsyncExitStack, asynccontextmanager from dataclasses import dataclass from enum import StrEnum import logging from universal_silabs_flasher.const import ApplicationType as FlasherApplicationType +from universal_silabs_flasher.firmware import parse_firmware_image from universal_silabs_flasher.flasher import Flasher from homeassistant.components.hassio import AddonError, AddonManager, AddonState from homeassistant.config_entries import ConfigEntryState from homeassistant.core import HomeAssistant, callback +from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers.hassio import is_hassio from homeassistant.helpers.singleton import singleton @@ -333,3 +335,52 @@ async def probe_silabs_firmware_type( return None return fw_info.firmware_type + + +async def async_flash_silabs_firmware( + hass: HomeAssistant, + device: str, + fw_data: bytes, + expected_installed_firmware_type: ApplicationType, + bootloader_reset_type: str | None = None, + progress_callback: Callable[[int, int], None] | None = None, +) -> FirmwareInfo: + """Flash firmware to the SiLabs device.""" + firmware_info = await guess_firmware_info(hass, device) + _LOGGER.debug("Identified firmware info: %s", firmware_info) + + fw_image = await hass.async_add_executor_job(parse_firmware_image, fw_data) + + flasher = Flasher( + device=device, + probe_methods=( + ApplicationType.GECKO_BOOTLOADER.as_flasher_application_type(), + ApplicationType.EZSP.as_flasher_application_type(), + ApplicationType.SPINEL.as_flasher_application_type(), + ApplicationType.CPC.as_flasher_application_type(), + ), + bootloader_reset=bootloader_reset_type, + ) + + async with AsyncExitStack() as stack: + for owner in firmware_info.owners: + await stack.enter_async_context(owner.temporarily_stop(hass)) + + try: + # Enter the bootloader with indeterminate progress + await flasher.enter_bootloader() + + # Flash the firmware, with progress + await flasher.flash_firmware(fw_image, progress_callback=progress_callback) + except Exception as err: + raise HomeAssistantError("Failed to flash firmware") from err + + probed_firmware_info = await probe_silabs_firmware_info( + device, + probe_methods=(expected_installed_firmware_type,), + ) + + if probed_firmware_info is None: + raise HomeAssistantError("Failed to probe the firmware after flashing") + + return probed_firmware_info diff --git a/homeassistant/components/homeassistant_sky_connect/config_flow.py b/homeassistant/components/homeassistant_sky_connect/config_flow.py index eb5ea214b3e7c4..997edb54b18d92 100644 --- a/homeassistant/components/homeassistant_sky_connect/config_flow.py +++ b/homeassistant/components/homeassistant_sky_connect/config_flow.py @@ -32,6 +32,7 @@ FIRMWARE, FIRMWARE_VERSION, MANUFACTURER, + NABU_CASA_FIRMWARE_RELEASES_URL, PID, PRODUCT, SERIAL_NUMBER, @@ -45,19 +46,29 @@ if TYPE_CHECKING: - class TranslationPlaceholderProtocol(Protocol): - """Protocol describing `BaseFirmwareInstallFlow`'s translation placeholders.""" + class FirmwareInstallFlowProtocol(Protocol): + """Protocol describing `BaseFirmwareInstallFlow` for a mixin.""" def _get_translation_placeholders(self) -> dict[str, str]: return {} + async def _install_firmware_step( + self, + fw_update_url: str, + fw_type: str, + firmware_name: str, + expected_installed_firmware_type: ApplicationType, + step_id: str, + next_step_id: str, + ) -> ConfigFlowResult: ... + else: # Multiple inheritance with `Protocol` seems to break - TranslationPlaceholderProtocol = object + FirmwareInstallFlowProtocol = object -class SkyConnectTranslationMixin(ConfigEntryBaseFlow, TranslationPlaceholderProtocol): - """Translation placeholder mixin for Home Assistant SkyConnect.""" +class SkyConnectFirmwareMixin(ConfigEntryBaseFlow, FirmwareInstallFlowProtocol): + """Mixin for Home Assistant SkyConnect firmware methods.""" context: ConfigFlowContext @@ -72,9 +83,35 @@ def _get_translation_placeholders(self) -> dict[str, str]: return placeholders + async def async_step_install_zigbee_firmware( + self, user_input: dict[str, Any] | None = None + ) -> ConfigFlowResult: + """Install Zigbee firmware.""" + return await self._install_firmware_step( + fw_update_url=NABU_CASA_FIRMWARE_RELEASES_URL, + fw_type="skyconnect_zigbee_ncp", + firmware_name="Zigbee", + expected_installed_firmware_type=ApplicationType.EZSP, + step_id="install_zigbee_firmware", + next_step_id="confirm_zigbee", + ) + + async def async_step_install_thread_firmware( + self, user_input: dict[str, Any] | None = None + ) -> ConfigFlowResult: + """Install Thread firmware.""" + return await self._install_firmware_step( + fw_update_url=NABU_CASA_FIRMWARE_RELEASES_URL, + fw_type="skyconnect_openthread_rcp", + firmware_name="OpenThread", + expected_installed_firmware_type=ApplicationType.SPINEL, + step_id="install_thread_firmware", + next_step_id="start_otbr_addon", + ) + class HomeAssistantSkyConnectConfigFlow( - SkyConnectTranslationMixin, + SkyConnectFirmwareMixin, firmware_config_flow.BaseFirmwareConfigFlow, domain=DOMAIN, ): @@ -207,7 +244,7 @@ async def async_step_flashing_complete( class HomeAssistantSkyConnectOptionsFlowHandler( - SkyConnectTranslationMixin, firmware_config_flow.BaseFirmwareOptionsFlow + SkyConnectFirmwareMixin, firmware_config_flow.BaseFirmwareOptionsFlow ): """Zigbee and Thread options flow handlers.""" diff --git a/homeassistant/components/homeassistant_sky_connect/strings.json b/homeassistant/components/homeassistant_sky_connect/strings.json index a990f025e8dd09..08c8a56c30d166 100644 --- a/homeassistant/components/homeassistant_sky_connect/strings.json +++ b/homeassistant/components/homeassistant_sky_connect/strings.json @@ -48,16 +48,6 @@ "disable_multi_pan": "[%key:component::homeassistant_hardware::silabs_multiprotocol_hardware::options::step::uninstall_addon::data::disable_multi_pan%]" } }, - "install_flasher_addon": { - "title": "[%key:component::homeassistant_hardware::silabs_multiprotocol_hardware::options::step::install_flasher_addon::title%]" - }, - "configure_flasher_addon": { - "title": "[%key:component::homeassistant_hardware::silabs_multiprotocol_hardware::options::step::configure_flasher_addon::title%]" - }, - "start_flasher_addon": { - "title": "[%key:component::homeassistant_hardware::silabs_multiprotocol_hardware::options::step::start_flasher_addon::title%]", - "description": "[%key:component::homeassistant_hardware::silabs_multiprotocol_hardware::options::step::start_flasher_addon::description%]" - }, "pick_firmware": { "title": "[%key:component::homeassistant_hardware::firmware_picker::options::step::pick_firmware::title%]", "description": "[%key:component::homeassistant_hardware::firmware_picker::options::step::pick_firmware::description%]", @@ -66,18 +56,6 @@ "pick_firmware_zigbee": "[%key:component::homeassistant_hardware::firmware_picker::options::step::pick_firmware::menu_options::pick_firmware_zigbee%]" } }, - "install_zigbee_flasher_addon": { - "title": "[%key:component::homeassistant_hardware::firmware_picker::options::step::install_zigbee_flasher_addon::title%]", - "description": "[%key:component::homeassistant_hardware::firmware_picker::options::step::install_zigbee_flasher_addon::description%]" - }, - "run_zigbee_flasher_addon": { - "title": "[%key:component::homeassistant_hardware::firmware_picker::options::step::run_zigbee_flasher_addon::title%]", - "description": "[%key:component::homeassistant_hardware::firmware_picker::options::step::run_zigbee_flasher_addon::description%]" - }, - "zigbee_flasher_failed": { - "title": "[%key:component::homeassistant_hardware::firmware_picker::options::step::zigbee_flasher_failed::title%]", - "description": "[%key:component::homeassistant_hardware::firmware_picker::options::step::zigbee_flasher_failed::description%]" - }, "confirm_zigbee": { "title": "[%key:component::homeassistant_hardware::firmware_picker::options::step::confirm_zigbee::title%]", "description": "[%key:component::homeassistant_hardware::firmware_picker::options::step::confirm_zigbee::description%]" @@ -120,9 +98,7 @@ "install_addon": "[%key:component::homeassistant_hardware::silabs_multiprotocol_hardware::options::progress::install_addon%]", "start_addon": "[%key:component::homeassistant_hardware::silabs_multiprotocol_hardware::options::progress::start_addon%]", "start_otbr_addon": "[%key:component::homeassistant_hardware::silabs_multiprotocol_hardware::options::progress::start_addon%]", - "install_zigbee_flasher_addon": "[%key:component::homeassistant_hardware::firmware_picker::options::progress::install_zigbee_flasher_addon%]", - "run_zigbee_flasher_addon": "[%key:component::homeassistant_hardware::firmware_picker::options::progress::run_zigbee_flasher_addon%]", - "uninstall_zigbee_flasher_addon": "[%key:component::homeassistant_hardware::firmware_picker::options::progress::uninstall_zigbee_flasher_addon%]" + "install_firmware": "[%key:component::homeassistant_hardware::firmware_picker::options::progress::install_firmware%]" } }, "config": { @@ -136,22 +112,6 @@ "pick_firmware_thread": "[%key:component::homeassistant_hardware::firmware_picker::options::step::pick_firmware::menu_options::pick_firmware_thread%]" } }, - "install_zigbee_flasher_addon": { - "title": "[%key:component::homeassistant_hardware::firmware_picker::options::step::install_zigbee_flasher_addon::title%]", - "description": "[%key:component::homeassistant_hardware::firmware_picker::options::step::install_zigbee_flasher_addon::description%]" - }, - "run_zigbee_flasher_addon": { - "title": "[%key:component::homeassistant_hardware::firmware_picker::options::step::run_zigbee_flasher_addon::title%]", - "description": "[%key:component::homeassistant_hardware::firmware_picker::options::step::run_zigbee_flasher_addon::description%]" - }, - "uninstall_zigbee_flasher_addon": { - "title": "[%key:component::homeassistant_hardware::firmware_picker::options::step::uninstall_zigbee_flasher_addon::title%]", - "description": "[%key:component::homeassistant_hardware::firmware_picker::options::step::uninstall_zigbee_flasher_addon::description%]" - }, - "zigbee_flasher_failed": { - "title": "[%key:component::homeassistant_hardware::firmware_picker::options::step::zigbee_flasher_failed::title%]", - "description": "[%key:component::homeassistant_hardware::firmware_picker::options::step::zigbee_flasher_failed::description%]" - }, "confirm_zigbee": { "title": "[%key:component::homeassistant_hardware::firmware_picker::options::step::confirm_zigbee::title%]", "description": "[%key:component::homeassistant_hardware::firmware_picker::options::step::confirm_zigbee::description%]" @@ -191,9 +151,7 @@ "install_addon": "[%key:component::homeassistant_hardware::silabs_multiprotocol_hardware::options::progress::install_addon%]", "start_addon": "[%key:component::homeassistant_hardware::silabs_multiprotocol_hardware::options::progress::start_addon%]", "start_otbr_addon": "[%key:component::homeassistant_hardware::silabs_multiprotocol_hardware::options::progress::start_addon%]", - "install_zigbee_flasher_addon": "[%key:component::homeassistant_hardware::firmware_picker::options::progress::install_zigbee_flasher_addon%]", - "run_zigbee_flasher_addon": "[%key:component::homeassistant_hardware::firmware_picker::options::progress::run_zigbee_flasher_addon%]", - "uninstall_zigbee_flasher_addon": "[%key:component::homeassistant_hardware::firmware_picker::options::progress::uninstall_zigbee_flasher_addon%]" + "install_firmware": "[%key:component::homeassistant_hardware::firmware_picker::options::progress::install_firmware%]" } }, "exceptions": { diff --git a/homeassistant/components/homeassistant_yellow/config_flow.py b/homeassistant/components/homeassistant_yellow/config_flow.py index 1fac6bcac9688a..db844d0b0e9813 100644 --- a/homeassistant/components/homeassistant_yellow/config_flow.py +++ b/homeassistant/components/homeassistant_yellow/config_flow.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod import asyncio import logging -from typing import Any, final +from typing import TYPE_CHECKING, Any, Protocol, final import aiohttp import voluptuous as vol @@ -31,6 +31,7 @@ from homeassistant.config_entries import ( SOURCE_HARDWARE, ConfigEntry, + ConfigEntryBaseFlow, ConfigFlowResult, OptionsFlow, ) @@ -41,6 +42,7 @@ DOMAIN, FIRMWARE, FIRMWARE_VERSION, + NABU_CASA_FIRMWARE_RELEASES_URL, RADIO_DEVICE, ZHA_DOMAIN, ZHA_HW_DISCOVERY_DATA, @@ -57,8 +59,59 @@ } ) +if TYPE_CHECKING: -class HomeAssistantYellowConfigFlow(BaseFirmwareConfigFlow, domain=DOMAIN): + class FirmwareInstallFlowProtocol(Protocol): + """Protocol describing `BaseFirmwareInstallFlow` for a mixin.""" + + async def _install_firmware_step( + self, + fw_update_url: str, + fw_type: str, + firmware_name: str, + expected_installed_firmware_type: ApplicationType, + step_id: str, + next_step_id: str, + ) -> ConfigFlowResult: ... + +else: + # Multiple inheritance with `Protocol` seems to break + FirmwareInstallFlowProtocol = object + + +class YellowFirmwareMixin(ConfigEntryBaseFlow, FirmwareInstallFlowProtocol): + """Mixin for Home Assistant Yellow firmware methods.""" + + async def async_step_install_zigbee_firmware( + self, user_input: dict[str, Any] | None = None + ) -> ConfigFlowResult: + """Install Zigbee firmware.""" + return await self._install_firmware_step( + fw_update_url=NABU_CASA_FIRMWARE_RELEASES_URL, + fw_type="yellow_zigbee_ncp", + firmware_name="Zigbee", + expected_installed_firmware_type=ApplicationType.EZSP, + step_id="install_zigbee_firmware", + next_step_id="confirm_zigbee", + ) + + async def async_step_install_thread_firmware( + self, user_input: dict[str, Any] | None = None + ) -> ConfigFlowResult: + """Install Thread firmware.""" + return await self._install_firmware_step( + fw_update_url=NABU_CASA_FIRMWARE_RELEASES_URL, + fw_type="yellow_openthread_rcp", + firmware_name="OpenThread", + expected_installed_firmware_type=ApplicationType.SPINEL, + step_id="install_thread_firmware", + next_step_id="start_otbr_addon", + ) + + +class HomeAssistantYellowConfigFlow( + YellowFirmwareMixin, BaseFirmwareConfigFlow, domain=DOMAIN +): """Handle a config flow for Home Assistant Yellow.""" VERSION = 1 @@ -275,7 +328,9 @@ async def async_step_flashing_complete( class HomeAssistantYellowOptionsFlowHandler( - BaseHomeAssistantYellowOptionsFlow, BaseFirmwareOptionsFlow + YellowFirmwareMixin, + BaseHomeAssistantYellowOptionsFlow, + BaseFirmwareOptionsFlow, ): """Handle a firmware options flow for Home Assistant Yellow.""" diff --git a/homeassistant/components/homeassistant_yellow/strings.json b/homeassistant/components/homeassistant_yellow/strings.json index 41c1438b234d33..980052f9ffbfcf 100644 --- a/homeassistant/components/homeassistant_yellow/strings.json +++ b/homeassistant/components/homeassistant_yellow/strings.json @@ -71,16 +71,6 @@ "disable_multi_pan": "[%key:component::homeassistant_hardware::silabs_multiprotocol_hardware::options::step::uninstall_addon::data::disable_multi_pan%]" } }, - "install_flasher_addon": { - "title": "[%key:component::homeassistant_hardware::silabs_multiprotocol_hardware::options::step::install_flasher_addon::title%]" - }, - "configure_flasher_addon": { - "title": "[%key:component::homeassistant_hardware::silabs_multiprotocol_hardware::options::step::configure_flasher_addon::title%]" - }, - "start_flasher_addon": { - "title": "[%key:component::homeassistant_hardware::silabs_multiprotocol_hardware::options::step::start_flasher_addon::title%]", - "description": "[%key:component::homeassistant_hardware::silabs_multiprotocol_hardware::options::step::start_flasher_addon::description%]" - }, "pick_firmware": { "title": "[%key:component::homeassistant_hardware::firmware_picker::options::step::pick_firmware::title%]", "description": "[%key:component::homeassistant_hardware::firmware_picker::options::step::pick_firmware::description%]", @@ -89,18 +79,6 @@ "pick_firmware_zigbee": "[%key:component::homeassistant_hardware::firmware_picker::options::step::pick_firmware::menu_options::pick_firmware_zigbee%]" } }, - "install_zigbee_flasher_addon": { - "title": "[%key:component::homeassistant_hardware::firmware_picker::options::step::install_zigbee_flasher_addon::title%]", - "description": "[%key:component::homeassistant_hardware::firmware_picker::options::step::install_zigbee_flasher_addon::description%]" - }, - "run_zigbee_flasher_addon": { - "title": "[%key:component::homeassistant_hardware::firmware_picker::options::step::run_zigbee_flasher_addon::title%]", - "description": "[%key:component::homeassistant_hardware::firmware_picker::options::step::run_zigbee_flasher_addon::description%]" - }, - "zigbee_flasher_failed": { - "title": "[%key:component::homeassistant_hardware::firmware_picker::options::step::zigbee_flasher_failed::title%]", - "description": "[%key:component::homeassistant_hardware::firmware_picker::options::step::zigbee_flasher_failed::description%]" - }, "confirm_zigbee": { "title": "[%key:component::homeassistant_hardware::firmware_picker::options::step::confirm_zigbee::title%]", "description": "[%key:component::homeassistant_hardware::firmware_picker::options::step::confirm_zigbee::description%]" @@ -145,9 +123,7 @@ "install_addon": "[%key:component::homeassistant_hardware::silabs_multiprotocol_hardware::options::progress::install_addon%]", "start_addon": "[%key:component::homeassistant_hardware::silabs_multiprotocol_hardware::options::progress::start_addon%]", "start_otbr_addon": "[%key:component::homeassistant_hardware::silabs_multiprotocol_hardware::options::progress::start_addon%]", - "install_zigbee_flasher_addon": "[%key:component::homeassistant_hardware::firmware_picker::options::progress::install_zigbee_flasher_addon%]", - "run_zigbee_flasher_addon": "[%key:component::homeassistant_hardware::firmware_picker::options::progress::run_zigbee_flasher_addon%]", - "uninstall_zigbee_flasher_addon": "[%key:component::homeassistant_hardware::firmware_picker::options::progress::uninstall_zigbee_flasher_addon%]" + "install_firmware": "[%key:component::homeassistant_hardware::firmware_picker::options::progress::install_firmware%]" } }, "entity": { diff --git a/homeassistant/components/husqvarna_automower_ble/__init__.py b/homeassistant/components/husqvarna_automower_ble/__init__.py index ca07d1ab8d2f62..f168e84be4ce9a 100644 --- a/homeassistant/components/husqvarna_automower_ble/__init__.py +++ b/homeassistant/components/husqvarna_automower_ble/__init__.py @@ -15,12 +15,14 @@ from .const import LOGGER from .coordinator import HusqvarnaCoordinator +type HusqvarnaConfigEntry = ConfigEntry[HusqvarnaCoordinator] + PLATFORMS = [ Platform.LAWN_MOWER, ] -async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: +async def async_setup_entry(hass: HomeAssistant, entry: HusqvarnaConfigEntry) -> bool: """Set up Husqvarna Autoconnect Bluetooth from a config entry.""" address = entry.data[CONF_ADDRESS] channel_id = entry.data[CONF_CLIENT_ID] @@ -54,7 +56,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: return True -async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: +async def async_unload_entry(hass: HomeAssistant, entry: HusqvarnaConfigEntry) -> bool: """Unload a config entry.""" if unload_ok := await hass.config_entries.async_unload_platforms(entry, PLATFORMS): coordinator: HusqvarnaCoordinator = entry.runtime_data diff --git a/homeassistant/components/husqvarna_automower_ble/coordinator.py b/homeassistant/components/husqvarna_automower_ble/coordinator.py index dde3462c081160..c7781becd76553 100644 --- a/homeassistant/components/husqvarna_automower_ble/coordinator.py +++ b/homeassistant/components/husqvarna_automower_ble/coordinator.py @@ -3,30 +3,31 @@ from __future__ import annotations from datetime import timedelta +from typing import TYPE_CHECKING from automower_ble.mower import Mower from bleak import BleakError from bleak_retry_connector import close_stale_connections_by_address from homeassistant.components import bluetooth -from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed from .const import DOMAIN, LOGGER +if TYPE_CHECKING: + from . import HusqvarnaConfigEntry + SCAN_INTERVAL = timedelta(seconds=60) class HusqvarnaCoordinator(DataUpdateCoordinator[dict[str, bytes]]): """Class to manage fetching data.""" - config_entry: ConfigEntry - def __init__( self, hass: HomeAssistant, - config_entry: ConfigEntry, + config_entry: HusqvarnaConfigEntry, mower: Mower, address: str, channel_id: str, diff --git a/homeassistant/components/husqvarna_automower_ble/lawn_mower.py b/homeassistant/components/husqvarna_automower_ble/lawn_mower.py index 4b239394c2d142..4b4a16ba1dbea2 100644 --- a/homeassistant/components/husqvarna_automower_ble/lawn_mower.py +++ b/homeassistant/components/husqvarna_automower_ble/lawn_mower.py @@ -10,10 +10,10 @@ LawnMowerEntity, LawnMowerEntityFeature, ) -from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant, callback from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback +from . import HusqvarnaConfigEntry from .const import LOGGER from .coordinator import HusqvarnaCoordinator from .entity import HusqvarnaAutomowerBleEntity @@ -21,11 +21,11 @@ async def async_setup_entry( hass: HomeAssistant, - config_entry: ConfigEntry, + config_entry: HusqvarnaConfigEntry, async_add_entities: AddConfigEntryEntitiesCallback, ) -> None: """Set up AutomowerLawnMower integration from a config entry.""" - coordinator: HusqvarnaCoordinator = config_entry.runtime_data + coordinator = config_entry.runtime_data address = coordinator.address async_add_entities( diff --git a/homeassistant/components/ista_ecotrend/manifest.json b/homeassistant/components/ista_ecotrend/manifest.json index baa5fbde9c0a3d..53638ac9a29bda 100644 --- a/homeassistant/components/ista_ecotrend/manifest.json +++ b/homeassistant/components/ista_ecotrend/manifest.json @@ -7,5 +7,6 @@ "documentation": "https://www.home-assistant.io/integrations/ista_ecotrend", "iot_class": "cloud_polling", "loggers": ["pyecotrend_ista"], + "quality_scale": "gold", "requirements": ["pyecotrend-ista==3.3.1"] } diff --git a/homeassistant/components/ista_ecotrend/quality_scale.yaml b/homeassistant/components/ista_ecotrend/quality_scale.yaml index a06aef7297fa77..ef665b04d416c6 100644 --- a/homeassistant/components/ista_ecotrend/quality_scale.yaml +++ b/homeassistant/components/ista_ecotrend/quality_scale.yaml @@ -50,14 +50,18 @@ rules: discovery: status: exempt comment: The integration is a web service, there are no discoverable devices. - docs-data-update: todo - docs-examples: todo + docs-data-update: done + docs-examples: + status: done + comment: describes how to use the integration with the statistics dashboard docs-known-limitations: done docs-supported-devices: done docs-supported-functions: done - docs-troubleshooting: todo + docs-troubleshooting: done docs-use-cases: done - dynamic-devices: todo + dynamic-devices: + status: exempt + comment: changes are very rare (usually takes years) entity-category: status: done comment: The default category is appropriate. @@ -67,8 +71,12 @@ rules: exception-translations: done icon-translations: done reconfiguration-flow: done - repair-issues: todo - stale-devices: todo + repair-issues: + status: exempt + comment: integration has no repairs + stale-devices: + status: exempt + comment: integration has no stale devices # Platinum async-dependency: todo diff --git a/homeassistant/components/litterrobot/icons.json b/homeassistant/components/litterrobot/icons.json index 163ad80c0a80c4..2e0cafe43d97ba 100644 --- a/homeassistant/components/litterrobot/icons.json +++ b/homeassistant/components/litterrobot/icons.json @@ -46,6 +46,9 @@ "motor_fault_short": "mdi:flash-off", "motor_ot_amps": "mdi:flash-alert" } + }, + "total_cycles": { + "default": "mdi:counter" } }, "switch": { diff --git a/homeassistant/components/litterrobot/manifest.json b/homeassistant/components/litterrobot/manifest.json index 81f987f8c1f699..a8945e482bf206 100644 --- a/homeassistant/components/litterrobot/manifest.json +++ b/homeassistant/components/litterrobot/manifest.json @@ -13,5 +13,5 @@ "iot_class": "cloud_push", "loggers": ["pylitterbot"], "quality_scale": "bronze", - "requirements": ["pylitterbot==2024.2.0"] + "requirements": ["pylitterbot==2024.2.1"] } diff --git a/homeassistant/components/litterrobot/sensor.py b/homeassistant/components/litterrobot/sensor.py index cdd9a1c08a5a5f..b7ddf3c3249bac 100644 --- a/homeassistant/components/litterrobot/sensor.py +++ b/homeassistant/components/litterrobot/sensor.py @@ -115,6 +115,14 @@ class RobotSensorEntityDescription(SensorEntityDescription, Generic[_WhiskerEnti lambda robot: status.lower() if (status := robot.status_code) else None ), ), + RobotSensorEntityDescription[LitterRobot]( + key="total_cycles", + translation_key="total_cycles", + entity_category=EntityCategory.DIAGNOSTIC, + entity_registry_enabled_default=False, + state_class=SensorStateClass.TOTAL_INCREASING, + value_fn=lambda robot: robot.cycle_count, + ), ], LitterRobot4: [ RobotSensorEntityDescription[LitterRobot4]( diff --git a/homeassistant/components/litterrobot/strings.json b/homeassistant/components/litterrobot/strings.json index ba5472918d365a..d9931d71a0dbc6 100644 --- a/homeassistant/components/litterrobot/strings.json +++ b/homeassistant/components/litterrobot/strings.json @@ -118,6 +118,10 @@ "spf": "Pinch detect at startup" } }, + "total_cycles": { + "name": "Total cycles", + "unit_of_measurement": "cycles" + }, "waste_drawer": { "name": "Waste drawer" } diff --git a/homeassistant/components/nfandroidtv/__init__.py b/homeassistant/components/nfandroidtv/__init__.py index 50674a7ed4635d..bdda0d30356b04 100644 --- a/homeassistant/components/nfandroidtv/__init__.py +++ b/homeassistant/components/nfandroidtv/__init__.py @@ -1,11 +1,8 @@ """The NFAndroidTV integration.""" -from notifications_android_tv.notifications import ConnectError, Notifications - from homeassistant.config_entries import ConfigEntry from homeassistant.const import CONF_HOST, Platform from homeassistant.core import HomeAssistant -from homeassistant.exceptions import ConfigEntryNotReady from homeassistant.helpers import config_validation as cv, discovery from homeassistant.helpers.typing import ConfigType @@ -25,14 +22,8 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Set up NFAndroidTV from a config entry.""" - try: - await hass.async_add_executor_job(Notifications, entry.data[CONF_HOST]) - except ConnectError as ex: - raise ConfigEntryNotReady( - f"Failed to connect to host: {entry.data[CONF_HOST]}" - ) from ex - hass.data.setdefault(DOMAIN, {}) + hass.data[DOMAIN][entry.entry_id] = entry.data[CONF_HOST] hass.async_create_task( discovery.async_load_platform( diff --git a/homeassistant/components/nfandroidtv/notify.py b/homeassistant/components/nfandroidtv/notify.py index f6d9bcde4325e7..c1c19a600b98d3 100644 --- a/homeassistant/components/nfandroidtv/notify.py +++ b/homeassistant/components/nfandroidtv/notify.py @@ -6,7 +6,7 @@ import logging from typing import Any -from notifications_android_tv import Notifications +from notifications_android_tv.notifications import ConnectError, Notifications import requests from requests.auth import HTTPBasicAuth, HTTPDigestAuth import voluptuous as vol @@ -19,7 +19,7 @@ ) from homeassistant.const import CONF_HOST from homeassistant.core import HomeAssistant -from homeassistant.exceptions import ServiceValidationError +from homeassistant.exceptions import HomeAssistantError, ServiceValidationError from homeassistant.helpers import config_validation as cv from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType @@ -59,9 +59,9 @@ async def async_get_service( """Get the NFAndroidTV notification service.""" if discovery_info is None: return None - notify = await hass.async_add_executor_job(Notifications, discovery_info[CONF_HOST]) + return NFAndroidTVNotificationService( - notify, + discovery_info[CONF_HOST], hass.config.is_allowed_path, ) @@ -71,15 +71,24 @@ class NFAndroidTVNotificationService(BaseNotificationService): def __init__( self, - notify: Notifications, + host: str, is_allowed_path: Any, ) -> None: """Initialize the service.""" - self.notify = notify + self.host = host self.is_allowed_path = is_allowed_path + self.notify: Notifications | None = None def send_message(self, message: str, **kwargs: Any) -> None: - """Send a message to a Android TV device.""" + """Send a message to an Android TV device.""" + if self.notify is None: + try: + self.notify = Notifications(self.host) + except ConnectError as err: + raise HomeAssistantError( + f"Failed to connect to host: {self.host}" + ) from err + data: dict | None = kwargs.get(ATTR_DATA) title = kwargs.get(ATTR_TITLE, ATTR_TITLE_DEFAULT) duration = None @@ -178,18 +187,22 @@ def send_message(self, message: str, **kwargs: Any) -> None: translation_key="invalid_notification_icon", translation_placeholders={"type": type(icondata).__name__}, ) - self.notify.send( - message, - title=title, - duration=duration, - fontsize=fontsize, - position=position, - bkgcolor=bkgcolor, - transparency=transparency, - interrupt=interrupt, - icon=icon, - image_file=image_file, - ) + + try: + self.notify.send( + message, + title=title, + duration=duration, + fontsize=fontsize, + position=position, + bkgcolor=bkgcolor, + transparency=transparency, + interrupt=interrupt, + icon=icon, + image_file=image_file, + ) + except ConnectError as err: + raise HomeAssistantError(f"Failed to connect to host: {self.host}") from err def load_file( self, diff --git a/homeassistant/components/ollama/__init__.py b/homeassistant/components/ollama/__init__.py index c828ee0af9f764..90d2012766decd 100644 --- a/homeassistant/components/ollama/__init__.py +++ b/homeassistant/components/ollama/__init__.py @@ -8,11 +8,16 @@ import httpx import ollama -from homeassistant.config_entries import ConfigEntry +from homeassistant.config_entries import ConfigEntry, ConfigSubentry from homeassistant.const import CONF_URL, Platform from homeassistant.core import HomeAssistant from homeassistant.exceptions import ConfigEntryNotReady -from homeassistant.helpers import config_validation as cv +from homeassistant.helpers import ( + config_validation as cv, + device_registry as dr, + entity_registry as er, +) +from homeassistant.helpers.typing import ConfigType from homeassistant.util.ssl import get_default_context from .const import ( @@ -42,8 +47,16 @@ CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN) PLATFORMS = (Platform.CONVERSATION,) +type OllamaConfigEntry = ConfigEntry[ollama.AsyncClient] + + +async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: + """Set up Ollama.""" + await async_migrate_integration(hass) + return True + -async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: +async def async_setup_entry(hass: HomeAssistant, entry: OllamaConfigEntry) -> bool: """Set up Ollama from a config entry.""" settings = {**entry.data, **entry.options} client = ollama.AsyncClient(host=settings[CONF_URL], verify=get_default_context()) @@ -53,8 +66,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: except (TimeoutError, httpx.ConnectError) as err: raise ConfigEntryNotReady(err) from err - hass.data.setdefault(DOMAIN, {})[entry.entry_id] = client - + entry.runtime_data = client await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) return True @@ -63,5 +75,69 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Unload Ollama.""" if not await hass.config_entries.async_unload_platforms(entry, PLATFORMS): return False - hass.data[DOMAIN].pop(entry.entry_id) return True + + +async def async_migrate_integration(hass: HomeAssistant) -> None: + """Migrate integration entry structure.""" + + entries = hass.config_entries.async_entries(DOMAIN) + if not any(entry.version == 1 for entry in entries): + return + + api_keys_entries: dict[str, ConfigEntry] = {} + entity_registry = er.async_get(hass) + device_registry = dr.async_get(hass) + + for entry in entries: + use_existing = False + subentry = ConfigSubentry( + data=entry.options, + subentry_type="conversation", + title=entry.title, + unique_id=None, + ) + if entry.data[CONF_URL] not in api_keys_entries: + use_existing = True + api_keys_entries[entry.data[CONF_URL]] = entry + + parent_entry = api_keys_entries[entry.data[CONF_URL]] + + hass.config_entries.async_add_subentry(parent_entry, subentry) + conversation_entity = entity_registry.async_get_entity_id( + "conversation", + DOMAIN, + entry.entry_id, + ) + if conversation_entity is not None: + entity_registry.async_update_entity( + conversation_entity, + config_entry_id=parent_entry.entry_id, + config_subentry_id=subentry.subentry_id, + new_unique_id=subentry.subentry_id, + ) + + device = device_registry.async_get_device( + identifiers={(DOMAIN, entry.entry_id)} + ) + if device is not None: + device_registry.async_update_device( + device.id, + new_identifiers={(DOMAIN, subentry.subentry_id)}, + add_config_subentry_id=subentry.subentry_id, + add_config_entry_id=parent_entry.entry_id, + ) + if parent_entry.entry_id != entry.entry_id: + device_registry.async_update_device( + device.id, + remove_config_entry_id=entry.entry_id, + ) + + if not use_existing: + await hass.config_entries.async_remove(entry.entry_id) + else: + hass.config_entries.async_update_entry( + entry, + options={}, + version=2, + ) diff --git a/homeassistant/components/ollama/config_flow.py b/homeassistant/components/ollama/config_flow.py index b94a0fc621d76f..58b557549e1594 100644 --- a/homeassistant/components/ollama/config_flow.py +++ b/homeassistant/components/ollama/config_flow.py @@ -14,12 +14,14 @@ from homeassistant.config_entries import ( ConfigEntry, + ConfigEntryState, ConfigFlow, ConfigFlowResult, - OptionsFlow, + ConfigSubentryFlow, + SubentryFlowResult, ) -from homeassistant.const import CONF_LLM_HASS_API, CONF_URL -from homeassistant.core import HomeAssistant +from homeassistant.const import CONF_LLM_HASS_API, CONF_NAME, CONF_URL +from homeassistant.core import HomeAssistant, callback from homeassistant.helpers import llm from homeassistant.helpers.selector import ( BooleanSelector, @@ -70,7 +72,7 @@ class OllamaConfigFlow(ConfigFlow, domain=DOMAIN): """Handle a config flow for Ollama.""" - VERSION = 1 + VERSION = 2 def __init__(self) -> None: """Initialize config flow.""" @@ -94,6 +96,8 @@ async def async_step_user( errors = {} + self._async_abort_entries_match({CONF_URL: self.url}) + try: self.client = ollama.AsyncClient( host=self.url, verify=get_default_context() @@ -146,8 +150,16 @@ async def async_step_user( return await self.async_step_download() return self.async_create_entry( - title=_get_title(self.model), + title=self.url, data={CONF_URL: self.url, CONF_MODEL: self.model}, + subentries=[ + { + "subentry_type": "conversation", + "data": {}, + "title": _get_title(self.model), + "unique_id": None, + } + ], ) async def async_step_download( @@ -189,6 +201,14 @@ async def async_step_finish( return self.async_create_entry( title=_get_title(self.model), data={CONF_URL: self.url, CONF_MODEL: self.model}, + subentries=[ + { + "subentry_type": "conversation", + "data": {}, + "title": _get_title(self.model), + "unique_id": None, + } + ], ) async def async_step_failed( @@ -197,41 +217,62 @@ async def async_step_failed( """Step after model downloading has failed.""" return self.async_abort(reason="download_failed") - @staticmethod - def async_get_options_flow( - config_entry: ConfigEntry, - ) -> OptionsFlow: - """Create the options flow.""" - return OllamaOptionsFlow(config_entry) + @classmethod + @callback + def async_get_supported_subentry_types( + cls, config_entry: ConfigEntry + ) -> dict[str, type[ConfigSubentryFlow]]: + """Return subentries supported by this integration.""" + return {"conversation": ConversationSubentryFlowHandler} -class OllamaOptionsFlow(OptionsFlow): - """Ollama options flow.""" +class ConversationSubentryFlowHandler(ConfigSubentryFlow): + """Flow for managing conversation subentries.""" - def __init__(self, config_entry: ConfigEntry) -> None: - """Initialize options flow.""" - self.url: str = config_entry.data[CONF_URL] - self.model: str = config_entry.data[CONF_MODEL] + @property + def _is_new(self) -> bool: + """Return if this is a new subentry.""" + return self.source == "user" - async def async_step_init( + async def async_step_set_options( self, user_input: dict[str, Any] | None = None - ) -> ConfigFlowResult: - """Manage the options.""" - if user_input is not None: + ) -> SubentryFlowResult: + """Set conversation options.""" + # abort if entry is not loaded + if self._get_entry().state != ConfigEntryState.LOADED: + return self.async_abort(reason="entry_not_loaded") + + errors: dict[str, str] = {} + + if user_input is None: + if self._is_new: + options = {} + else: + options = self._get_reconfigure_subentry().data.copy() + + elif self._is_new: return self.async_create_entry( - title=_get_title(self.model), data=user_input + title=user_input.pop(CONF_NAME), + data=user_input, + ) + else: + return self.async_update_and_abort( + self._get_entry(), + self._get_reconfigure_subentry(), + data=user_input, ) - options: Mapping[str, Any] = self.config_entry.options or {} - schema = ollama_config_option_schema(self.hass, options) + schema = ollama_config_option_schema(self.hass, self._is_new, options) return self.async_show_form( - step_id="init", - data_schema=vol.Schema(schema), + step_id="set_options", data_schema=vol.Schema(schema), errors=errors ) + async_step_user = async_step_set_options + async_step_reconfigure = async_step_set_options + def ollama_config_option_schema( - hass: HomeAssistant, options: Mapping[str, Any] + hass: HomeAssistant, is_new: bool, options: Mapping[str, Any] ) -> dict: """Ollama options schema.""" hass_apis: list[SelectOptionDict] = [ @@ -242,54 +283,72 @@ def ollama_config_option_schema( for api in llm.async_get_apis(hass) ] - return { - vol.Optional( - CONF_PROMPT, - description={ - "suggested_value": options.get( - CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT + if is_new: + schema: dict[vol.Required | vol.Optional, Any] = { + vol.Required(CONF_NAME, default="Ollama Conversation"): str, + } + else: + schema = {} + + schema.update( + { + vol.Optional( + CONF_PROMPT, + description={ + "suggested_value": options.get( + CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT + ) + }, + ): TemplateSelector(), + vol.Optional( + CONF_LLM_HASS_API, + description={"suggested_value": options.get(CONF_LLM_HASS_API)}, + ): SelectSelector(SelectSelectorConfig(options=hass_apis, multiple=True)), + vol.Optional( + CONF_NUM_CTX, + description={ + "suggested_value": options.get(CONF_NUM_CTX, DEFAULT_NUM_CTX) + }, + ): NumberSelector( + NumberSelectorConfig( + min=MIN_NUM_CTX, + max=MAX_NUM_CTX, + step=1, + mode=NumberSelectorMode.BOX, ) - }, - ): TemplateSelector(), - vol.Optional( - CONF_LLM_HASS_API, - description={"suggested_value": options.get(CONF_LLM_HASS_API)}, - ): SelectSelector(SelectSelectorConfig(options=hass_apis, multiple=True)), - vol.Optional( - CONF_NUM_CTX, - description={"suggested_value": options.get(CONF_NUM_CTX, DEFAULT_NUM_CTX)}, - ): NumberSelector( - NumberSelectorConfig( - min=MIN_NUM_CTX, max=MAX_NUM_CTX, step=1, mode=NumberSelectorMode.BOX - ) - ), - vol.Optional( - CONF_MAX_HISTORY, - description={ - "suggested_value": options.get(CONF_MAX_HISTORY, DEFAULT_MAX_HISTORY) - }, - ): NumberSelector( - NumberSelectorConfig( - min=0, max=sys.maxsize, step=1, mode=NumberSelectorMode.BOX - ) - ), - vol.Optional( - CONF_KEEP_ALIVE, - description={ - "suggested_value": options.get(CONF_KEEP_ALIVE, DEFAULT_KEEP_ALIVE) - }, - ): NumberSelector( - NumberSelectorConfig( - min=-1, max=sys.maxsize, step=1, mode=NumberSelectorMode.BOX - ) - ), - vol.Optional( - CONF_THINK, - description={ - "suggested_value": options.get("think", DEFAULT_THINK), - }, - ): BooleanSelector(), - } + ), + vol.Optional( + CONF_MAX_HISTORY, + description={ + "suggested_value": options.get( + CONF_MAX_HISTORY, DEFAULT_MAX_HISTORY + ) + }, + ): NumberSelector( + NumberSelectorConfig( + min=0, max=sys.maxsize, step=1, mode=NumberSelectorMode.BOX + ) + ), + vol.Optional( + CONF_KEEP_ALIVE, + description={ + "suggested_value": options.get(CONF_KEEP_ALIVE, DEFAULT_KEEP_ALIVE) + }, + ): NumberSelector( + NumberSelectorConfig( + min=-1, max=sys.maxsize, step=1, mode=NumberSelectorMode.BOX + ) + ), + vol.Optional( + CONF_THINK, + description={ + "suggested_value": options.get("think", DEFAULT_THINK), + }, + ): BooleanSelector(), + } + ) + + return schema def _get_title(model: str) -> str: diff --git a/homeassistant/components/ollama/conversation.py b/homeassistant/components/ollama/conversation.py index 1717d0b24b2f46..beedb61f942acc 100644 --- a/homeassistant/components/ollama/conversation.py +++ b/homeassistant/components/ollama/conversation.py @@ -2,7 +2,7 @@ from __future__ import annotations -from collections.abc import AsyncGenerator, Callable +from collections.abc import AsyncGenerator, AsyncIterator, Callable import json import logging from typing import Any, Literal @@ -11,13 +11,14 @@ from voluptuous_openapi import convert from homeassistant.components import assist_pipeline, conversation -from homeassistant.config_entries import ConfigEntry +from homeassistant.config_entries import ConfigEntry, ConfigSubentry from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError -from homeassistant.helpers import intent, llm +from homeassistant.helpers import device_registry as dr, intent, llm from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback +from . import OllamaConfigEntry from .const import ( CONF_KEEP_ALIVE, CONF_MAX_HISTORY, @@ -40,12 +41,18 @@ async def async_setup_entry( hass: HomeAssistant, - config_entry: ConfigEntry, + config_entry: OllamaConfigEntry, async_add_entities: AddConfigEntryEntitiesCallback, ) -> None: """Set up conversation entities.""" - agent = OllamaConversationEntity(config_entry) - async_add_entities([agent]) + for subentry in config_entry.subentries.values(): + if subentry.subentry_type != "conversation": + continue + + async_add_entities( + [OllamaConversationEntity(config_entry, subentry)], + config_subentry_id=subentry.subentry_id, + ) def _format_tool( @@ -130,7 +137,7 @@ def _convert_content( async def _transform_stream( - result: AsyncGenerator[ollama.Message], + result: AsyncIterator[ollama.ChatResponse], ) -> AsyncGenerator[conversation.AssistantContentDeltaDict]: """Transform the response stream into HA format. @@ -174,17 +181,22 @@ class OllamaConversationEntity( ): """Ollama conversation agent.""" - _attr_has_entity_name = True _attr_supports_streaming = True - def __init__(self, entry: ConfigEntry) -> None: + def __init__(self, entry: OllamaConfigEntry, subentry: ConfigSubentry) -> None: """Initialize the agent.""" self.entry = entry - - # conversation id -> message history - self._attr_name = entry.title - self._attr_unique_id = entry.entry_id - if self.entry.options.get(CONF_LLM_HASS_API): + self.subentry = subentry + self._attr_name = subentry.title + self._attr_unique_id = subentry.subentry_id + self._attr_device_info = dr.DeviceInfo( + identifiers={(DOMAIN, subentry.subentry_id)}, + name=subentry.title, + manufacturer="Ollama", + model=entry.data[CONF_MODEL], + entry_type=dr.DeviceEntryType.SERVICE, + ) + if self.subentry.data.get(CONF_LLM_HASS_API): self._attr_supported_features = ( conversation.ConversationEntityFeature.CONTROL ) @@ -216,7 +228,7 @@ async def _async_handle_message( chat_log: conversation.ChatLog, ) -> conversation.ConversationResult: """Call the API.""" - settings = {**self.entry.data, **self.entry.options} + settings = {**self.entry.data, **self.subentry.data} try: await chat_log.async_provide_llm_data( @@ -248,9 +260,9 @@ async def _async_handle_chat_log( chat_log: conversation.ChatLog, ) -> None: """Generate an answer for the chat log.""" - settings = {**self.entry.data, **self.entry.options} + settings = {**self.entry.data, **self.subentry.data} - client = self.hass.data[DOMAIN][self.entry.entry_id] + client = self.entry.runtime_data model = settings[CONF_MODEL] tools: list[dict[str, Any]] | None = None diff --git a/homeassistant/components/ollama/strings.json b/homeassistant/components/ollama/strings.json index c60b0ef7ebdaae..74a5eaff454582 100644 --- a/homeassistant/components/ollama/strings.json +++ b/homeassistant/components/ollama/strings.json @@ -12,7 +12,8 @@ } }, "abort": { - "download_failed": "Model downloading failed" + "download_failed": "Model downloading failed", + "already_configured": "[%key:common::config_flow::abort::already_configured_service%]" }, "error": { "cannot_connect": "[%key:common::config_flow::error::cannot_connect%]", @@ -22,23 +23,35 @@ "download": "Please wait while the model is downloaded, which may take a very long time. Check your Ollama server logs for more details." } }, - "options": { - "step": { - "init": { - "data": { - "prompt": "Instructions", - "llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]", - "max_history": "Max history messages", - "num_ctx": "Context window size", - "keep_alive": "Keep alive", - "think": "Think before responding" - }, - "data_description": { - "prompt": "Instruct how the LLM should respond. This can be a template.", - "keep_alive": "Duration in seconds for Ollama to keep model in memory. -1 = indefinite, 0 = never.", - "num_ctx": "Maximum number of text tokens the model can process. Lower to reduce Ollama RAM, or increase for a large number of exposed entities.", - "think": "If enabled, the LLM will think before responding. This can improve response quality but may increase latency." + "config_subentries": { + "conversation": { + "initiate_flow": { + "user": "Add conversation agent", + "reconfigure": "Reconfigure conversation agent" + }, + "entry_type": "Conversation agent", + "step": { + "set_options": { + "data": { + "name": "[%key:common::config_flow::data::name%]", + "prompt": "Instructions", + "llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]", + "max_history": "Max history messages", + "num_ctx": "Context window size", + "keep_alive": "Keep alive", + "think": "Think before responding" + }, + "data_description": { + "prompt": "Instruct how the LLM should respond. This can be a template.", + "keep_alive": "Duration in seconds for Ollama to keep model in memory. -1 = indefinite, 0 = never.", + "num_ctx": "Maximum number of text tokens the model can process. Lower to reduce Ollama RAM, or increase for a large number of exposed entities.", + "think": "If enabled, the LLM will think before responding. This can improve response quality but may increase latency." + } } + }, + "abort": { + "reconfigure_successful": "[%key:common::config_flow::abort::reconfigure_successful%]", + "entry_not_loaded": "Cannot add things while the configuration is disabled." } } } diff --git a/homeassistant/components/openai_conversation/__init__.py b/homeassistant/components/openai_conversation/__init__.py index 71effe83884b8c..a5b13ded3756a1 100644 --- a/homeassistant/components/openai_conversation/__init__.py +++ b/homeassistant/components/openai_conversation/__init__.py @@ -19,7 +19,7 @@ ) import voluptuous as vol -from homeassistant.config_entries import ConfigEntry +from homeassistant.config_entries import ConfigEntry, ConfigSubentry from homeassistant.const import CONF_API_KEY, Platform from homeassistant.core import ( HomeAssistant, @@ -32,7 +32,12 @@ HomeAssistantError, ServiceValidationError, ) -from homeassistant.helpers import config_validation as cv, selector +from homeassistant.helpers import ( + config_validation as cv, + device_registry as dr, + entity_registry as er, + selector, +) from homeassistant.helpers.httpx_client import get_async_client from homeassistant.helpers.typing import ConfigType @@ -73,6 +78,7 @@ def encode_file(file_path: str) -> tuple[str, str]: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up OpenAI Conversation.""" + await async_migrate_integration(hass) async def render_image(call: ServiceCall) -> ServiceResponse: """Render an image with dall-e.""" @@ -118,7 +124,21 @@ async def send_prompt(call: ServiceCall) -> ServiceResponse: translation_placeholders={"config_entry": entry_id}, ) - model: str = entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL) + # Get first conversation subentry for options + conversation_subentry = next( + ( + sub + for sub in entry.subentries.values() + if sub.subentry_type == "conversation" + ), + None, + ) + if not conversation_subentry: + raise ServiceValidationError("No conversation configuration found") + + model: str = conversation_subentry.data.get( + CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL + ) client: openai.AsyncClient = entry.runtime_data content: ResponseInputMessageContentListParam = [ @@ -169,11 +189,11 @@ def append_files_to_content() -> None: model_args = { "model": model, "input": messages, - "max_output_tokens": entry.options.get( + "max_output_tokens": conversation_subentry.data.get( CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS ), - "top_p": entry.options.get(CONF_TOP_P, RECOMMENDED_TOP_P), - "temperature": entry.options.get( + "top_p": conversation_subentry.data.get(CONF_TOP_P, RECOMMENDED_TOP_P), + "temperature": conversation_subentry.data.get( CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE ), "user": call.context.user_id, @@ -182,7 +202,7 @@ def append_files_to_content() -> None: if model.startswith("o"): model_args["reasoning"] = { - "effort": entry.options.get( + "effort": conversation_subentry.data.get( CONF_REASONING_EFFORT, RECOMMENDED_REASONING_EFFORT ) } @@ -269,3 +289,68 @@ async def async_setup_entry(hass: HomeAssistant, entry: OpenAIConfigEntry) -> bo async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Unload OpenAI.""" return await hass.config_entries.async_unload_platforms(entry, PLATFORMS) + + +async def async_migrate_integration(hass: HomeAssistant) -> None: + """Migrate integration entry structure.""" + + entries = hass.config_entries.async_entries(DOMAIN) + if not any(entry.version == 1 for entry in entries): + return + + api_keys_entries: dict[str, ConfigEntry] = {} + entity_registry = er.async_get(hass) + device_registry = dr.async_get(hass) + + for entry in entries: + use_existing = False + subentry = ConfigSubentry( + data=entry.options, + subentry_type="conversation", + title=entry.title, + unique_id=None, + ) + if entry.data[CONF_API_KEY] not in api_keys_entries: + use_existing = True + api_keys_entries[entry.data[CONF_API_KEY]] = entry + + parent_entry = api_keys_entries[entry.data[CONF_API_KEY]] + + hass.config_entries.async_add_subentry(parent_entry, subentry) + conversation_entity = entity_registry.async_get_entity_id( + "conversation", + DOMAIN, + entry.entry_id, + ) + if conversation_entity is not None: + entity_registry.async_update_entity( + conversation_entity, + config_entry_id=parent_entry.entry_id, + config_subentry_id=subentry.subentry_id, + new_unique_id=subentry.subentry_id, + ) + + device = device_registry.async_get_device( + identifiers={(DOMAIN, entry.entry_id)} + ) + if device is not None: + device_registry.async_update_device( + device.id, + new_identifiers={(DOMAIN, subentry.subentry_id)}, + add_config_subentry_id=subentry.subentry_id, + add_config_entry_id=parent_entry.entry_id, + ) + if parent_entry.entry_id != entry.entry_id: + device_registry.async_update_device( + device.id, + remove_config_entry_id=entry.entry_id, + ) + + if not use_existing: + await hass.config_entries.async_remove(entry.entry_id) + else: + hass.config_entries.async_update_entry( + entry, + options={}, + version=2, + ) diff --git a/homeassistant/components/openai_conversation/config_flow.py b/homeassistant/components/openai_conversation/config_flow.py index 60d81bf6745c55..a9a444cf3ddc7b 100644 --- a/homeassistant/components/openai_conversation/config_flow.py +++ b/homeassistant/components/openai_conversation/config_flow.py @@ -13,17 +13,20 @@ from homeassistant.components.zone import ENTITY_ID_HOME from homeassistant.config_entries import ( ConfigEntry, + ConfigEntryState, ConfigFlow, ConfigFlowResult, - OptionsFlow, + ConfigSubentryFlow, + SubentryFlowResult, ) from homeassistant.const import ( ATTR_LATITUDE, ATTR_LONGITUDE, CONF_API_KEY, CONF_LLM_HASS_API, + CONF_NAME, ) -from homeassistant.core import HomeAssistant +from homeassistant.core import HomeAssistant, callback from homeassistant.helpers import llm from homeassistant.helpers.httpx_client import get_async_client from homeassistant.helpers.selector import ( @@ -52,6 +55,7 @@ CONF_WEB_SEARCH_REGION, CONF_WEB_SEARCH_TIMEZONE, CONF_WEB_SEARCH_USER_LOCATION, + DEFAULT_CONVERSATION_NAME, DOMAIN, RECOMMENDED_CHAT_MODEL, RECOMMENDED_MAX_TOKENS, @@ -94,7 +98,7 @@ async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None: class OpenAIConfigFlow(ConfigFlow, domain=DOMAIN): """Handle a config flow for OpenAI Conversation.""" - VERSION = 1 + VERSION = 2 async def async_step_user( self, user_input: dict[str, Any] | None = None @@ -107,6 +111,7 @@ async def async_step_user( errors: dict[str, str] = {} + self._async_abort_entries_match(user_input) try: await validate_input(self.hass, user_input) except openai.APIConnectionError: @@ -120,32 +125,61 @@ async def async_step_user( return self.async_create_entry( title="ChatGPT", data=user_input, - options=RECOMMENDED_OPTIONS, + subentries=[ + { + "subentry_type": "conversation", + "data": RECOMMENDED_OPTIONS, + "title": DEFAULT_CONVERSATION_NAME, + "unique_id": None, + } + ], ) return self.async_show_form( step_id="user", data_schema=STEP_USER_DATA_SCHEMA, errors=errors ) - @staticmethod - def async_get_options_flow( - config_entry: ConfigEntry, - ) -> OptionsFlow: - """Create the options flow.""" - return OpenAIOptionsFlow(config_entry) + @classmethod + @callback + def async_get_supported_subentry_types( + cls, config_entry: ConfigEntry + ) -> dict[str, type[ConfigSubentryFlow]]: + """Return subentries supported by this integration.""" + return {"conversation": ConversationSubentryFlowHandler} + +class ConversationSubentryFlowHandler(ConfigSubentryFlow): + """Flow for managing conversation subentries.""" -class OpenAIOptionsFlow(OptionsFlow): - """OpenAI config flow options handler.""" + last_rendered_recommended = False + options: dict[str, Any] + + @property + def _is_new(self) -> bool: + """Return if this is a new subentry.""" + return self.source == "user" + + async def async_step_user( + self, user_input: dict[str, Any] | None = None + ) -> SubentryFlowResult: + """Add a subentry.""" + self.options = RECOMMENDED_OPTIONS.copy() + return await self.async_step_init() - def __init__(self, config_entry: ConfigEntry) -> None: - """Initialize options flow.""" - self.options = config_entry.options.copy() + async def async_step_reconfigure( + self, user_input: dict[str, Any] | None = None + ) -> SubentryFlowResult: + """Handle reconfiguration of a subentry.""" + self.options = self._get_reconfigure_subentry().data.copy() + return await self.async_step_init() async def async_step_init( self, user_input: dict[str, Any] | None = None - ) -> ConfigFlowResult: + ) -> SubentryFlowResult: """Manage initial options.""" + # abort if entry is not loaded + if self._get_entry().state != ConfigEntryState.LOADED: + return self.async_abort(reason="entry_not_loaded") options = self.options hass_apis: list[SelectOptionDict] = [ @@ -160,25 +194,47 @@ async def async_step_init( ): options[CONF_LLM_HASS_API] = [suggested_llm_apis] - step_schema: VolDictType = { - vol.Optional( - CONF_PROMPT, - description={"suggested_value": llm.DEFAULT_INSTRUCTIONS_PROMPT}, - ): TemplateSelector(), - vol.Optional(CONF_LLM_HASS_API): SelectSelector( - SelectSelectorConfig(options=hass_apis, multiple=True) - ), - vol.Required( - CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False) - ): bool, - } + step_schema: VolDictType = {} + + if self._is_new: + step_schema[vol.Required(CONF_NAME, default=DEFAULT_CONVERSATION_NAME)] = ( + str + ) + + step_schema.update( + { + vol.Optional( + CONF_PROMPT, + description={ + "suggested_value": options.get( + CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT + ) + }, + ): TemplateSelector(), + vol.Optional(CONF_LLM_HASS_API): SelectSelector( + SelectSelectorConfig(options=hass_apis, multiple=True) + ), + vol.Required( + CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False) + ): bool, + } + ) if user_input is not None: if not user_input.get(CONF_LLM_HASS_API): user_input.pop(CONF_LLM_HASS_API, None) if user_input[CONF_RECOMMENDED]: - return self.async_create_entry(title="", data=user_input) + if self._is_new: + return self.async_create_entry( + title=user_input.pop(CONF_NAME), + data=user_input, + ) + return self.async_update_and_abort( + self._get_entry(), + self._get_reconfigure_subentry(), + data=user_input, + ) options.update(user_input) if CONF_LLM_HASS_API in options and CONF_LLM_HASS_API not in user_input: @@ -194,7 +250,7 @@ async def async_step_init( async def async_step_advanced( self, user_input: dict[str, Any] | None = None - ) -> ConfigFlowResult: + ) -> SubentryFlowResult: """Manage advanced options.""" options = self.options errors: dict[str, str] = {} @@ -236,7 +292,7 @@ async def async_step_advanced( async def async_step_model( self, user_input: dict[str, Any] | None = None - ) -> ConfigFlowResult: + ) -> SubentryFlowResult: """Manage model-specific options.""" options = self.options errors: dict[str, str] = {} @@ -303,7 +359,16 @@ async def async_step_model( } if not step_schema: - return self.async_create_entry(title="", data=options) + if self._is_new: + return self.async_create_entry( + title=options.pop(CONF_NAME, DEFAULT_CONVERSATION_NAME), + data=options, + ) + return self.async_update_and_abort( + self._get_entry(), + self._get_reconfigure_subentry(), + data=options, + ) if user_input is not None: if user_input.get(CONF_WEB_SEARCH): @@ -316,7 +381,16 @@ async def async_step_model( options.pop(CONF_WEB_SEARCH_TIMEZONE, None) options.update(user_input) - return self.async_create_entry(title="", data=options) + if self._is_new: + return self.async_create_entry( + title=options.pop(CONF_NAME, DEFAULT_CONVERSATION_NAME), + data=options, + ) + return self.async_update_and_abort( + self._get_entry(), + self._get_reconfigure_subentry(), + data=options, + ) return self.async_show_form( step_id="model", @@ -332,7 +406,7 @@ async def _get_location_data(self) -> dict[str, str]: zone_home = self.hass.states.get(ENTITY_ID_HOME) if zone_home is not None: client = openai.AsyncOpenAI( - api_key=self.config_entry.data[CONF_API_KEY], + api_key=self._get_entry().data[CONF_API_KEY], http_client=get_async_client(self.hass), ) location_schema = vol.Schema( diff --git a/homeassistant/components/openai_conversation/const.py b/homeassistant/components/openai_conversation/const.py index f022b4840eb75d..f90c05eed79a70 100644 --- a/homeassistant/components/openai_conversation/const.py +++ b/homeassistant/components/openai_conversation/const.py @@ -5,6 +5,8 @@ DOMAIN = "openai_conversation" LOGGER: logging.Logger = logging.getLogger(__package__) +DEFAULT_CONVERSATION_NAME = "OpenAI Conversation" + CONF_CHAT_MODEL = "chat_model" CONF_FILENAMES = "filenames" CONF_MAX_TOKENS = "max_tokens" diff --git a/homeassistant/components/openai_conversation/conversation.py b/homeassistant/components/openai_conversation/conversation.py index 8fea4613ce0400..e63bbf32c3504f 100644 --- a/homeassistant/components/openai_conversation/conversation.py +++ b/homeassistant/components/openai_conversation/conversation.py @@ -34,7 +34,7 @@ from voluptuous_openapi import convert from homeassistant.components import assist_pipeline, conversation -from homeassistant.config_entries import ConfigEntry +from homeassistant.config_entries import ConfigEntry, ConfigSubentry from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError @@ -76,8 +76,14 @@ async def async_setup_entry( async_add_entities: AddConfigEntryEntitiesCallback, ) -> None: """Set up conversation entities.""" - agent = OpenAIConversationEntity(config_entry) - async_add_entities([agent]) + for subentry in config_entry.subentries.values(): + if subentry.subentry_type != "conversation": + continue + + async_add_entities( + [OpenAIConversationEntity(config_entry, subentry)], + config_subentry_id=subentry.subentry_id, + ) def _format_tool( @@ -229,22 +235,22 @@ class OpenAIConversationEntity( ): """OpenAI conversation agent.""" - _attr_has_entity_name = True - _attr_name = None _attr_supports_streaming = True - def __init__(self, entry: OpenAIConfigEntry) -> None: + def __init__(self, entry: OpenAIConfigEntry, subentry: ConfigSubentry) -> None: """Initialize the agent.""" self.entry = entry - self._attr_unique_id = entry.entry_id + self.subentry = subentry + self._attr_name = subentry.title + self._attr_unique_id = subentry.subentry_id self._attr_device_info = dr.DeviceInfo( - identifiers={(DOMAIN, entry.entry_id)}, - name=entry.title, + identifiers={(DOMAIN, subentry.subentry_id)}, + name=subentry.title, manufacturer="OpenAI", - model="ChatGPT", + model=entry.data.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL), entry_type=dr.DeviceEntryType.SERVICE, ) - if self.entry.options.get(CONF_LLM_HASS_API): + if self.subentry.data.get(CONF_LLM_HASS_API): self._attr_supported_features = ( conversation.ConversationEntityFeature.CONTROL ) @@ -276,7 +282,7 @@ async def _async_handle_message( chat_log: conversation.ChatLog, ) -> conversation.ConversationResult: """Process the user input and call the API.""" - options = self.entry.options + options = self.subentry.data try: await chat_log.async_provide_llm_data( @@ -304,7 +310,7 @@ async def _async_handle_chat_log( chat_log: conversation.ChatLog, ) -> None: """Generate an answer for the chat log.""" - options = self.entry.options + options = self.subentry.data tools: list[ToolParam] | None = None if chat_log.llm_api: diff --git a/homeassistant/components/openai_conversation/strings.json b/homeassistant/components/openai_conversation/strings.json index 351e82ec11f294..ffbe84337b7f53 100644 --- a/homeassistant/components/openai_conversation/strings.json +++ b/homeassistant/components/openai_conversation/strings.json @@ -11,47 +11,63 @@ "cannot_connect": "[%key:common::config_flow::error::cannot_connect%]", "invalid_auth": "[%key:common::config_flow::error::invalid_auth%]", "unknown": "[%key:common::config_flow::error::unknown%]" + }, + "abort": { + "already_configured": "[%key:common::config_flow::abort::already_configured_service%]" } }, - "options": { - "step": { - "init": { - "data": { - "prompt": "Instructions", - "llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]", - "recommended": "Recommended model settings" + "config_subentries": { + "conversation": { + "initiate_flow": { + "user": "Add conversation agent", + "reconfigure": "Reconfigure conversation agent" + }, + "entry_type": "Conversation agent", + + "step": { + "init": { + "data": { + "name": "[%key:common::config_flow::data::name%]", + "prompt": "Instructions", + "llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]", + "recommended": "Recommended model settings" + }, + "data_description": { + "prompt": "Instruct how the LLM should respond. This can be a template." + } + }, + "advanced": { + "title": "Advanced settings", + "data": { + "chat_model": "[%key:common::generic::model%]", + "max_tokens": "Maximum tokens to return in response", + "temperature": "Temperature", + "top_p": "Top P" + } }, - "data_description": { - "prompt": "Instruct how the LLM should respond. This can be a template." + "model": { + "title": "Model-specific options", + "data": { + "reasoning_effort": "Reasoning effort", + "web_search": "Enable web search", + "search_context_size": "Search context size", + "user_location": "Include home location" + }, + "data_description": { + "reasoning_effort": "How many reasoning tokens the model should generate before creating a response to the prompt", + "web_search": "Allow the model to search the web for the latest information before generating a response", + "search_context_size": "High level guidance for the amount of context window space to use for the search", + "user_location": "Refine search results based on geography" + } } }, - "advanced": { - "title": "Advanced settings", - "data": { - "chat_model": "[%key:common::generic::model%]", - "max_tokens": "Maximum tokens to return in response", - "temperature": "Temperature", - "top_p": "Top P" - } + "abort": { + "reconfigure_successful": "[%key:common::config_flow::abort::reconfigure_successful%]", + "entry_not_loaded": "Cannot add things while the configuration is disabled." }, - "model": { - "title": "Model-specific options", - "data": { - "reasoning_effort": "Reasoning effort", - "web_search": "Enable web search", - "search_context_size": "Search context size", - "user_location": "Include home location" - }, - "data_description": { - "reasoning_effort": "How many reasoning tokens the model should generate before creating a response to the prompt", - "web_search": "Allow the model to search the web for the latest information before generating a response", - "search_context_size": "High level guidance for the amount of context window space to use for the search", - "user_location": "Refine search results based on geography" - } + "error": { + "model_not_supported": "This model is not supported, please select a different model" } - }, - "error": { - "model_not_supported": "This model is not supported, please select a different model" } }, "selector": { diff --git a/homeassistant/components/samsungtv/manifest.json b/homeassistant/components/samsungtv/manifest.json index dc8133a1b1ffbc..a2ab8e6e466ac0 100644 --- a/homeassistant/components/samsungtv/manifest.json +++ b/homeassistant/components/samsungtv/manifest.json @@ -34,6 +34,7 @@ "integration_type": "device", "iot_class": "local_push", "loggers": ["samsungctl", "samsungtvws"], + "quality_scale": "bronze", "requirements": [ "getmac==0.9.5", "samsungctl[websocket]==0.7.1", diff --git a/homeassistant/components/samsungtv/quality_scale.yaml b/homeassistant/components/samsungtv/quality_scale.yaml new file mode 100644 index 00000000000000..845ebfe6e464c1 --- /dev/null +++ b/homeassistant/components/samsungtv/quality_scale.yaml @@ -0,0 +1,96 @@ +rules: + # Bronze + action-setup: + status: exempt + comment: no custom actions + appropriate-polling: done + brands: done + common-modules: done + config-flow-test-coverage: done + config-flow: done + dependency-transparency: done + docs-actions: + status: exempt + comment: no actions + docs-high-level-description: done + docs-installation-instructions: done + docs-removal-instructions: done + entity-event-setup: + status: exempt + comment: no events + entity-unique-id: done + has-entity-name: done + runtime-data: done + test-before-configure: done + test-before-setup: done + unique-config-entry: done + + # Silver + action-exceptions: done + config-entry-unloading: done + docs-configuration-parameters: + status: exempt + comment: no configuration options so far + docs-installation-parameters: done + entity-unavailable: + status: todo + comment: check super().unavailable + integration-owner: done + log-when-unavailable: done + parallel-updates: done + reauthentication-flow: done + test-coverage: done + + # Gold + devices: done + diagnostics: done + discovery-update-info: done + discovery: done + docs-data-update: + status: todo + comment: add info about polling the bridge every 10 seconds + docs-examples: done + docs-known-limitations: done + docs-supported-devices: + status: todo + comment: be more specific about supported devices + docs-supported-functions: + status: todo + comment: be more specific about supported functions + docs-troubleshooting: + status: todo + comment: split that up to proper troubleshooting and known limitations section + docs-use-cases: done + dynamic-devices: + status: exempt + comment: device type integration + entity-category: + status: exempt + comment: no config or diagnostic entities + entity-device-class: done + entity-disabled-by-default: + status: exempt + comment: only 2 main entities + entity-translations: + status: exempt + comment: using only device name + exception-translations: done + icon-translations: + status: done + comment: no custom icons, only default icons + reconfiguration-flow: + status: todo + comment: handle at least host change + repair-issues: + status: exempt + comment: no known repair use case so far + stale-devices: + status: exempt + comment: device type integration + + # Platinum + async-dependency: done + inject-websession: done + strict-typing: + status: todo + comment: Requirements 'getmac==0.9.5', 'samsungctl[websocket]==0.7.1' and 'wakeonlan==2.1.0' appear untyped diff --git a/homeassistant/components/shelly/manifest.json b/homeassistant/components/shelly/manifest.json index 78e01e6d8a66d5..c6a255b1bbb1c7 100644 --- a/homeassistant/components/shelly/manifest.json +++ b/homeassistant/components/shelly/manifest.json @@ -9,7 +9,7 @@ "iot_class": "local_push", "loggers": ["aioshelly"], "quality_scale": "silver", - "requirements": ["aioshelly==13.6.0"], + "requirements": ["aioshelly==13.7.0"], "zeroconf": [ { "type": "_http._tcp.local.", diff --git a/homeassistant/components/wiz/__init__.py b/homeassistant/components/wiz/__init__.py index 0e986aaefa2c59..43a9b863d202e1 100644 --- a/homeassistant/components/wiz/__init__.py +++ b/homeassistant/components/wiz/__init__.py @@ -63,12 +63,12 @@ async def _async_discovery(*_: Any) -> None: return True -async def _async_update_listener(hass: HomeAssistant, entry: ConfigEntry) -> None: +async def _async_update_listener(hass: HomeAssistant, entry: WizConfigEntry) -> None: """Handle options update.""" await hass.config_entries.async_reload(entry.entry_id) -async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: +async def async_setup_entry(hass: HomeAssistant, entry: WizConfigEntry) -> bool: """Set up the wiz integration from a config entry.""" ip_address = entry.data[CONF_HOST] _LOGGER.debug("Get bulb with IP: %s", ip_address) @@ -145,7 +145,7 @@ def _async_push_update(state: PilotParser) -> None: return True -async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: +async def async_unload_entry(hass: HomeAssistant, entry: WizConfigEntry) -> bool: """Unload a config entry.""" if unload_ok := await hass.config_entries.async_unload_platforms(entry, PLATFORMS): await entry.runtime_data.bulb.async_close() diff --git a/homeassistant/components/wyoming/assist_satellite.py b/homeassistant/components/wyoming/assist_satellite.py index 1a1a67bf1de169..03470dbe555892 100644 --- a/homeassistant/components/wyoming/assist_satellite.py +++ b/homeassistant/components/wyoming/assist_satellite.py @@ -132,6 +132,10 @@ def __init__( # Used to ensure TTS timeout is acted on correctly. self._run_loop_id: str | None = None + # TTS streaming + self._tts_stream_token: str | None = None + self._is_tts_streaming: bool = False + @property def pipeline_entity_id(self) -> str | None: """Return the entity ID of the pipeline to use for the next conversation.""" @@ -179,11 +183,20 @@ def on_pipeline_event(self, event: PipelineEvent) -> None: """Set state based on pipeline stage.""" assert self._client is not None - if event.type == assist_pipeline.PipelineEventType.RUN_END: + if event.type == assist_pipeline.PipelineEventType.RUN_START: + if event.data and (tts_output := event.data["tts_output"]): + # Get stream token early. + # If "tts_start_streaming" is True in INTENT_PROGRESS event, we + # can start streaming TTS before the TTS_END event. + self._tts_stream_token = tts_output["token"] + self._is_tts_streaming = False + elif event.type == assist_pipeline.PipelineEventType.RUN_END: # Pipeline run is complete self._is_pipeline_running = False self._pipeline_ended_event.set() self.device.set_is_active(False) + self._tts_stream_token = None + self._is_tts_streaming = False elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_START: self.config_entry.async_create_background_task( self.hass, @@ -245,6 +258,20 @@ def on_pipeline_event(self, event: PipelineEvent) -> None: self._client.write_event(Transcript(text=stt_text).event()), f"{self.entity_id} {event.type}", ) + elif event.type == assist_pipeline.PipelineEventType.INTENT_PROGRESS: + if ( + event.data + and event.data.get("tts_start_streaming") + and self._tts_stream_token + and (stream := tts.async_get_stream(self.hass, self._tts_stream_token)) + ): + # Start streaming TTS early (before TTS_END). + self._is_tts_streaming = True + self.config_entry.async_create_background_task( + self.hass, + self._stream_tts(stream), + f"{self.entity_id} {event.type}", + ) elif event.type == assist_pipeline.PipelineEventType.TTS_START: # Text-to-speech text if event.data: @@ -267,8 +294,10 @@ def on_pipeline_event(self, event: PipelineEvent) -> None: if ( event.data and (tts_output := event.data["tts_output"]) + and not self._is_tts_streaming and (stream := tts.async_get_stream(self.hass, tts_output["token"])) ): + # Send TTS only if we haven't already started streaming it in INTENT_PROGRESS. self.config_entry.async_create_background_task( self.hass, self._stream_tts(stream), @@ -711,39 +740,62 @@ async def _stream_tts(self, tts_result: tts.ResultStream) -> None: start_time = time.monotonic() try: - data = b"".join([chunk async for chunk in tts_result.async_stream_result()]) - - with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file: - sample_rate = wav_file.getframerate() - sample_width = wav_file.getsampwidth() - sample_channels = wav_file.getnchannels() - _LOGGER.debug("Streaming %s TTS sample(s)", wav_file.getnframes()) - - timestamp = 0 - await self._client.write_event( - AudioStart( - rate=sample_rate, - width=sample_width, - channels=sample_channels, - timestamp=timestamp, - ).event() + header_data = b"" + header_complete = False + sample_rate: int | None = None + sample_width: int | None = None + sample_channels: int | None = None + timestamp = 0 + + async for data_chunk in tts_result.async_stream_result(): + if not header_complete: + # Accumulate data until we can parse the header and get + # sample rate, etc. + header_data += data_chunk + # Most WAVE headers are 44 bytes in length + if (len(header_data) >= 44) and ( + audio_info := _try_parse_wav_header(header_data) + ): + # Overwrite chunk with audio after header + sample_rate, sample_width, sample_channels, data_chunk = ( + audio_info + ) + await self._client.write_event( + AudioStart( + rate=sample_rate, + width=sample_width, + channels=sample_channels, + timestamp=timestamp, + ).event() + ) + header_complete = True + + if not data_chunk: + # No audio after header + continue + else: + # Header is incomplete + continue + + # Streaming audio + assert sample_rate is not None + assert sample_width is not None + assert sample_channels is not None + + audio_chunk = AudioChunk( + rate=sample_rate, + width=sample_width, + channels=sample_channels, + audio=data_chunk, + timestamp=timestamp, ) - # Stream audio chunks - while audio_bytes := wav_file.readframes(_SAMPLES_PER_CHUNK): - chunk = AudioChunk( - rate=sample_rate, - width=sample_width, - channels=sample_channels, - audio=audio_bytes, - timestamp=timestamp, - ) - await self._client.write_event(chunk.event()) - timestamp += chunk.milliseconds - total_seconds += chunk.seconds + await self._client.write_event(audio_chunk.event()) + timestamp += audio_chunk.milliseconds + total_seconds += audio_chunk.seconds - await self._client.write_event(AudioStop(timestamp=timestamp).event()) - _LOGGER.debug("TTS streaming complete") + await self._client.write_event(AudioStop(timestamp=timestamp).event()) + _LOGGER.debug("TTS streaming complete") finally: send_duration = time.monotonic() - start_time timeout_seconds = max(0, total_seconds - send_duration + _TTS_TIMEOUT_EXTRA) @@ -812,3 +864,25 @@ def _handle_timer( self.config_entry.async_create_background_task( self.hass, self._client.write_event(event), "wyoming timer event" ) + + +def _try_parse_wav_header(header_data: bytes) -> tuple[int, int, int, bytes] | None: + """Try to parse a WAV header from a buffer. + + If successful, return (rate, width, channels, audio). + """ + try: + with io.BytesIO(header_data) as wav_io: + wav_file: wave.Wave_read = wave.open(wav_io, "rb") + with wav_file: + return ( + wav_file.getframerate(), + wav_file.getsampwidth(), + wav_file.getnchannels(), + wav_file.readframes(wav_file.getnframes()), + ) + except wave.Error: + # Ignore errors and return None + pass + + return None diff --git a/homeassistant/components/wyoming/tts.py b/homeassistant/components/wyoming/tts.py index 79e431fee98381..cf088c04d9f1b3 100644 --- a/homeassistant/components/wyoming/tts.py +++ b/homeassistant/components/wyoming/tts.py @@ -1,13 +1,21 @@ """Support for Wyoming text-to-speech services.""" from collections import defaultdict +from collections.abc import AsyncGenerator import io import logging import wave -from wyoming.audio import AudioChunk, AudioStop +from wyoming.audio import AudioChunk, AudioStart, AudioStop from wyoming.client import AsyncTcpClient -from wyoming.tts import Synthesize, SynthesizeVoice +from wyoming.tts import ( + Synthesize, + SynthesizeChunk, + SynthesizeStart, + SynthesizeStop, + SynthesizeStopped, + SynthesizeVoice, +) from homeassistant.components import tts from homeassistant.config_entries import ConfigEntry @@ -45,6 +53,7 @@ def __init__( service: WyomingService, ) -> None: """Set up provider.""" + self.config_entry = config_entry self.service = service self._tts_service = next(tts for tts in service.info.tts if tts.installed) @@ -150,3 +159,98 @@ async def async_get_tts_audio(self, message, language, options): return (None, None) return ("wav", data) + + def async_supports_streaming_input(self) -> bool: + """Return if the TTS engine supports streaming input.""" + return self._tts_service.supports_synthesize_streaming + + async def async_stream_tts_audio( + self, request: tts.TTSAudioRequest + ) -> tts.TTSAudioResponse: + """Generate speech from an incoming message.""" + voice_name: str | None = request.options.get(tts.ATTR_VOICE) + voice_speaker: str | None = request.options.get(ATTR_SPEAKER) + voice: SynthesizeVoice | None = None + if voice_name is not None: + voice = SynthesizeVoice(name=voice_name, speaker=voice_speaker) + + client = AsyncTcpClient(self.service.host, self.service.port) + await client.connect() + + # Stream text chunks to client + self.config_entry.async_create_background_task( + self.hass, + self._write_tts_message(request.message_gen, client, voice), + "wyoming tts write", + ) + + async def data_gen(): + # Stream audio bytes from client + try: + async for data_chunk in self._read_tts_audio(client): + yield data_chunk + finally: + await client.disconnect() + + return tts.TTSAudioResponse("wav", data_gen()) + + async def _write_tts_message( + self, + message_gen: AsyncGenerator[str], + client: AsyncTcpClient, + voice: SynthesizeVoice | None, + ) -> None: + """Write text chunks to the client.""" + try: + # Start stream + await client.write_event(SynthesizeStart(voice=voice).event()) + + # Accumulate entire message for synthesize event. + message = "" + async for message_chunk in message_gen: + message += message_chunk + + await client.write_event(SynthesizeChunk(text=message_chunk).event()) + + # Send entire message for backwards compatibility + await client.write_event(Synthesize(text=message, voice=voice).event()) + + # End stream + await client.write_event(SynthesizeStop().event()) + except (OSError, WyomingError): + # Disconnected + _LOGGER.warning("Unexpected disconnection from TTS client") + + async def _read_tts_audio(self, client: AsyncTcpClient) -> AsyncGenerator[bytes]: + """Read audio events from the client and yield WAV audio chunks. + + The WAV header is sent first with a frame count of 0 to indicate that + we're streaming and don't know the number of frames ahead of time. + """ + wav_header_sent = False + + try: + while event := await client.read_event(): + if wav_header_sent and AudioChunk.is_type(event.type): + # PCM audio + yield AudioChunk.from_event(event).audio + elif (not wav_header_sent) and AudioStart.is_type(event.type): + # WAV header with nframes = 0 for streaming + audio_start = AudioStart.from_event(event) + with io.BytesIO() as wav_io: + wav_file: wave.Wave_write = wave.open(wav_io, "wb") + with wav_file: + wav_file.setframerate(audio_start.rate) + wav_file.setsampwidth(audio_start.width) + wav_file.setnchannels(audio_start.channels) + + wav_io.seek(0) + yield wav_io.getvalue() + + wav_header_sent = True + elif SynthesizeStopped.is_type(event.type): + # All TTS audio has been received + break + except (OSError, WyomingError): + # Disconnected + _LOGGER.warning("Unexpected disconnection from TTS client") diff --git a/homeassistant/helpers/selector.py b/homeassistant/helpers/selector.py index 6f8df828c37d7b..acb91ddc148bfa 100644 --- a/homeassistant/helpers/selector.py +++ b/homeassistant/helpers/selector.py @@ -1010,9 +1010,11 @@ def __call__(self, data: Any) -> dict[str, float]: return location -class MediaSelectorConfig(BaseSelectorConfig): +class MediaSelectorConfig(BaseSelectorConfig, total=False): """Class to represent a media selector config.""" + accept: list[str] + @SELECTORS.register("media") class MediaSelector(Selector[MediaSelectorConfig]): diff --git a/requirements_all.txt b/requirements_all.txt index 80f543f790f1d4..b9712860354d1c 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -381,7 +381,7 @@ aioruuvigateway==0.1.0 aiosenz==1.0.0 # homeassistant.components.shelly -aioshelly==13.6.0 +aioshelly==13.7.0 # homeassistant.components.skybell aioskybell==22.7.0 @@ -2121,7 +2121,7 @@ pylibrespot-java==0.1.1 pylitejet==0.6.3 # homeassistant.components.litterrobot -pylitterbot==2024.2.0 +pylitterbot==2024.2.1 # homeassistant.components.lutron_caseta pylutron-caseta==0.24.0 diff --git a/requirements_test_all.txt b/requirements_test_all.txt index e1a546dfe2fe85..d2055bba4ef796 100644 --- a/requirements_test_all.txt +++ b/requirements_test_all.txt @@ -363,7 +363,7 @@ aioruuvigateway==0.1.0 aiosenz==1.0.0 # homeassistant.components.shelly -aioshelly==13.6.0 +aioshelly==13.7.0 # homeassistant.components.skybell aioskybell==22.7.0 @@ -1763,7 +1763,7 @@ pylibrespot-java==0.1.1 pylitejet==0.6.3 # homeassistant.components.litterrobot -pylitterbot==2024.2.0 +pylitterbot==2024.2.1 # homeassistant.components.lutron_caseta pylutron-caseta==0.24.0 diff --git a/script/hassfest/quality_scale.py b/script/hassfest/quality_scale.py index 73505e805bc3d2..ff6fbcad85e2ea 100644 --- a/script/hassfest/quality_scale.py +++ b/script/hassfest/quality_scale.py @@ -865,7 +865,6 @@ class Rule: "ruuvitag_ble", "rympro", "saj", - "samsungtv", "sanix", "satel_integra", "schlage", @@ -1573,7 +1572,6 @@ class Rule: "iqvia", "irish_rail_transport", "isal", - "ista_ecotrend", "iskra", "islamic_prayer_times", "israel_rail", @@ -1926,7 +1924,6 @@ class Rule: "ruuvitag_ble", "rympro", "saj", - "samsungtv", "sanix", "satel_integra", "schlage", diff --git a/script/hassfest/translations.py b/script/hassfest/translations.py index f4c05f504ca626..34c06abb4513f3 100644 --- a/script/hassfest/translations.py +++ b/script/hassfest/translations.py @@ -306,10 +306,11 @@ def gen_strings_schema(config: Config, integration: Integration) -> vol.Schema: ), vol.Optional("selector"): cv.schema_with_slug_keys( { - "options": cv.schema_with_slug_keys( + vol.Optional("options"): cv.schema_with_slug_keys( translation_value_validator, slug_validator=translation_key_validator, - ) + ), + vol.Optional("fields"): cv.schema_with_slug_keys(str), }, slug_validator=vol.Any("_", cv.slug), ), diff --git a/tests/components/assist_pipeline/test_pipeline.py b/tests/components/assist_pipeline/test_pipeline.py index 9ea3802d9f6e1f..1302925dab9aa9 100644 --- a/tests/components/assist_pipeline/test_pipeline.py +++ b/tests/components/assist_pipeline/test_pipeline.py @@ -1110,6 +1110,7 @@ async def test_sentence_trigger_overrides_conversation_agent( None, ) assert (intent_end_event is not None) and intent_end_event.data + assert intent_end_event.data["processed_locally"] is True assert ( intent_end_event.data["intent_output"]["response"]["speech"]["plain"][ "speech" @@ -1192,6 +1193,7 @@ async def async_handle( None, ) assert (intent_end_event is not None) and intent_end_event.data + assert intent_end_event.data["processed_locally"] is True assert ( intent_end_event.data["intent_output"]["response"]["speech"]["plain"][ "speech" diff --git a/tests/components/esphome/test_assist_satellite.py b/tests/components/esphome/test_assist_satellite.py index 3acdc1f20291dd..bfcc35b2e6afb2 100644 --- a/tests/components/esphome/test_assist_satellite.py +++ b/tests/components/esphome/test_assist_satellite.py @@ -1776,6 +1776,78 @@ async def test_get_set_configuration( assert satellite.async_get_configuration() == updated_config +async def test_intent_progress_optimization( + hass: HomeAssistant, + mock_client: APIClient, + mock_esphome_device: MockESPHomeDeviceType, +) -> None: + """Test that intent progress events are only sent when early TTS streaming is available.""" + mock_device = await mock_esphome_device( + mock_client=mock_client, + device_info={ + "voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT + }, + ) + await hass.async_block_till_done() + + satellite = get_satellite_entity(hass, mock_device.device_info.mac_address) + assert satellite is not None + + # Test that intent progress without tts_start_streaming is not sent + mock_client.send_voice_assistant_event.reset_mock() + satellite.on_pipeline_event( + PipelineEvent( + type=PipelineEventType.INTENT_PROGRESS, + data={"some_other_key": "value"}, + ) + ) + mock_client.send_voice_assistant_event.assert_not_called() + + # Test that intent progress with tts_start_streaming=False is not sent + satellite.on_pipeline_event( + PipelineEvent( + type=PipelineEventType.INTENT_PROGRESS, + data={"tts_start_streaming": False}, + ) + ) + mock_client.send_voice_assistant_event.assert_not_called() + + # Test that intent progress with tts_start_streaming=True is sent + satellite.on_pipeline_event( + PipelineEvent( + type=PipelineEventType.INTENT_PROGRESS, + data={"tts_start_streaming": True}, + ) + ) + assert mock_client.send_voice_assistant_event.call_args_list[-1].args == ( + VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_PROGRESS, + {"tts_start_streaming": "1"}, + ) + + # Test that intent progress with tts_start_streaming as string "1" is sent + mock_client.send_voice_assistant_event.reset_mock() + satellite.on_pipeline_event( + PipelineEvent( + type=PipelineEventType.INTENT_PROGRESS, + data={"tts_start_streaming": "1"}, + ) + ) + assert mock_client.send_voice_assistant_event.call_args_list[-1].args == ( + VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_PROGRESS, + {"tts_start_streaming": "1"}, + ) + + # Test that intent progress with no data is *not* sent + mock_client.send_voice_assistant_event.reset_mock() + satellite.on_pipeline_event( + PipelineEvent( + type=PipelineEventType.INTENT_PROGRESS, + data=None, + ) + ) + mock_client.send_voice_assistant_event.assert_not_called() + + async def test_wake_word_select( hass: HomeAssistant, mock_client: APIClient, diff --git a/tests/components/homeassistant_hardware/test_config_flow.py b/tests/components/homeassistant_hardware/test_config_flow.py index 2d5067bea3eadd..530308fdf41a72 100644 --- a/tests/components/homeassistant_hardware/test_config_flow.py +++ b/tests/components/homeassistant_hardware/test_config_flow.py @@ -4,9 +4,15 @@ from collections.abc import Awaitable, Callable, Generator, Iterator import contextlib from typing import Any -from unittest.mock import AsyncMock, Mock, call, patch +from unittest.mock import AsyncMock, MagicMock, Mock, call, patch +from ha_silabs_firmware_client import ( + FirmwareManifest, + FirmwareMetadata, + FirmwareUpdateClient, +) import pytest +from yarl import URL from homeassistant.components.hassio import AddonInfo, AddonState from homeassistant.components.homeassistant_hardware.firmware_config_flow import ( @@ -19,12 +25,13 @@ ApplicationType, FirmwareInfo, get_otbr_addon_manager, - get_zigbee_flasher_addon_manager, ) from homeassistant.config_entries import ConfigEntry, ConfigFlowResult, OptionsFlow from homeassistant.core import HomeAssistant, callback from homeassistant.data_entry_flow import FlowResultType +from homeassistant.exceptions import HomeAssistantError from homeassistant.setup import async_setup_component +from homeassistant.util.dt import utcnow from tests.common import ( MockConfigEntry, @@ -37,6 +44,7 @@ TEST_DOMAIN = "test_firmware_domain" TEST_DEVICE = "/dev/SomeDevice123" TEST_HARDWARE_NAME = "Some Hardware Name" +TEST_RELEASES_URL = URL("http://invalid/releases") class FakeFirmwareConfigFlow(BaseFirmwareConfigFlow, domain=TEST_DOMAIN): @@ -62,6 +70,32 @@ async def async_step_hardware( return await self.async_step_confirm() + async def async_step_install_zigbee_firmware( + self, user_input: dict[str, Any] | None = None + ) -> ConfigFlowResult: + """Install Zigbee firmware.""" + return await self._install_firmware_step( + fw_update_url=TEST_RELEASES_URL, + fw_type="fake_zigbee_ncp", + firmware_name="Zigbee", + expected_installed_firmware_type=ApplicationType.EZSP, + step_id="install_zigbee_firmware", + next_step_id="confirm_zigbee", + ) + + async def async_step_install_thread_firmware( + self, user_input: dict[str, Any] | None = None + ) -> ConfigFlowResult: + """Install Thread firmware.""" + return await self._install_firmware_step( + fw_update_url=TEST_RELEASES_URL, + fw_type="fake_openthread_rcp", + firmware_name="Thread", + expected_installed_firmware_type=ApplicationType.SPINEL, + step_id="install_thread_firmware", + next_step_id="start_otbr_addon", + ) + def _async_flow_finished(self) -> ConfigFlowResult: """Create the config entry.""" assert self._device is not None @@ -99,6 +133,18 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: # Regenerate the translation placeholders self._get_translation_placeholders() + async def async_step_install_zigbee_firmware( + self, user_input: dict[str, Any] | None = None + ) -> ConfigFlowResult: + """Install Zigbee firmware.""" + return await self.async_step_confirm_zigbee() + + async def async_step_install_thread_firmware( + self, user_input: dict[str, Any] | None = None + ) -> ConfigFlowResult: + """Install Thread firmware.""" + return await self.async_step_start_otbr_addon() + def _async_flow_finished(self) -> ConfigFlowResult: """Create the config entry.""" assert self._probed_firmware_info is not None @@ -146,12 +192,22 @@ async def side_effect(*args: Any, **kwargs: Any) -> None: return side_effect +def create_mock_owner() -> Mock: + """Mock for OwningAddon / OwningIntegration.""" + owner = Mock() + owner.is_running = AsyncMock(return_value=True) + owner.temporarily_stop = MagicMock() + owner.temporarily_stop.return_value.__aenter__.return_value = AsyncMock() + + return owner + + @contextlib.contextmanager -def mock_addon_info( +def mock_firmware_info( hass: HomeAssistant, *, is_hassio: bool = True, - app_type: ApplicationType | None = ApplicationType.EZSP, + probe_app_type: ApplicationType | None = ApplicationType.EZSP, otbr_addon_info: AddonInfo = AddonInfo( available=True, hostname=None, @@ -160,29 +216,9 @@ def mock_addon_info( update_available=False, version=None, ), - flasher_addon_info: AddonInfo = AddonInfo( - available=True, - hostname=None, - options={}, - state=AddonState.NOT_INSTALLED, - update_available=False, - version=None, - ), + flash_app_type: ApplicationType = ApplicationType.EZSP, ) -> Iterator[tuple[Mock, Mock]]: """Mock the main addon states for the config flow.""" - mock_flasher_manager = Mock(spec_set=get_zigbee_flasher_addon_manager(hass)) - mock_flasher_manager.addon_name = "Silicon Labs Flasher" - mock_flasher_manager.async_start_addon_waiting = AsyncMock( - side_effect=delayed_side_effect() - ) - mock_flasher_manager.async_install_addon_waiting = AsyncMock( - side_effect=delayed_side_effect() - ) - mock_flasher_manager.async_uninstall_addon_waiting = AsyncMock( - side_effect=delayed_side_effect() - ) - mock_flasher_manager.async_get_addon_info.return_value = flasher_addon_info - mock_otbr_manager = Mock(spec_set=get_otbr_addon_manager(hass)) mock_otbr_manager.addon_name = "OpenThread Border Router" mock_otbr_manager.async_install_addon_waiting = AsyncMock( @@ -196,17 +232,73 @@ def mock_addon_info( ) mock_otbr_manager.async_get_addon_info.return_value = otbr_addon_info - if app_type is None: - firmware_info_result = None + mock_update_client = AsyncMock(spec_set=FirmwareUpdateClient) + mock_update_client.async_update_data.return_value = FirmwareManifest( + url=TEST_RELEASES_URL, + html_url=TEST_RELEASES_URL / "html", + created_at=utcnow(), + firmwares=[ + FirmwareMetadata( + filename="fake_openthread_rcp_7.4.4.0_variant.gbl", + checksum="sha256:1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef", + size=123, + release_notes="Some release notes", + metadata={}, + url=TEST_RELEASES_URL / "fake_openthread_rcp_7.4.4.0_variant.gbl", + ), + FirmwareMetadata( + filename="fake_zigbee_ncp_7.4.4.0_variant.gbl", + checksum="sha256:1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef", + size=123, + release_notes="Some release notes", + metadata={}, + url=TEST_RELEASES_URL / "fake_zigbee_ncp_7.4.4.0_variant.gbl", + ), + ], + ) + + if probe_app_type is None: + probed_firmware_info = None else: - firmware_info_result = FirmwareInfo( + probed_firmware_info = FirmwareInfo( device="/dev/ttyUSB0", # Not used - firmware_type=app_type, + firmware_type=probe_app_type, firmware_version=None, owners=[], source="probe", ) + if flash_app_type is None: + flashed_firmware_info = None + else: + flashed_firmware_info = FirmwareInfo( + device=TEST_DEVICE, + firmware_type=flash_app_type, + firmware_version="7.4.4.0", + owners=[create_mock_owner()], + source="probe", + ) + + async def mock_flash_firmware( + hass: HomeAssistant, + device: str, + fw_data: bytes, + expected_installed_firmware_type: ApplicationType, + bootloader_reset_type: str | None = None, + progress_callback: Callable[[int, int], None] | None = None, + ) -> FirmwareInfo: + await asyncio.sleep(0) + progress_callback(0, 100) + await asyncio.sleep(0) + progress_callback(50, 100) + await asyncio.sleep(0) + progress_callback(100, 100) + + if flashed_firmware_info is None: + raise HomeAssistantError("Failed to probe the firmware after flashing") + + return flashed_firmware_info + with ( patch( "homeassistant.components.homeassistant_hardware.firmware_config_flow.get_otbr_addon_manager", @@ -216,10 +308,6 @@ def mock_addon_info( "homeassistant.components.homeassistant_hardware.util.get_otbr_addon_manager", return_value=mock_otbr_manager, ), - patch( - "homeassistant.components.homeassistant_hardware.firmware_config_flow.get_zigbee_flasher_addon_manager", - return_value=mock_flasher_manager, - ), patch( "homeassistant.components.homeassistant_hardware.firmware_config_flow.is_hassio", return_value=is_hassio, @@ -229,81 +317,85 @@ def mock_addon_info( return_value=is_hassio, ), patch( + # We probe once before installation and once after "homeassistant.components.homeassistant_hardware.firmware_config_flow.probe_silabs_firmware_info", - return_value=firmware_info_result, + side_effect=(probed_firmware_info, flashed_firmware_info), + ), + patch( + "homeassistant.components.homeassistant_hardware.firmware_config_flow.FirmwareUpdateClient", + return_value=mock_update_client, + ), + patch( + "homeassistant.components.homeassistant_hardware.util.parse_firmware_image" + ), + patch( + "homeassistant.components.homeassistant_hardware.firmware_config_flow.async_flash_silabs_firmware", + side_effect=mock_flash_firmware, ), ): - yield mock_otbr_manager, mock_flasher_manager + yield mock_otbr_manager + + +async def consume_progress_flow( + hass: HomeAssistant, + flow_id: str, + valid_step_ids: tuple[str], +) -> ConfigFlowResult: + """Consume a progress flow until it is done.""" + while True: + result = await hass.config_entries.flow.async_configure(flow_id) + flow_id = result["flow_id"] + + if result["type"] != FlowResultType.SHOW_PROGRESS: + break + + assert result["type"] is FlowResultType.SHOW_PROGRESS + assert result["step_id"] in valid_step_ids + + await asyncio.sleep(0.1) + + return result async def test_config_flow_zigbee(hass: HomeAssistant) -> None: """Test the config flow.""" - result = await hass.config_entries.flow.async_init( + init_result = await hass.config_entries.flow.async_init( TEST_DOMAIN, context={"source": "hardware"} ) - assert result["type"] is FlowResultType.MENU - assert result["step_id"] == "pick_firmware" + assert init_result["type"] is FlowResultType.MENU + assert init_result["step_id"] == "pick_firmware" - with mock_addon_info( + with mock_firmware_info( hass, - app_type=ApplicationType.SPINEL, - ) as (mock_otbr_manager, mock_flasher_manager): - # Pick the menu option: we are now installing the addon - result = await hass.config_entries.flow.async_configure( - result["flow_id"], + probe_app_type=ApplicationType.SPINEL, + flash_app_type=ApplicationType.EZSP, + ): + # Pick the menu option: we are flashing the firmware + pick_result = await hass.config_entries.flow.async_configure( + init_result["flow_id"], user_input={"next_step_id": STEP_PICK_FIRMWARE_ZIGBEE}, ) - assert result["type"] is FlowResultType.SHOW_PROGRESS - assert result["progress_action"] == "install_addon" - assert result["step_id"] == "install_zigbee_flasher_addon" - assert result["description_placeholders"]["firmware_type"] == "spinel" - - await hass.async_block_till_done(wait_background_tasks=True) - - # Progress the flow, we are now configuring the addon and running it - result = await hass.config_entries.flow.async_configure(result["flow_id"]) - assert result["type"] is FlowResultType.SHOW_PROGRESS - assert result["step_id"] == "run_zigbee_flasher_addon" - assert result["progress_action"] == "run_zigbee_flasher_addon" - assert mock_flasher_manager.async_set_addon_options.mock_calls == [ - call( - { - "device": TEST_DEVICE, - "baudrate": 115200, - "bootloader_baudrate": 115200, - "flow_control": True, - } - ) - ] - - await hass.async_block_till_done(wait_background_tasks=True) - # Progress the flow, we are now uninstalling the addon - result = await hass.config_entries.flow.async_configure(result["flow_id"]) - assert result["type"] is FlowResultType.SHOW_PROGRESS - assert result["step_id"] == "uninstall_zigbee_flasher_addon" - assert result["progress_action"] == "uninstall_zigbee_flasher_addon" + assert pick_result["type"] is FlowResultType.SHOW_PROGRESS + assert pick_result["progress_action"] == "install_firmware" + assert pick_result["step_id"] == "install_zigbee_firmware" - await hass.async_block_till_done(wait_background_tasks=True) + confirm_result = await consume_progress_flow( + hass, + flow_id=pick_result["flow_id"], + valid_step_ids=("install_zigbee_firmware",), + ) - # We are finally done with the addon - assert mock_flasher_manager.async_uninstall_addon_waiting.mock_calls == [call()] + assert confirm_result["type"] is FlowResultType.FORM + assert confirm_result["step_id"] == "confirm_zigbee" - result = await hass.config_entries.flow.async_configure(result["flow_id"]) - assert result["type"] is FlowResultType.FORM - assert result["step_id"] == "confirm_zigbee" - - with mock_addon_info( - hass, - app_type=ApplicationType.EZSP, - ): - result = await hass.config_entries.flow.async_configure( - result["flow_id"], user_input={} + create_result = await hass.config_entries.flow.async_configure( + confirm_result["flow_id"], user_input={} ) - assert result["type"] is FlowResultType.CREATE_ENTRY + assert create_result["type"] is FlowResultType.CREATE_ENTRY - config_entry = result["result"] + config_entry = create_result["result"] assert config_entry.data == { "firmware": "ezsp", "device": TEST_DEVICE, @@ -328,52 +420,20 @@ async def test_config_flow_zigbee_skip_step_if_installed(hass: HomeAssistant) -> assert result["type"] is FlowResultType.MENU assert result["step_id"] == "pick_firmware" - with mock_addon_info( - hass, - app_type=ApplicationType.SPINEL, - flasher_addon_info=AddonInfo( - available=True, - hostname=None, - options={ - "device": "", - "baudrate": 115200, - "bootloader_baudrate": 115200, - "flow_control": True, - }, - state=AddonState.NOT_RUNNING, - update_available=False, - version="1.2.3", - ), - ) as (mock_otbr_manager, mock_flasher_manager): + with mock_firmware_info(hass, probe_app_type=ApplicationType.SPINEL): # Pick the menu option: we skip installation, instead we directly run it result = await hass.config_entries.flow.async_configure( result["flow_id"], user_input={"next_step_id": STEP_PICK_FIRMWARE_ZIGBEE}, ) - assert result["type"] is FlowResultType.SHOW_PROGRESS - assert result["step_id"] == "run_zigbee_flasher_addon" - assert result["progress_action"] == "run_zigbee_flasher_addon" - assert result["description_placeholders"]["firmware_type"] == "spinel" - assert mock_flasher_manager.async_set_addon_options.mock_calls == [ - call( - { - "device": TEST_DEVICE, - "baudrate": 115200, - "bootloader_baudrate": 115200, - "flow_control": True, - } - ) - ] - - # Uninstall the addon - await hass.async_block_till_done(wait_background_tasks=True) + # Confirm result = await hass.config_entries.flow.async_configure(result["flow_id"]) # Done - with mock_addon_info( + with mock_firmware_info( hass, - app_type=ApplicationType.EZSP, + probe_app_type=ApplicationType.EZSP, ): await hass.async_block_till_done(wait_background_tasks=True) result = await hass.config_entries.flow.async_configure(result["flow_id"]) @@ -409,28 +469,29 @@ async def test_config_flow_auto_confirm_if_running(hass: HomeAssistant) -> None: async def test_config_flow_thread(hass: HomeAssistant) -> None: """Test the config flow.""" - result = await hass.config_entries.flow.async_init( + init_result = await hass.config_entries.flow.async_init( TEST_DOMAIN, context={"source": "hardware"} ) - assert result["type"] is FlowResultType.MENU - assert result["step_id"] == "pick_firmware" + assert init_result["type"] is FlowResultType.MENU + assert init_result["step_id"] == "pick_firmware" - with mock_addon_info( + with mock_firmware_info( hass, - app_type=ApplicationType.EZSP, - ) as (mock_otbr_manager, mock_flasher_manager): + probe_app_type=ApplicationType.EZSP, + flash_app_type=ApplicationType.SPINEL, + ) as mock_otbr_manager: # Pick the menu option - result = await hass.config_entries.flow.async_configure( - result["flow_id"], + pick_result = await hass.config_entries.flow.async_configure( + init_result["flow_id"], user_input={"next_step_id": STEP_PICK_FIRMWARE_THREAD}, ) - assert result["type"] is FlowResultType.SHOW_PROGRESS - assert result["progress_action"] == "install_addon" - assert result["step_id"] == "install_otbr_addon" - assert result["description_placeholders"]["firmware_type"] == "ezsp" - assert result["description_placeholders"]["model"] == TEST_HARDWARE_NAME + assert pick_result["type"] is FlowResultType.SHOW_PROGRESS + assert pick_result["progress_action"] == "install_addon" + assert pick_result["step_id"] == "install_otbr_addon" + assert pick_result["description_placeholders"]["firmware_type"] == "ezsp" + assert pick_result["description_placeholders"]["model"] == TEST_HARDWARE_NAME await hass.async_block_till_done(wait_background_tasks=True) @@ -441,19 +502,37 @@ async def test_config_flow_thread(hass: HomeAssistant) -> None: "device": "", "baudrate": 460800, "flow_control": True, - "autoflash_firmware": True, + "autoflash_firmware": False, }, state=AddonState.NOT_RUNNING, update_available=False, version="1.2.3", ) - # Progress the flow, it is now configuring the addon and running it - result = await hass.config_entries.flow.async_configure(result["flow_id"]) + # Progress the flow, it is now installing firmware + confirm_otbr_result = await consume_progress_flow( + hass, + flow_id=pick_result["flow_id"], + valid_step_ids=( + "pick_firmware_thread", + "install_otbr_addon", + "install_thread_firmware", + "start_otbr_addon", + ), + ) - assert result["type"] is FlowResultType.SHOW_PROGRESS - assert result["step_id"] == "start_otbr_addon" - assert result["progress_action"] == "start_otbr_addon" + # Installation will conclude with the config entry being created + create_result = await hass.config_entries.flow.async_configure( + confirm_otbr_result["flow_id"], user_input={} + ) + assert create_result["type"] is FlowResultType.CREATE_ENTRY + + config_entry = create_result["result"] + assert config_entry.data == { + "firmware": "spinel", + "device": TEST_DEVICE, + "hardware": TEST_HARDWARE_NAME, + } assert mock_otbr_manager.async_set_addon_options.mock_calls == [ call( @@ -461,44 +540,22 @@ async def test_config_flow_thread(hass: HomeAssistant) -> None: "device": TEST_DEVICE, "baudrate": 460800, "flow_control": True, - "autoflash_firmware": True, + "autoflash_firmware": False, } ) ] - await hass.async_block_till_done(wait_background_tasks=True) - - # The addon is now running - result = await hass.config_entries.flow.async_configure(result["flow_id"]) - assert result["type"] is FlowResultType.FORM - assert result["step_id"] == "confirm_otbr" - - with mock_addon_info( - hass, - app_type=ApplicationType.SPINEL, - ): - result = await hass.config_entries.flow.async_configure( - result["flow_id"], user_input={} - ) - assert result["type"] is FlowResultType.CREATE_ENTRY - - config_entry = result["result"] - assert config_entry.data == { - "firmware": "spinel", - "device": TEST_DEVICE, - "hardware": TEST_HARDWARE_NAME, - } - async def test_config_flow_thread_addon_already_installed(hass: HomeAssistant) -> None: """Test the Thread config flow, addon is already installed.""" - result = await hass.config_entries.flow.async_init( + init_result = await hass.config_entries.flow.async_init( TEST_DOMAIN, context={"source": "hardware"} ) - with mock_addon_info( + with mock_firmware_info( hass, - app_type=ApplicationType.EZSP, + probe_app_type=ApplicationType.EZSP, + flash_app_type=ApplicationType.SPINEL, otbr_addon_info=AddonInfo( available=True, hostname=None, @@ -507,81 +564,50 @@ async def test_config_flow_thread_addon_already_installed(hass: HomeAssistant) - update_available=False, version=None, ), - ) as (mock_otbr_manager, mock_flasher_manager): + ) as mock_otbr_manager: # Pick the menu option - result = await hass.config_entries.flow.async_configure( - result["flow_id"], + pick_result = await hass.config_entries.flow.async_configure( + init_result["flow_id"], user_input={"next_step_id": STEP_PICK_FIRMWARE_THREAD}, ) - assert result["type"] is FlowResultType.SHOW_PROGRESS - assert result["step_id"] == "start_otbr_addon" - assert result["progress_action"] == "start_otbr_addon" + # Progress + confirm_otbr_result = await consume_progress_flow( + hass, + flow_id=pick_result["flow_id"], + valid_step_ids=( + "pick_firmware_thread", + "install_thread_firmware", + "start_otbr_addon", + ), + ) + + # We're now waiting to confirm OTBR + assert confirm_otbr_result["type"] is FlowResultType.FORM + assert confirm_otbr_result["step_id"] == "confirm_otbr" + + # The addon has been installed assert mock_otbr_manager.async_set_addon_options.mock_calls == [ call( { "device": TEST_DEVICE, "baudrate": 460800, "flow_control": True, - "autoflash_firmware": True, + "autoflash_firmware": False, # And firmware flashing is disabled } ) ] - await hass.async_block_till_done(wait_background_tasks=True) - - # The addon is now running - result = await hass.config_entries.flow.async_configure(result["flow_id"]) - assert result["type"] is FlowResultType.FORM - assert result["step_id"] == "confirm_otbr" - - with mock_addon_info( - hass, - app_type=ApplicationType.SPINEL, - ): - result = await hass.config_entries.flow.async_configure( - result["flow_id"], user_input={} + # Finally, create the config entry + create_result = await hass.config_entries.flow.async_configure( + confirm_otbr_result["flow_id"], user_input={} ) - assert result["type"] is FlowResultType.CREATE_ENTRY - - -async def test_config_flow_zigbee_not_hassio(hass: HomeAssistant) -> None: - """Test when the stick is used with a non-hassio setup.""" - result = await hass.config_entries.flow.async_init( - TEST_DOMAIN, context={"source": "hardware"} - ) - - with mock_addon_info( - hass, - is_hassio=False, - app_type=ApplicationType.EZSP, - ) as (mock_otbr_manager, mock_flasher_manager): - result = await hass.config_entries.flow.async_configure( - result["flow_id"], - user_input={"next_step_id": STEP_PICK_FIRMWARE_ZIGBEE}, - ) - assert result["type"] is FlowResultType.FORM - assert result["step_id"] == "confirm_zigbee" - - result = await hass.config_entries.flow.async_configure( - result["flow_id"], user_input={} - ) - assert result["type"] is FlowResultType.CREATE_ENTRY - - config_entry = result["result"] - assert config_entry.data == { - "firmware": "ezsp", - "device": TEST_DEVICE, - "hardware": TEST_HARDWARE_NAME, - } - - # Ensure a ZHA discovery flow has been created - flows = hass.config_entries.flow.async_progress() - assert len(flows) == 1 - zha_flow = flows[0] - assert zha_flow["handler"] == "zha" - assert zha_flow["context"]["source"] == "hardware" - assert zha_flow["step_id"] == "confirm" + assert create_result["type"] is FlowResultType.CREATE_ENTRY + assert create_result["result"].data == { + "firmware": "spinel", + "device": TEST_DEVICE, + "hardware": TEST_HARDWARE_NAME, + } @pytest.mark.usefixtures("addon_store_info") @@ -601,10 +627,11 @@ async def test_options_flow_zigbee_to_thread(hass: HomeAssistant) -> None: assert await hass.config_entries.async_setup(config_entry.entry_id) - with mock_addon_info( + with mock_firmware_info( hass, - app_type=ApplicationType.EZSP, - ) as (mock_otbr_manager, mock_flasher_manager): + probe_app_type=ApplicationType.EZSP, + flash_app_type=ApplicationType.SPINEL, + ) as mock_otbr_manager: # First step is confirmation result = await hass.config_entries.options.async_init(config_entry.entry_id) assert result["type"] is FlowResultType.MENU @@ -630,7 +657,7 @@ async def test_options_flow_zigbee_to_thread(hass: HomeAssistant) -> None: "device": "", "baudrate": 460800, "flow_control": True, - "autoflash_firmware": True, + "autoflash_firmware": False, }, state=AddonState.NOT_RUNNING, update_available=False, @@ -650,7 +677,7 @@ async def test_options_flow_zigbee_to_thread(hass: HomeAssistant) -> None: "device": TEST_DEVICE, "baudrate": 460800, "flow_control": True, - "autoflash_firmware": True, + "autoflash_firmware": False, } ) ] @@ -662,10 +689,6 @@ async def test_options_flow_zigbee_to_thread(hass: HomeAssistant) -> None: assert result["type"] is FlowResultType.FORM assert result["step_id"] == "confirm_otbr" - with mock_addon_info( - hass, - app_type=ApplicationType.SPINEL, - ): # We are now done result = await hass.config_entries.options.async_configure( result["flow_id"], user_input={} @@ -700,57 +723,23 @@ async def test_options_flow_thread_to_zigbee(hass: HomeAssistant) -> None: assert result["description_placeholders"]["firmware_type"] == "spinel" assert result["description_placeholders"]["model"] == TEST_HARDWARE_NAME - with mock_addon_info( + with mock_firmware_info( hass, - app_type=ApplicationType.SPINEL, - ) as (mock_otbr_manager, mock_flasher_manager): + probe_app_type=ApplicationType.SPINEL, + ): # Pick the menu option: we are now installing the addon result = await hass.config_entries.options.async_configure( result["flow_id"], user_input={"next_step_id": STEP_PICK_FIRMWARE_ZIGBEE}, ) - assert result["type"] is FlowResultType.SHOW_PROGRESS - assert result["progress_action"] == "install_addon" - assert result["step_id"] == "install_zigbee_flasher_addon" - - await hass.async_block_till_done(wait_background_tasks=True) - - # Progress the flow, we are now configuring the addon and running it - result = await hass.config_entries.options.async_configure(result["flow_id"]) - assert result["type"] is FlowResultType.SHOW_PROGRESS - assert result["step_id"] == "run_zigbee_flasher_addon" - assert result["progress_action"] == "run_zigbee_flasher_addon" - assert mock_flasher_manager.async_set_addon_options.mock_calls == [ - call( - { - "device": TEST_DEVICE, - "baudrate": 115200, - "bootloader_baudrate": 115200, - "flow_control": True, - } - ) - ] - - await hass.async_block_till_done(wait_background_tasks=True) - - # Progress the flow, we are now uninstalling the addon - result = await hass.config_entries.options.async_configure(result["flow_id"]) - assert result["type"] is FlowResultType.SHOW_PROGRESS - assert result["step_id"] == "uninstall_zigbee_flasher_addon" - assert result["progress_action"] == "uninstall_zigbee_flasher_addon" - - await hass.async_block_till_done(wait_background_tasks=True) - - # We are finally done with the addon - assert mock_flasher_manager.async_uninstall_addon_waiting.mock_calls == [call()] result = await hass.config_entries.options.async_configure(result["flow_id"]) assert result["type"] is FlowResultType.FORM assert result["step_id"] == "confirm_zigbee" - with mock_addon_info( + with mock_firmware_info( hass, - app_type=ApplicationType.EZSP, + probe_app_type=ApplicationType.EZSP, ): # We are now done result = await hass.config_entries.options.async_configure( diff --git a/tests/components/homeassistant_hardware/test_config_flow_failures.py b/tests/components/homeassistant_hardware/test_config_flow_failures.py index 38c2696a62a5ef..65a5f58b17da78 100644 --- a/tests/components/homeassistant_hardware/test_config_flow_failures.py +++ b/tests/components/homeassistant_hardware/test_config_flow_failures.py @@ -21,8 +21,8 @@ TEST_DEVICE, TEST_DOMAIN, TEST_HARDWARE_NAME, - delayed_side_effect, - mock_addon_info, + consume_progress_flow, + mock_firmware_info, mock_test_firmware_platform, # noqa: F401 ) @@ -51,10 +51,10 @@ async def test_config_flow_cannot_probe_firmware( ) -> None: """Test failure case when firmware cannot be probed.""" - with mock_addon_info( + with mock_firmware_info( hass, - app_type=None, - ) as (mock_otbr_manager, mock_flasher_manager): + probe_app_type=None, + ): # Start the flow result = await hass.config_entries.flow.async_init( TEST_DOMAIN, context={"source": "hardware"} @@ -69,283 +69,6 @@ async def test_config_flow_cannot_probe_firmware( assert result["reason"] == "unsupported_firmware" -@pytest.mark.parametrize( - "ignore_translations_for_mock_domains", - ["test_firmware_domain"], -) -async def test_config_flow_zigbee_not_hassio_wrong_firmware( - hass: HomeAssistant, -) -> None: - """Test when the stick is used with a non-hassio setup but the firmware is bad.""" - result = await hass.config_entries.flow.async_init( - TEST_DOMAIN, context={"source": "hardware"} - ) - - with mock_addon_info( - hass, - app_type=ApplicationType.SPINEL, - is_hassio=False, - ) as (mock_otbr_manager, mock_flasher_manager): - result = await hass.config_entries.flow.async_configure( - result["flow_id"], user_input={} - ) - - result = await hass.config_entries.flow.async_configure( - result["flow_id"], - user_input={"next_step_id": STEP_PICK_FIRMWARE_ZIGBEE}, - ) - assert result["type"] is FlowResultType.ABORT - assert result["reason"] == "not_hassio" - - -@pytest.mark.parametrize( - "ignore_translations_for_mock_domains", - ["test_firmware_domain"], -) -async def test_config_flow_zigbee_flasher_addon_already_running( - hass: HomeAssistant, -) -> None: - """Test failure case when flasher addon is already running.""" - result = await hass.config_entries.flow.async_init( - TEST_DOMAIN, context={"source": "hardware"} - ) - - with mock_addon_info( - hass, - app_type=ApplicationType.SPINEL, - flasher_addon_info=AddonInfo( - available=True, - hostname=None, - options={}, - state=AddonState.RUNNING, - update_available=False, - version="1.0.0", - ), - ) as (mock_otbr_manager, mock_flasher_manager): - result = await hass.config_entries.flow.async_configure( - result["flow_id"], user_input={} - ) - - result = await hass.config_entries.flow.async_configure( - result["flow_id"], - user_input={"next_step_id": STEP_PICK_FIRMWARE_ZIGBEE}, - ) - - # Cannot get addon info - assert result["type"] == FlowResultType.ABORT - assert result["reason"] == "addon_already_running" - - -@pytest.mark.parametrize( - "ignore_translations_for_mock_domains", - ["test_firmware_domain"], -) -async def test_config_flow_zigbee_flasher_addon_info_fails(hass: HomeAssistant) -> None: - """Test failure case when flasher addon cannot be installed.""" - result = await hass.config_entries.flow.async_init( - TEST_DOMAIN, context={"source": "hardware"} - ) - - with mock_addon_info( - hass, - app_type=ApplicationType.SPINEL, - flasher_addon_info=AddonInfo( - available=True, - hostname=None, - options={}, - state=AddonState.RUNNING, - update_available=False, - version="1.0.0", - ), - ) as (mock_otbr_manager, mock_flasher_manager): - mock_flasher_manager.async_get_addon_info.side_effect = AddonError() - - result = await hass.config_entries.flow.async_configure( - result["flow_id"], user_input={} - ) - result = await hass.config_entries.flow.async_configure( - result["flow_id"], - user_input={"next_step_id": STEP_PICK_FIRMWARE_ZIGBEE}, - ) - - # Cannot get addon info - assert result["type"] == FlowResultType.ABORT - assert result["reason"] == "addon_info_failed" - - -@pytest.mark.parametrize( - "ignore_translations_for_mock_domains", - ["test_firmware_domain"], -) -async def test_config_flow_zigbee_flasher_addon_install_fails( - hass: HomeAssistant, -) -> None: - """Test failure case when flasher addon cannot be installed.""" - result = await hass.config_entries.flow.async_init( - TEST_DOMAIN, context={"source": "hardware"} - ) - - with mock_addon_info( - hass, - app_type=ApplicationType.SPINEL, - ) as (mock_otbr_manager, mock_flasher_manager): - mock_flasher_manager.async_install_addon_waiting = AsyncMock( - side_effect=AddonError() - ) - - result = await hass.config_entries.flow.async_configure( - result["flow_id"], user_input={} - ) - - result = await hass.config_entries.flow.async_configure( - result["flow_id"], - user_input={"next_step_id": STEP_PICK_FIRMWARE_ZIGBEE}, - ) - - # Cannot install addon - assert result["type"] == FlowResultType.ABORT - assert result["reason"] == "addon_install_failed" - - -@pytest.mark.parametrize( - "ignore_translations_for_mock_domains", - ["test_firmware_domain"], -) -async def test_config_flow_zigbee_flasher_addon_set_config_fails( - hass: HomeAssistant, -) -> None: - """Test failure case when flasher addon cannot be configured.""" - result = await hass.config_entries.flow.async_init( - TEST_DOMAIN, context={"source": "hardware"} - ) - - with mock_addon_info( - hass, - app_type=ApplicationType.SPINEL, - ) as (mock_otbr_manager, mock_flasher_manager): - mock_flasher_manager.async_install_addon_waiting = AsyncMock( - side_effect=delayed_side_effect() - ) - mock_flasher_manager.async_set_addon_options = AsyncMock( - side_effect=AddonError() - ) - - result = await hass.config_entries.flow.async_configure( - result["flow_id"], user_input={} - ) - result = await hass.config_entries.flow.async_configure( - result["flow_id"], - user_input={"next_step_id": STEP_PICK_FIRMWARE_ZIGBEE}, - ) - result = await hass.config_entries.flow.async_configure(result["flow_id"]) - await hass.async_block_till_done(wait_background_tasks=True) - - result = await hass.config_entries.flow.async_configure(result["flow_id"]) - assert result["type"] == FlowResultType.ABORT - assert result["reason"] == "addon_set_config_failed" - - -@pytest.mark.parametrize( - "ignore_translations_for_mock_domains", - ["test_firmware_domain"], -) -async def test_config_flow_zigbee_flasher_run_fails(hass: HomeAssistant) -> None: - """Test failure case when flasher addon fails to run.""" - result = await hass.config_entries.flow.async_init( - TEST_DOMAIN, context={"source": "hardware"} - ) - - with mock_addon_info( - hass, - app_type=ApplicationType.SPINEL, - ) as (mock_otbr_manager, mock_flasher_manager): - mock_flasher_manager.async_start_addon_waiting = AsyncMock( - side_effect=AddonError() - ) - - result = await hass.config_entries.flow.async_configure( - result["flow_id"], user_input={} - ) - result = await hass.config_entries.flow.async_configure( - result["flow_id"], - user_input={"next_step_id": STEP_PICK_FIRMWARE_ZIGBEE}, - ) - result = await hass.config_entries.flow.async_configure(result["flow_id"]) - await hass.async_block_till_done(wait_background_tasks=True) - - result = await hass.config_entries.flow.async_configure(result["flow_id"]) - assert result["type"] == FlowResultType.ABORT - assert result["reason"] == "addon_start_failed" - - -async def test_config_flow_zigbee_flasher_uninstall_fails(hass: HomeAssistant) -> None: - """Test failure case when flasher addon uninstall fails.""" - result = await hass.config_entries.flow.async_init( - TEST_DOMAIN, context={"source": "hardware"} - ) - - with mock_addon_info( - hass, - app_type=ApplicationType.SPINEL, - ) as (mock_otbr_manager, mock_flasher_manager): - mock_flasher_manager.async_uninstall_addon_waiting = AsyncMock( - side_effect=AddonError() - ) - result = await hass.config_entries.flow.async_configure( - result["flow_id"], user_input={} - ) - result = await hass.config_entries.flow.async_configure( - result["flow_id"], - user_input={"next_step_id": STEP_PICK_FIRMWARE_ZIGBEE}, - ) - result = await hass.config_entries.flow.async_configure(result["flow_id"]) - await hass.async_block_till_done(wait_background_tasks=True) - - result = await hass.config_entries.flow.async_configure(result["flow_id"]) - await hass.async_block_till_done(wait_background_tasks=True) - - # Uninstall failure isn't critical - result = await hass.config_entries.flow.async_configure(result["flow_id"]) - assert result["type"] is FlowResultType.FORM - assert result["step_id"] == "confirm_zigbee" - - -@pytest.mark.parametrize( - "ignore_translations_for_mock_domains", - ["test_firmware_domain"], -) -async def test_config_flow_zigbee_confirmation_fails(hass: HomeAssistant) -> None: - """Test the config flow failing due to Zigbee firmware not being detected.""" - result = await hass.config_entries.flow.async_init( - TEST_DOMAIN, context={"source": "hardware"} - ) - - assert result["type"] is FlowResultType.MENU - assert result["step_id"] == "pick_firmware" - - with mock_addon_info( - hass, - app_type=ApplicationType.EZSP, - ) as (mock_otbr_manager, mock_flasher_manager): - # Pick the menu option: we are now installing the addon - result = await hass.config_entries.flow.async_configure( - result["flow_id"], - user_input={"next_step_id": STEP_PICK_FIRMWARE_ZIGBEE}, - ) - assert result["type"] is FlowResultType.FORM - assert result["step_id"] == "confirm_zigbee" - - with mock_addon_info( - hass, - app_type=None, # Probing fails - ) as (mock_otbr_manager, mock_flasher_manager): - result = await hass.config_entries.flow.async_configure( - result["flow_id"], user_input={} - ) - assert result["type"] is FlowResultType.ABORT - assert result["reason"] == "unsupported_firmware" - - @pytest.mark.parametrize( "ignore_translations_for_mock_domains", ["test_firmware_domain"], @@ -356,11 +79,11 @@ async def test_config_flow_thread_not_hassio(hass: HomeAssistant) -> None: TEST_DOMAIN, context={"source": "hardware"} ) - with mock_addon_info( + with mock_firmware_info( hass, is_hassio=False, - app_type=ApplicationType.EZSP, - ) as (mock_otbr_manager, mock_flasher_manager): + probe_app_type=ApplicationType.EZSP, + ): result = await hass.config_entries.flow.async_configure( result["flow_id"], user_input={} ) @@ -383,10 +106,10 @@ async def test_config_flow_thread_addon_info_fails(hass: HomeAssistant) -> None: TEST_DOMAIN, context={"source": "hardware"} ) - with mock_addon_info( + with mock_firmware_info( hass, - app_type=ApplicationType.EZSP, - ) as (mock_otbr_manager, mock_flasher_manager): + probe_app_type=ApplicationType.EZSP, + ) as mock_otbr_manager: mock_otbr_manager.async_get_addon_info.side_effect = AddonError() result = await hass.config_entries.flow.async_configure( result["flow_id"], user_input={} @@ -405,24 +128,26 @@ async def test_config_flow_thread_addon_info_fails(hass: HomeAssistant) -> None: "ignore_translations_for_mock_domains", ["test_firmware_domain"], ) -async def test_config_flow_thread_addon_already_running(hass: HomeAssistant) -> None: +async def test_config_flow_thread_addon_already_configured(hass: HomeAssistant) -> None: """Test failure case when the Thread addon is already running.""" result = await hass.config_entries.flow.async_init( TEST_DOMAIN, context={"source": "hardware"} ) - with mock_addon_info( + with mock_firmware_info( hass, - app_type=ApplicationType.EZSP, + probe_app_type=ApplicationType.EZSP, otbr_addon_info=AddonInfo( available=True, hostname=None, - options={}, + options={ + "device": TEST_DEVICE + "2", # A different device + }, state=AddonState.RUNNING, update_available=False, version="1.0.0", ), - ) as (mock_otbr_manager, mock_flasher_manager): + ) as mock_otbr_manager: mock_otbr_manager.async_install_addon_waiting = AsyncMock( side_effect=AddonError() ) @@ -450,10 +175,10 @@ async def test_config_flow_thread_addon_install_fails(hass: HomeAssistant) -> No TEST_DOMAIN, context={"source": "hardware"} ) - with mock_addon_info( + with mock_firmware_info( hass, - app_type=ApplicationType.EZSP, - ) as (mock_otbr_manager, mock_flasher_manager): + probe_app_type=ApplicationType.EZSP, + ) as mock_otbr_manager: mock_otbr_manager.async_install_addon_waiting = AsyncMock( side_effect=AddonError() ) @@ -477,29 +202,51 @@ async def test_config_flow_thread_addon_install_fails(hass: HomeAssistant) -> No ) async def test_config_flow_thread_addon_set_config_fails(hass: HomeAssistant) -> None: """Test failure case when flasher addon cannot be configured.""" - result = await hass.config_entries.flow.async_init( + init_result = await hass.config_entries.flow.async_init( TEST_DOMAIN, context={"source": "hardware"} ) - with mock_addon_info( + with mock_firmware_info( hass, - app_type=ApplicationType.EZSP, - ) as (mock_otbr_manager, mock_flasher_manager): + probe_app_type=ApplicationType.EZSP, + ) as mock_otbr_manager: + + async def install_addon() -> None: + mock_otbr_manager.async_get_addon_info.return_value = AddonInfo( + available=True, + hostname=None, + options={"device": TEST_DEVICE}, + state=AddonState.NOT_RUNNING, + update_available=False, + version="1.0.0", + ) + + mock_otbr_manager.async_install_addon_waiting = AsyncMock( + side_effect=install_addon + ) mock_otbr_manager.async_set_addon_options = AsyncMock(side_effect=AddonError()) - result = await hass.config_entries.flow.async_configure( - result["flow_id"], user_input={} + confirm_result = await hass.config_entries.flow.async_configure( + init_result["flow_id"], user_input={} ) - result = await hass.config_entries.flow.async_configure( - result["flow_id"], + + pick_thread_result = await hass.config_entries.flow.async_configure( + confirm_result["flow_id"], user_input={"next_step_id": STEP_PICK_FIRMWARE_THREAD}, ) - result = await hass.config_entries.flow.async_configure(result["flow_id"]) - await hass.async_block_till_done(wait_background_tasks=True) - result = await hass.config_entries.flow.async_configure(result["flow_id"]) - assert result["type"] == FlowResultType.ABORT - assert result["reason"] == "addon_set_config_failed" + pick_thread_progress_result = await consume_progress_flow( + hass, + flow_id=pick_thread_result["flow_id"], + valid_step_ids=( + "pick_firmware_thread", + "install_thread_firmware", + "start_otbr_addon", + ), + ) + + assert pick_thread_progress_result["type"] == FlowResultType.ABORT + assert pick_thread_progress_result["reason"] == "addon_set_config_failed" @pytest.mark.parametrize( @@ -508,63 +255,45 @@ async def test_config_flow_thread_addon_set_config_fails(hass: HomeAssistant) -> ) async def test_config_flow_thread_flasher_run_fails(hass: HomeAssistant) -> None: """Test failure case when flasher addon fails to run.""" - result = await hass.config_entries.flow.async_init( + init_result = await hass.config_entries.flow.async_init( TEST_DOMAIN, context={"source": "hardware"} ) - with mock_addon_info( + with mock_firmware_info( hass, - app_type=ApplicationType.EZSP, - ) as (mock_otbr_manager, mock_flasher_manager): + probe_app_type=ApplicationType.EZSP, + otbr_addon_info=AddonInfo( + available=True, + hostname=None, + options={"device": TEST_DEVICE}, + state=AddonState.NOT_RUNNING, + update_available=False, + version="1.0.0", + ), + ) as mock_otbr_manager: mock_otbr_manager.async_start_addon_waiting = AsyncMock( side_effect=AddonError() ) - result = await hass.config_entries.flow.async_configure( - result["flow_id"], user_input={} + confirm_result = await hass.config_entries.flow.async_configure( + init_result["flow_id"], user_input={} ) - result = await hass.config_entries.flow.async_configure( - result["flow_id"], + pick_thread_result = await hass.config_entries.flow.async_configure( + confirm_result["flow_id"], user_input={"next_step_id": STEP_PICK_FIRMWARE_THREAD}, ) - result = await hass.config_entries.flow.async_configure(result["flow_id"]) - await hass.async_block_till_done(wait_background_tasks=True) - - result = await hass.config_entries.flow.async_configure(result["flow_id"]) - assert result["type"] == FlowResultType.ABORT - assert result["reason"] == "addon_start_failed" - - -async def test_config_flow_thread_flasher_uninstall_fails(hass: HomeAssistant) -> None: - """Test failure case when flasher addon uninstall fails.""" - result = await hass.config_entries.flow.async_init( - TEST_DOMAIN, context={"source": "hardware"} - ) - - with mock_addon_info( - hass, - app_type=ApplicationType.EZSP, - ) as (mock_otbr_manager, mock_flasher_manager): - mock_otbr_manager.async_uninstall_addon_waiting = AsyncMock( - side_effect=AddonError() - ) - result = await hass.config_entries.flow.async_configure( - result["flow_id"], user_input={} + pick_thread_progress_result = await consume_progress_flow( + hass, + flow_id=pick_thread_result["flow_id"], + valid_step_ids=( + "pick_firmware_thread", + "install_thread_firmware", + "start_otbr_addon", + ), ) - result = await hass.config_entries.flow.async_configure( - result["flow_id"], - user_input={"next_step_id": STEP_PICK_FIRMWARE_THREAD}, - ) - result = await hass.config_entries.flow.async_configure(result["flow_id"]) - await hass.async_block_till_done(wait_background_tasks=True) - - result = await hass.config_entries.flow.async_configure(result["flow_id"]) - await hass.async_block_till_done(wait_background_tasks=True) - # Uninstall failure isn't critical - result = await hass.config_entries.flow.async_configure(result["flow_id"]) - assert result["type"] is FlowResultType.FORM - assert result["step_id"] == "confirm_otbr" + assert pick_thread_progress_result["type"] == FlowResultType.ABORT + assert pick_thread_progress_result["reason"] == "addon_start_failed" @pytest.mark.parametrize( @@ -573,40 +302,43 @@ async def test_config_flow_thread_flasher_uninstall_fails(hass: HomeAssistant) - ) async def test_config_flow_thread_confirmation_fails(hass: HomeAssistant) -> None: """Test the config flow failing due to OpenThread firmware not being detected.""" - result = await hass.config_entries.flow.async_init( + init_result = await hass.config_entries.flow.async_init( TEST_DOMAIN, context={"source": "hardware"} ) - with mock_addon_info( + with mock_firmware_info( hass, - app_type=ApplicationType.EZSP, - ) as (mock_otbr_manager, mock_flasher_manager): - result = await hass.config_entries.flow.async_configure( - result["flow_id"], user_input={} + probe_app_type=ApplicationType.EZSP, + flash_app_type=None, + otbr_addon_info=AddonInfo( + available=True, + hostname=None, + options={"device": TEST_DEVICE}, + state=AddonState.RUNNING, + update_available=False, + version="1.0.0", + ), + ): + confirm_result = await hass.config_entries.flow.async_configure( + init_result["flow_id"], user_input={} ) - result = await hass.config_entries.flow.async_configure( - result["flow_id"], + pick_thread_result = await hass.config_entries.flow.async_configure( + confirm_result["flow_id"], user_input={"next_step_id": STEP_PICK_FIRMWARE_THREAD}, ) - result = await hass.config_entries.flow.async_configure(result["flow_id"]) - await hass.async_block_till_done(wait_background_tasks=True) - - result = await hass.config_entries.flow.async_configure(result["flow_id"]) - await hass.async_block_till_done(wait_background_tasks=True) - result = await hass.config_entries.flow.async_configure(result["flow_id"]) - assert result["type"] is FlowResultType.FORM - assert result["step_id"] == "confirm_otbr" - - with mock_addon_info( - hass, - app_type=None, # Probing fails - ) as (mock_otbr_manager, mock_flasher_manager): - result = await hass.config_entries.flow.async_configure( - result["flow_id"], user_input={} + pick_thread_progress_result = await consume_progress_flow( + hass, + flow_id=pick_thread_result["flow_id"], + valid_step_ids=( + "pick_firmware_thread", + "install_thread_firmware", + "start_otbr_addon", + ), ) - assert result["type"] is FlowResultType.ABORT - assert result["reason"] == "unsupported_firmware" + + assert pick_thread_progress_result["type"] is FlowResultType.ABORT + assert pick_thread_progress_result["reason"] == "unsupported_firmware" @pytest.mark.parametrize( @@ -683,9 +415,9 @@ async def test_options_flow_thread_to_zigbee_otbr_configured( # Confirm options flow result = await hass.config_entries.options.async_init(config_entry.entry_id) - with mock_addon_info( + with mock_firmware_info( hass, - app_type=ApplicationType.SPINEL, + probe_app_type=ApplicationType.SPINEL, otbr_addon_info=AddonInfo( available=True, hostname=None, @@ -694,7 +426,7 @@ async def test_options_flow_thread_to_zigbee_otbr_configured( update_available=False, version="1.0.0", ), - ) as (mock_otbr_manager, mock_flasher_manager): + ): result = await hass.config_entries.options.async_configure( result["flow_id"], user_input={"next_step_id": STEP_PICK_FIRMWARE_ZIGBEE}, diff --git a/tests/components/homeassistant_hardware/test_update.py b/tests/components/homeassistant_hardware/test_update.py index 81c6f2e0459de0..aacc064e4f24a5 100644 --- a/tests/components/homeassistant_hardware/test_update.py +++ b/tests/components/homeassistant_hardware/test_update.py @@ -3,10 +3,10 @@ from __future__ import annotations import asyncio -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, Callable import dataclasses import logging -from unittest.mock import AsyncMock, Mock, patch +from unittest.mock import Mock, patch import aiohttp from ha_silabs_firmware_client import FirmwareManifest, FirmwareMetadata @@ -355,10 +355,14 @@ async def test_update_entity_installation( "https://example.org/release_notes" ) - mock_firmware = Mock() - mock_flasher = AsyncMock() - - async def mock_flash_firmware(fw_image, progress_callback): + async def mock_flash_firmware( + hass: HomeAssistant, + device: str, + fw_data: bytes, + expected_installed_firmware_type: ApplicationType, + bootloader_reset_type: str | None = None, + progress_callback: Callable[[int, int], None] | None = None, + ) -> FirmwareInfo: await asyncio.sleep(0) progress_callback(0, 100) await asyncio.sleep(0) @@ -366,31 +370,20 @@ async def mock_flash_firmware(fw_image, progress_callback): await asyncio.sleep(0) progress_callback(100, 100) - mock_flasher.flash_firmware = mock_flash_firmware + return FirmwareInfo( + device=TEST_DEVICE, + firmware_type=ApplicationType.EZSP, + firmware_version="7.4.4.0 build 0", + owners=[], + source="probe", + ) # When we install it, the other integration is reloaded with ( patch( - "homeassistant.components.homeassistant_hardware.update.parse_firmware_image", - return_value=mock_firmware, - ), - patch( - "homeassistant.components.homeassistant_hardware.update.Flasher", - return_value=mock_flasher, - ), - patch( - "homeassistant.components.homeassistant_hardware.update.probe_silabs_firmware_info", - return_value=FirmwareInfo( - device=TEST_DEVICE, - firmware_type=ApplicationType.EZSP, - firmware_version="7.4.4.0 build 0", - owners=[], - source="probe", - ), + "homeassistant.components.homeassistant_hardware.update.async_flash_silabs_firmware", + side_effect=mock_flash_firmware, ), - patch.object( - owning_config_entry, "async_unload", wraps=owning_config_entry.async_unload - ) as owning_config_entry_unload, ): state_changes: list[Event[EventStateChangedData]] = async_capture_events( hass, EVENT_STATE_CHANGED @@ -423,9 +416,6 @@ async def mock_flash_firmware(fw_image, progress_callback): assert state_changes[6].data["new_state"].attributes["update_percentage"] is None assert state_changes[6].data["new_state"].attributes["in_progress"] is False - # The owning integration was unloaded and is again running - assert len(owning_config_entry_unload.mock_calls) == 1 - # After the firmware update, the entity has the new version and the correct state state_after_install = hass.states.get(TEST_UPDATE_ENTITY_ID) assert state_after_install is not None @@ -456,19 +446,10 @@ async def test_update_entity_installation_failure( assert state_before_install.attributes["installed_version"] == "7.3.1.0" assert state_before_install.attributes["latest_version"] == "7.4.4.0" - mock_flasher = AsyncMock() - mock_flasher.flash_firmware.side_effect = RuntimeError( - "Something broke during flashing!" - ) - with ( patch( - "homeassistant.components.homeassistant_hardware.update.parse_firmware_image", - return_value=Mock(), - ), - patch( - "homeassistant.components.homeassistant_hardware.update.Flasher", - return_value=mock_flasher, + "homeassistant.components.homeassistant_hardware.update.async_flash_silabs_firmware", + side_effect=HomeAssistantError("Failed to flash firmware"), ), pytest.raises(HomeAssistantError, match="Failed to flash firmware"), ): @@ -511,16 +492,10 @@ async def test_update_entity_installation_probe_failure( with ( patch( - "homeassistant.components.homeassistant_hardware.update.parse_firmware_image", - return_value=Mock(), - ), - patch( - "homeassistant.components.homeassistant_hardware.update.Flasher", - return_value=AsyncMock(), - ), - patch( - "homeassistant.components.homeassistant_hardware.update.probe_silabs_firmware_info", - return_value=None, + "homeassistant.components.homeassistant_hardware.update.async_flash_silabs_firmware", + side_effect=HomeAssistantError( + "Failed to probe the firmware after flashing" + ), ), pytest.raises( HomeAssistantError, match="Failed to probe the firmware after flashing" diff --git a/tests/components/homeassistant_hardware/test_util.py b/tests/components/homeassistant_hardware/test_util.py index 1b7bfe4a8acb0d..048bf998d1317c 100644 --- a/tests/components/homeassistant_hardware/test_util.py +++ b/tests/components/homeassistant_hardware/test_util.py @@ -1,10 +1,13 @@ """Test hardware utilities.""" -from unittest.mock import AsyncMock, MagicMock, patch +import asyncio +from collections.abc import Callable +from unittest.mock import ANY, AsyncMock, MagicMock, Mock, call, patch import pytest from universal_silabs_flasher.common import Version as FlasherVersion from universal_silabs_flasher.const import ApplicationType as FlasherApplicationType +from universal_silabs_flasher.firmware import GBLImage from homeassistant.components.hassio import ( AddonError, @@ -20,6 +23,7 @@ FirmwareInfo, OwningAddon, OwningIntegration, + async_flash_silabs_firmware, get_otbr_addon_firmware_info, guess_firmware_info, probe_silabs_firmware_info, @@ -27,8 +31,11 @@ ) from homeassistant.config_entries import ConfigEntry, ConfigEntryState from homeassistant.core import HomeAssistant +from homeassistant.exceptions import HomeAssistantError from homeassistant.setup import async_setup_component +from .test_config_flow import create_mock_owner + from tests.common import MockConfigEntry ZHA_CONFIG_ENTRY = MockConfigEntry( @@ -526,3 +533,201 @@ async def test_probe_silabs_firmware_type( ): result = await probe_silabs_firmware_type("/dev/ttyUSB0") assert result == expected + + +async def test_async_flash_silabs_firmware(hass: HomeAssistant) -> None: + """Test async_flash_silabs_firmware.""" + owner1 = create_mock_owner() + owner2 = create_mock_owner() + + progress_callback = Mock() + + async def mock_flash_firmware( + fw_image: GBLImage, progress_callback: Callable[[int, int], None] + ) -> None: + """Mock flash firmware function.""" + await asyncio.sleep(0) + progress_callback(0, 100) + await asyncio.sleep(0) + progress_callback(50, 100) + await asyncio.sleep(0) + progress_callback(100, 100) + await asyncio.sleep(0) + + mock_flasher = Mock() + mock_flasher.enter_bootloader = AsyncMock() + mock_flasher.flash_firmware = AsyncMock(side_effect=mock_flash_firmware) + + expected_firmware_info = FirmwareInfo( + device="/dev/ttyUSB0", + firmware_type=ApplicationType.SPINEL, + firmware_version=None, + source="probe", + owners=[], + ) + + with ( + patch( + "homeassistant.components.homeassistant_hardware.util.guess_firmware_info", + return_value=FirmwareInfo( + device="/dev/ttyUSB0", + firmware_type=ApplicationType.EZSP, + firmware_version=None, + source="unknown", + owners=[owner1, owner2], + ), + ), + patch( + "homeassistant.components.homeassistant_hardware.util.Flasher", + return_value=mock_flasher, + ), + patch( + "homeassistant.components.homeassistant_hardware.util.parse_firmware_image" + ), + patch( + "homeassistant.components.homeassistant_hardware.util.probe_silabs_firmware_info", + return_value=expected_firmware_info, + ), + ): + after_flash_info = await async_flash_silabs_firmware( + hass=hass, + device="/dev/ttyUSB0", + fw_data=b"firmware contents", + expected_installed_firmware_type=ApplicationType.SPINEL, + bootloader_reset_type=None, + progress_callback=progress_callback, + ) + + assert progress_callback.mock_calls == [call(0, 100), call(50, 100), call(100, 100)] + assert after_flash_info == expected_firmware_info + + # Both owning integrations/addons are stopped and restarted + assert owner1.temporarily_stop.mock_calls == [ + call(hass), + # pylint: disable-next=unnecessary-dunder-call + call().__aenter__(ANY), + # pylint: disable-next=unnecessary-dunder-call + call().__aexit__(ANY, None, None, None), + ] + + assert owner2.temporarily_stop.mock_calls == [ + call(hass), + # pylint: disable-next=unnecessary-dunder-call + call().__aenter__(ANY), + # pylint: disable-next=unnecessary-dunder-call + call().__aexit__(ANY, None, None, None), + ] + + +async def test_async_flash_silabs_firmware_flash_failure(hass: HomeAssistant) -> None: + """Test async_flash_silabs_firmware flash failure.""" + owner1 = create_mock_owner() + owner2 = create_mock_owner() + + mock_flasher = Mock() + mock_flasher.enter_bootloader = AsyncMock() + mock_flasher.flash_firmware = AsyncMock(side_effect=RuntimeError("Failure!")) + + with ( + patch( + "homeassistant.components.homeassistant_hardware.util.guess_firmware_info", + return_value=FirmwareInfo( + device="/dev/ttyUSB0", + firmware_type=ApplicationType.EZSP, + firmware_version=None, + source="unknown", + owners=[owner1, owner2], + ), + ), + patch( + "homeassistant.components.homeassistant_hardware.util.Flasher", + return_value=mock_flasher, + ), + patch( + "homeassistant.components.homeassistant_hardware.util.parse_firmware_image" + ), + pytest.raises(HomeAssistantError, match="Failed to flash firmware") as exc, + ): + await async_flash_silabs_firmware( + hass=hass, + device="/dev/ttyUSB0", + fw_data=b"firmware contents", + expected_installed_firmware_type=ApplicationType.SPINEL, + bootloader_reset_type=None, + ) + + # Both owning integrations/addons are stopped and restarted + assert owner1.temporarily_stop.mock_calls == [ + call(hass), + # pylint: disable-next=unnecessary-dunder-call + call().__aenter__(ANY), + # pylint: disable-next=unnecessary-dunder-call + call().__aexit__(ANY, HomeAssistantError, exc.value, ANY), + ] + assert owner2.temporarily_stop.mock_calls == [ + call(hass), + # pylint: disable-next=unnecessary-dunder-call + call().__aenter__(ANY), + # pylint: disable-next=unnecessary-dunder-call + call().__aexit__(ANY, HomeAssistantError, exc.value, ANY), + ] + + +async def test_async_flash_silabs_firmware_probe_failure(hass: HomeAssistant) -> None: + """Test async_flash_silabs_firmware probe failure.""" + owner1 = create_mock_owner() + owner2 = create_mock_owner() + + mock_flasher = Mock() + mock_flasher.enter_bootloader = AsyncMock() + mock_flasher.flash_firmware = AsyncMock() + + with ( + patch( + "homeassistant.components.homeassistant_hardware.util.guess_firmware_info", + return_value=FirmwareInfo( + device="/dev/ttyUSB0", + firmware_type=ApplicationType.EZSP, + firmware_version=None, + source="unknown", + owners=[owner1, owner2], + ), + ), + patch( + "homeassistant.components.homeassistant_hardware.util.Flasher", + return_value=mock_flasher, + ), + patch( + "homeassistant.components.homeassistant_hardware.util.parse_firmware_image" + ), + patch( + "homeassistant.components.homeassistant_hardware.util.probe_silabs_firmware_info", + return_value=None, + ), + pytest.raises( + HomeAssistantError, match="Failed to probe the firmware after flashing" + ), + ): + await async_flash_silabs_firmware( + hass=hass, + device="/dev/ttyUSB0", + fw_data=b"firmware contents", + expected_installed_firmware_type=ApplicationType.SPINEL, + bootloader_reset_type=None, + ) + + # Both owning integrations/addons are stopped and restarted + assert owner1.temporarily_stop.mock_calls == [ + call(hass), + # pylint: disable-next=unnecessary-dunder-call + call().__aenter__(ANY), + # pylint: disable-next=unnecessary-dunder-call + call().__aexit__(ANY, None, None, None), + ] + assert owner2.temporarily_stop.mock_calls == [ + call(hass), + # pylint: disable-next=unnecessary-dunder-call + call().__aenter__(ANY), + # pylint: disable-next=unnecessary-dunder-call + call().__aexit__(ANY, None, None, None), + ] diff --git a/tests/components/homeassistant_sky_connect/test_config_flow.py b/tests/components/homeassistant_sky_connect/test_config_flow.py index 44a5e0029c3fff..9dcac0732c9c3f 100644 --- a/tests/components/homeassistant_sky_connect/test_config_flow.py +++ b/tests/components/homeassistant_sky_connect/test_config_flow.py @@ -6,6 +6,7 @@ from homeassistant.components.hassio import AddonInfo, AddonState from homeassistant.components.homeassistant_hardware.firmware_config_flow import ( + STEP_PICK_FIRMWARE_THREAD, STEP_PICK_FIRMWARE_ZIGBEE, ) from homeassistant.components.homeassistant_hardware.silabs_multiprotocol_addon import ( @@ -18,6 +19,7 @@ FirmwareInfo, ) from homeassistant.components.homeassistant_sky_connect.const import DOMAIN +from homeassistant.config_entries import ConfigFlowResult from homeassistant.core import HomeAssistant from homeassistant.data_entry_flow import FlowResultType from homeassistant.helpers.service_info.usb import UsbServiceInfo @@ -28,14 +30,31 @@ @pytest.mark.parametrize( - ("usb_data", "model"), + ("step", "usb_data", "model", "fw_type", "fw_version"), [ - (USB_DATA_SKY, "Home Assistant SkyConnect"), - (USB_DATA_ZBT1, "Home Assistant Connect ZBT-1"), + ( + STEP_PICK_FIRMWARE_ZIGBEE, + USB_DATA_SKY, + "Home Assistant SkyConnect", + ApplicationType.EZSP, + "7.4.4.0 build 0", + ), + ( + STEP_PICK_FIRMWARE_THREAD, + USB_DATA_ZBT1, + "Home Assistant Connect ZBT-1", + ApplicationType.SPINEL, + "2.4.4.0", + ), ], ) async def test_config_flow( - usb_data: UsbServiceInfo, model: str, hass: HomeAssistant + step: str, + usb_data: UsbServiceInfo, + model: str, + fw_type: ApplicationType, + fw_version: str, + hass: HomeAssistant, ) -> None: """Test the config flow for SkyConnect.""" result = await hass.config_entries.flow.async_init( @@ -46,21 +65,36 @@ async def test_config_flow( assert result["step_id"] == "pick_firmware" assert result["description_placeholders"]["model"] == model - async def mock_async_step_pick_firmware_zigbee(self, data): - return await self.async_step_confirm_zigbee(user_input={}) + async def mock_install_firmware_step( + self, + fw_update_url: str, + fw_type: str, + firmware_name: str, + expected_installed_firmware_type: ApplicationType, + step_id: str, + next_step_id: str, + ) -> ConfigFlowResult: + if next_step_id == "start_otbr_addon": + next_step_id = "confirm_otbr" + + return await getattr(self, f"async_step_{next_step_id}")(user_input={}) with ( patch( - "homeassistant.components.homeassistant_hardware.firmware_config_flow.BaseFirmwareConfigFlow.async_step_pick_firmware_zigbee", + "homeassistant.components.homeassistant_hardware.firmware_config_flow.BaseFirmwareConfigFlow._ensure_thread_addon_setup", + return_value=None, + ), + patch( + "homeassistant.components.homeassistant_hardware.firmware_config_flow.BaseFirmwareConfigFlow._install_firmware_step", autospec=True, - side_effect=mock_async_step_pick_firmware_zigbee, + side_effect=mock_install_firmware_step, ), patch( "homeassistant.components.homeassistant_hardware.firmware_config_flow.probe_silabs_firmware_info", return_value=FirmwareInfo( device=usb_data.device, - firmware_type=ApplicationType.EZSP, - firmware_version="7.4.4.0 build 0", + firmware_type=fw_type, + firmware_version=fw_version, owners=[], source="probe", ), @@ -68,15 +102,15 @@ async def mock_async_step_pick_firmware_zigbee(self, data): ): result = await hass.config_entries.flow.async_configure( result["flow_id"], - user_input={"next_step_id": STEP_PICK_FIRMWARE_ZIGBEE}, + user_input={"next_step_id": step}, ) assert result["type"] is FlowResultType.CREATE_ENTRY config_entry = result["result"] assert config_entry.data == { - "firmware": "ezsp", - "firmware_version": "7.4.4.0 build 0", + "firmware": fw_type.value, + "firmware_version": fw_version, "device": usb_data.device, "manufacturer": usb_data.manufacturer, "pid": usb_data.pid, @@ -86,13 +120,17 @@ async def mock_async_step_pick_firmware_zigbee(self, data): "vid": usb_data.vid, } - # Ensure a ZHA discovery flow has been created flows = hass.config_entries.flow.async_progress() - assert len(flows) == 1 - zha_flow = flows[0] - assert zha_flow["handler"] == "zha" - assert zha_flow["context"]["source"] == "hardware" - assert zha_flow["step_id"] == "confirm" + + if step == STEP_PICK_FIRMWARE_ZIGBEE: + # Ensure a ZHA discovery flow has been created + assert len(flows) == 1 + zha_flow = flows[0] + assert zha_flow["handler"] == "zha" + assert zha_flow["context"]["source"] == "hardware" + assert zha_flow["step_id"] == "confirm" + else: + assert len(flows) == 0 @pytest.mark.parametrize( diff --git a/tests/components/homeassistant_yellow/test_config_flow.py b/tests/components/homeassistant_yellow/test_config_flow.py index 1d5a64eafb9171..cd4a194105090e 100644 --- a/tests/components/homeassistant_yellow/test_config_flow.py +++ b/tests/components/homeassistant_yellow/test_config_flow.py @@ -11,6 +11,7 @@ AddonState, ) from homeassistant.components.homeassistant_hardware.firmware_config_flow import ( + STEP_PICK_FIRMWARE_THREAD, STEP_PICK_FIRMWARE_ZIGBEE, ) from homeassistant.components.homeassistant_hardware.silabs_multiprotocol_addon import ( @@ -23,6 +24,7 @@ FirmwareInfo, ) from homeassistant.components.homeassistant_yellow.const import DOMAIN, RADIO_DEVICE +from homeassistant.config_entries import ConfigFlowResult from homeassistant.core import HomeAssistant from homeassistant.data_entry_flow import FlowResultType from homeassistant.setup import async_setup_component @@ -305,7 +307,16 @@ async def test_option_flow_led_settings_fail_2( assert result["reason"] == "write_hw_settings_error" -async def test_firmware_options_flow(hass: HomeAssistant) -> None: +@pytest.mark.parametrize( + ("step", "fw_type", "fw_version"), + [ + (STEP_PICK_FIRMWARE_ZIGBEE, ApplicationType.EZSP, "7.4.4.0 build 0"), + (STEP_PICK_FIRMWARE_THREAD, ApplicationType.SPINEL, "2.4.4.0"), + ], +) +async def test_firmware_options_flow( + step: str, fw_type: ApplicationType, fw_version: str, hass: HomeAssistant +) -> None: """Test the firmware options flow for Yellow.""" mock_integration(hass, MockModule("hassio")) await async_setup_component(hass, HASSIO_DOMAIN, {}) @@ -339,18 +350,36 @@ async def test_firmware_options_flow(hass: HomeAssistant) -> None: async def mock_async_step_pick_firmware_zigbee(self, data): return await self.async_step_confirm_zigbee(user_input={}) + async def mock_install_firmware_step( + self, + fw_update_url: str, + fw_type: str, + firmware_name: str, + expected_installed_firmware_type: ApplicationType, + step_id: str, + next_step_id: str, + ) -> ConfigFlowResult: + if next_step_id == "start_otbr_addon": + next_step_id = "confirm_otbr" + + return await getattr(self, f"async_step_{next_step_id}")(user_input={}) + with ( patch( - "homeassistant.components.homeassistant_hardware.firmware_config_flow.BaseFirmwareOptionsFlow.async_step_pick_firmware_zigbee", + "homeassistant.components.homeassistant_hardware.firmware_config_flow.BaseFirmwareConfigFlow._ensure_thread_addon_setup", + return_value=None, + ), + patch( + "homeassistant.components.homeassistant_hardware.firmware_config_flow.BaseFirmwareInstallFlow._install_firmware_step", autospec=True, - side_effect=mock_async_step_pick_firmware_zigbee, + side_effect=mock_install_firmware_step, ), patch( "homeassistant.components.homeassistant_hardware.firmware_config_flow.probe_silabs_firmware_info", return_value=FirmwareInfo( device=RADIO_DEVICE, - firmware_type=ApplicationType.EZSP, - firmware_version="7.4.4.0 build 0", + firmware_type=fw_type, + firmware_version=fw_version, owners=[], source="probe", ), @@ -358,15 +387,15 @@ async def mock_async_step_pick_firmware_zigbee(self, data): ): result = await hass.config_entries.options.async_configure( result["flow_id"], - user_input={"next_step_id": STEP_PICK_FIRMWARE_ZIGBEE}, + user_input={"next_step_id": step}, ) assert result["type"] is FlowResultType.CREATE_ENTRY assert result["result"] is True assert config_entry.data == { - "firmware": "ezsp", - "firmware_version": "7.4.4.0 build 0", + "firmware": fw_type.value, + "firmware_version": fw_version, } diff --git a/tests/components/litterrobot/test_sensor.py b/tests/components/litterrobot/test_sensor.py index bbc6274e56b255..76c567f54179cb 100644 --- a/tests/components/litterrobot/test_sensor.py +++ b/tests/components/litterrobot/test_sensor.py @@ -5,7 +5,11 @@ import pytest from homeassistant.components.litterrobot.sensor import icon_for_gauge_level -from homeassistant.components.sensor import DOMAIN as PLATFORM_DOMAIN, SensorDeviceClass +from homeassistant.components.sensor import ( + DOMAIN as PLATFORM_DOMAIN, + SensorDeviceClass, + SensorStateClass, +) from homeassistant.const import PERCENTAGE, STATE_UNKNOWN, UnitOfMass from homeassistant.core import HomeAssistant @@ -70,6 +74,7 @@ async def test_gauge_icon() -> None: @pytest.mark.freeze_time("2022-09-18 23:00:44+00:00") +@pytest.mark.usefixtures("entity_registry_enabled_by_default") async def test_litter_robot_sensor( hass: HomeAssistant, mock_account_with_litterrobot_4: MagicMock ) -> None: @@ -94,6 +99,9 @@ async def test_litter_robot_sensor( sensor = hass.states.get("sensor.test_pet_weight") assert sensor.state == "12.0" assert sensor.attributes["unit_of_measurement"] == UnitOfMass.POUNDS + sensor = hass.states.get("sensor.test_total_cycles") + assert sensor.state == "158" + assert sensor.attributes["state_class"] == SensorStateClass.TOTAL_INCREASING async def test_feeder_robot_sensor( diff --git a/tests/components/ollama/conftest.py b/tests/components/ollama/conftest.py index 7658d1cbfab973..c99f586a5d44c4 100644 --- a/tests/components/ollama/conftest.py +++ b/tests/components/ollama/conftest.py @@ -30,7 +30,15 @@ def mock_config_entry( entry = MockConfigEntry( domain=ollama.DOMAIN, data=TEST_USER_DATA, - options=mock_config_entry_options, + version=2, + subentries_data=[ + { + "data": mock_config_entry_options, + "subentry_type": "conversation", + "title": "Ollama Conversation", + "unique_id": None, + } + ], ) entry.add_to_hass(hass) return entry @@ -41,8 +49,10 @@ def mock_config_entry_with_assist( hass: HomeAssistant, mock_config_entry: MockConfigEntry ) -> MockConfigEntry: """Mock a config entry with assist.""" - hass.config_entries.async_update_entry( - mock_config_entry, options={CONF_LLM_HASS_API: llm.LLM_API_ASSIST} + hass.config_entries.async_update_subentry( + mock_config_entry, + next(iter(mock_config_entry.subentries.values())), + data={CONF_LLM_HASS_API: llm.LLM_API_ASSIST}, ) return mock_config_entry diff --git a/tests/components/ollama/test_config_flow.py b/tests/components/ollama/test_config_flow.py index 34282f25e90491..4b78df9bce2d6f 100644 --- a/tests/components/ollama/test_config_flow.py +++ b/tests/components/ollama/test_config_flow.py @@ -63,6 +63,37 @@ async def test_form(hass: HomeAssistant) -> None: assert len(mock_setup_entry.mock_calls) == 1 +async def test_duplicate_entry(hass: HomeAssistant) -> None: + """Test we abort on duplicate config entry.""" + MockConfigEntry( + domain=ollama.DOMAIN, + data={ + ollama.CONF_URL: "http://localhost:11434", + ollama.CONF_MODEL: "test_model", + }, + ).add_to_hass(hass) + + result = await hass.config_entries.flow.async_init( + ollama.DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + assert result["type"] is FlowResultType.FORM + assert not result["errors"] + + with patch( + "homeassistant.components.ollama.config_flow.ollama.AsyncClient.list", + return_value={"models": [{"model": "test_model"}]}, + ): + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + { + ollama.CONF_URL: "http://localhost:11434", + }, + ) + + assert result["type"] is FlowResultType.ABORT + assert result["reason"] == "already_configured" + + async def test_form_need_download(hass: HomeAssistant) -> None: """Test flow when a model needs to be downloaded.""" # Pretend we already set up a config entry. @@ -155,14 +186,21 @@ async def pull(self, model: str, *args, **kwargs) -> None: assert len(mock_setup_entry.mock_calls) == 1 -async def test_options( +async def test_subentry_options( hass: HomeAssistant, mock_config_entry, mock_init_component ) -> None: - """Test the options form.""" - options_flow = await hass.config_entries.options.async_init( - mock_config_entry.entry_id + """Test the subentry options form.""" + subentry = next(iter(mock_config_entry.subentries.values())) + + # Test reconfiguration + options_flow = await mock_config_entry.start_subentry_reconfigure_flow( + hass, subentry.subentry_id ) - options = await hass.config_entries.options.async_configure( + + assert options_flow["type"] is FlowResultType.FORM + assert options_flow["step_id"] == "set_options" + + options = await hass.config_entries.subentries.async_configure( options_flow["flow_id"], { ollama.CONF_PROMPT: "test prompt", @@ -172,8 +210,10 @@ async def test_options( }, ) await hass.async_block_till_done() - assert options["type"] is FlowResultType.CREATE_ENTRY - assert options["data"] == { + + assert options["type"] is FlowResultType.ABORT + assert options["reason"] == "reconfigure_successful" + assert subentry.data == { ollama.CONF_PROMPT: "test prompt", ollama.CONF_MAX_HISTORY: 100, ollama.CONF_NUM_CTX: 32768, @@ -181,6 +221,22 @@ async def test_options( } +async def test_creating_conversation_subentry_not_loaded( + hass: HomeAssistant, + mock_init_component, + mock_config_entry: MockConfigEntry, +) -> None: + """Test creating a conversation subentry when entry is not loaded.""" + await hass.config_entries.async_unload(mock_config_entry.entry_id) + result = await hass.config_entries.subentries.async_init( + (mock_config_entry.entry_id, "conversation"), + context={"source": config_entries.SOURCE_USER}, + ) + + assert result["type"] is FlowResultType.ABORT + assert result["reason"] == "entry_not_loaded" + + @pytest.mark.parametrize( ("side_effect", "error"), [ diff --git a/tests/components/ollama/test_conversation.py b/tests/components/ollama/test_conversation.py index e83c2a3495f8e7..cebb185bd08c60 100644 --- a/tests/components/ollama/test_conversation.py +++ b/tests/components/ollama/test_conversation.py @@ -35,7 +35,7 @@ async def stream_generator(response: dict | list[dict]) -> AsyncGenerator[dict]: yield msg -@pytest.mark.parametrize("agent_id", [None, "conversation.mock_title"]) +@pytest.mark.parametrize("agent_id", [None, "conversation.ollama_conversation"]) async def test_chat( hass: HomeAssistant, mock_config_entry: MockConfigEntry, @@ -149,9 +149,11 @@ async def test_template_variables( mock_user.id = "12345" mock_user.name = "Test User" - hass.config_entries.async_update_entry( + subentry = next(iter(mock_config_entry.subentries.values())) + hass.config_entries.async_update_subentry( mock_config_entry, - options={ + subentry, + data={ "prompt": ( "The user name is {{ user_name }}. " "The user id is {{ llm_context.context.user_id }}." @@ -382,10 +384,12 @@ async def test_unknown_hass_api( mock_init_component, ) -> None: """Test when we reference an API that no longer exists.""" - hass.config_entries.async_update_entry( + subentry = next(iter(mock_config_entry.subentries.values())) + hass.config_entries.async_update_subentry( mock_config_entry, - options={ - **mock_config_entry.options, + subentry, + data={ + **subentry.data, CONF_LLM_HASS_API: "non-existing", }, ) @@ -518,8 +522,9 @@ def stream(*args, **kwargs) -> AsyncGenerator[dict]: with ( patch("ollama.AsyncClient.chat", side_effect=stream) as mock_chat, ): - hass.config_entries.async_update_entry( - mock_config_entry, options={ollama.CONF_MAX_HISTORY: 0} + subentry = next(iter(mock_config_entry.subentries.values())) + hass.config_entries.async_update_subentry( + mock_config_entry, subentry, data={ollama.CONF_MAX_HISTORY: 0} ) for i in range(100): result = await conversation.async_converse( @@ -563,9 +568,11 @@ async def test_template_error( hass: HomeAssistant, mock_config_entry: MockConfigEntry ) -> None: """Test that template error handling works.""" - hass.config_entries.async_update_entry( + subentry = next(iter(mock_config_entry.subentries.values())) + hass.config_entries.async_update_subentry( mock_config_entry, - options={ + subentry, + data={ "prompt": "talk like a {% if True %}smarthome{% else %}pirate please.", }, ) @@ -593,7 +600,7 @@ async def test_conversation_agent( ) assert agent.supported_languages == MATCH_ALL - state = hass.states.get("conversation.mock_title") + state = hass.states.get("conversation.ollama_conversation") assert state assert state.attributes[ATTR_SUPPORTED_FEATURES] == 0 @@ -609,7 +616,7 @@ async def test_conversation_agent_with_assist( ) assert agent.supported_languages == MATCH_ALL - state = hass.states.get("conversation.mock_title") + state = hass.states.get("conversation.ollama_conversation") assert state assert ( state.attributes[ATTR_SUPPORTED_FEATURES] @@ -642,7 +649,7 @@ async def test_options( "test message", None, Context(), - agent_id="conversation.mock_title", + agent_id="conversation.ollama_conversation", ) assert mock_chat.call_count == 1 @@ -667,9 +674,11 @@ async def test_reasoning_filter( entry = MockConfigEntry() entry.add_to_hass(hass) - hass.config_entries.async_update_entry( + subentry = next(iter(mock_config_entry.subentries.values())) + hass.config_entries.async_update_subentry( mock_config_entry, - options={ + subentry, + data={ ollama.CONF_THINK: think, }, ) diff --git a/tests/components/ollama/test_init.py b/tests/components/ollama/test_init.py index d1074226837812..e11eb98451a345 100644 --- a/tests/components/ollama/test_init.py +++ b/tests/components/ollama/test_init.py @@ -6,9 +6,13 @@ import pytest from homeassistant.components import ollama +from homeassistant.components.ollama.const import DOMAIN from homeassistant.core import HomeAssistant +from homeassistant.helpers import device_registry as dr, entity_registry as er from homeassistant.setup import async_setup_component +from . import TEST_OPTIONS, TEST_USER_DATA + from tests.common import MockConfigEntry @@ -34,3 +38,250 @@ async def test_init_error( assert await async_setup_component(hass, ollama.DOMAIN, {}) await hass.async_block_till_done() assert error in caplog.text + + +async def test_migration_from_v1_to_v2( + hass: HomeAssistant, + device_registry: dr.DeviceRegistry, + entity_registry: er.EntityRegistry, +) -> None: + """Test migration from version 1 to version 2.""" + # Create a v1 config entry with conversation options and an entity + mock_config_entry = MockConfigEntry( + domain=DOMAIN, + data=TEST_USER_DATA, + options=TEST_OPTIONS, + version=1, + title="llama-3.2-8b", + ) + mock_config_entry.add_to_hass(hass) + + device = device_registry.async_get_or_create( + config_entry_id=mock_config_entry.entry_id, + identifiers={(DOMAIN, mock_config_entry.entry_id)}, + name=mock_config_entry.title, + manufacturer="Ollama", + model="Ollama", + entry_type=dr.DeviceEntryType.SERVICE, + ) + entity = entity_registry.async_get_or_create( + "conversation", + DOMAIN, + mock_config_entry.entry_id, + config_entry=mock_config_entry, + device_id=device.id, + suggested_object_id="llama_3_2_8b", + ) + + # Run migration + with patch( + "homeassistant.components.ollama.async_setup_entry", + return_value=True, + ): + await hass.config_entries.async_setup(mock_config_entry.entry_id) + + assert mock_config_entry.version == 2 + assert mock_config_entry.data == TEST_USER_DATA + assert mock_config_entry.options == {} + + assert len(mock_config_entry.subentries) == 1 + + subentry = next(iter(mock_config_entry.subentries.values())) + assert subentry.unique_id is None + assert subentry.title == "llama-3.2-8b" + assert subentry.subentry_type == "conversation" + assert subentry.data == TEST_OPTIONS + + migrated_entity = entity_registry.async_get(entity.entity_id) + assert migrated_entity is not None + assert migrated_entity.config_entry_id == mock_config_entry.entry_id + assert migrated_entity.config_subentry_id == subentry.subentry_id + assert migrated_entity.unique_id == subentry.subentry_id + + # Check device migration + assert not device_registry.async_get_device( + identifiers={(DOMAIN, mock_config_entry.entry_id)} + ) + assert ( + migrated_device := device_registry.async_get_device( + identifiers={(DOMAIN, subentry.subentry_id)} + ) + ) + assert migrated_device.identifiers == {(DOMAIN, subentry.subentry_id)} + assert migrated_device.id == device.id + + +async def test_migration_from_v1_to_v2_with_multiple_urls( + hass: HomeAssistant, + device_registry: dr.DeviceRegistry, + entity_registry: er.EntityRegistry, +) -> None: + """Test migration from version 1 to version 2 with different URLs.""" + # Create two v1 config entries with different URLs + mock_config_entry = MockConfigEntry( + domain=DOMAIN, + data={"url": "http://localhost:11434", "model": "llama3.2:latest"}, + options=TEST_OPTIONS, + version=1, + title="Ollama 1", + ) + mock_config_entry.add_to_hass(hass) + mock_config_entry_2 = MockConfigEntry( + domain=DOMAIN, + data={"url": "http://localhost:11435", "model": "llama3.2:latest"}, + options=TEST_OPTIONS, + version=1, + title="Ollama 2", + ) + mock_config_entry_2.add_to_hass(hass) + + device = device_registry.async_get_or_create( + config_entry_id=mock_config_entry.entry_id, + identifiers={(DOMAIN, mock_config_entry.entry_id)}, + name=mock_config_entry.title, + manufacturer="Ollama", + model="Ollama 1", + entry_type=dr.DeviceEntryType.SERVICE, + ) + entity_registry.async_get_or_create( + "conversation", + DOMAIN, + mock_config_entry.entry_id, + config_entry=mock_config_entry, + device_id=device.id, + suggested_object_id="ollama_1", + ) + + device_2 = device_registry.async_get_or_create( + config_entry_id=mock_config_entry_2.entry_id, + identifiers={(DOMAIN, mock_config_entry_2.entry_id)}, + name=mock_config_entry_2.title, + manufacturer="Ollama", + model="Ollama 2", + entry_type=dr.DeviceEntryType.SERVICE, + ) + entity_registry.async_get_or_create( + "conversation", + DOMAIN, + mock_config_entry_2.entry_id, + config_entry=mock_config_entry_2, + device_id=device_2.id, + suggested_object_id="ollama_2", + ) + + # Run migration + with patch( + "homeassistant.components.ollama.async_setup_entry", + return_value=True, + ): + await hass.config_entries.async_setup(mock_config_entry.entry_id) + await hass.async_block_till_done() + + entries = hass.config_entries.async_entries(DOMAIN) + assert len(entries) == 2 + + for idx, entry in enumerate(entries): + assert entry.version == 2 + assert not entry.options + assert len(entry.subentries) == 1 + subentry = list(entry.subentries.values())[0] + assert subentry.subentry_type == "conversation" + assert subentry.data == TEST_OPTIONS + assert subentry.title == f"Ollama {idx + 1}" + + dev = device_registry.async_get_device( + identifiers={(DOMAIN, list(entry.subentries.values())[0].subentry_id)} + ) + assert dev is not None + + +async def test_migration_from_v1_to_v2_with_same_urls( + hass: HomeAssistant, + device_registry: dr.DeviceRegistry, + entity_registry: er.EntityRegistry, +) -> None: + """Test migration from version 1 to version 2 with same URLs consolidates entries.""" + # Create two v1 config entries with the same URL + mock_config_entry = MockConfigEntry( + domain=DOMAIN, + data={"url": "http://localhost:11434", "model": "llama3.2:latest"}, + options=TEST_OPTIONS, + version=1, + title="Ollama", + ) + mock_config_entry.add_to_hass(hass) + mock_config_entry_2 = MockConfigEntry( + domain=DOMAIN, + data={"url": "http://localhost:11434", "model": "llama3.2:latest"}, # Same URL + options=TEST_OPTIONS, + version=1, + title="Ollama 2", + ) + mock_config_entry_2.add_to_hass(hass) + + device = device_registry.async_get_or_create( + config_entry_id=mock_config_entry.entry_id, + identifiers={(DOMAIN, mock_config_entry.entry_id)}, + name=mock_config_entry.title, + manufacturer="Ollama", + model="Ollama", + entry_type=dr.DeviceEntryType.SERVICE, + ) + entity_registry.async_get_or_create( + "conversation", + DOMAIN, + mock_config_entry.entry_id, + config_entry=mock_config_entry, + device_id=device.id, + suggested_object_id="ollama", + ) + + device_2 = device_registry.async_get_or_create( + config_entry_id=mock_config_entry_2.entry_id, + identifiers={(DOMAIN, mock_config_entry_2.entry_id)}, + name=mock_config_entry_2.title, + manufacturer="Ollama", + model="Ollama", + entry_type=dr.DeviceEntryType.SERVICE, + ) + entity_registry.async_get_or_create( + "conversation", + DOMAIN, + mock_config_entry_2.entry_id, + config_entry=mock_config_entry_2, + device_id=device_2.id, + suggested_object_id="ollama_2", + ) + + # Run migration + with patch( + "homeassistant.components.ollama.async_setup_entry", + return_value=True, + ): + await hass.config_entries.async_setup(mock_config_entry.entry_id) + await hass.async_block_till_done() + + # Should have only one entry left (consolidated) + entries = hass.config_entries.async_entries(DOMAIN) + assert len(entries) == 1 + + entry = entries[0] + assert entry.version == 2 + assert not entry.options + assert len(entry.subentries) == 2 # Two subentries from the two original entries + + # Check both subentries exist with correct data + subentries = list(entry.subentries.values()) + titles = [sub.title for sub in subentries] + assert "Ollama" in titles + assert "Ollama 2" in titles + + for subentry in subentries: + assert subentry.subentry_type == "conversation" + assert subentry.data == TEST_OPTIONS + + # Check devices were migrated correctly + dev = device_registry.async_get_device( + identifiers={(DOMAIN, subentry.subentry_id)} + ) + assert dev is not None diff --git a/tests/components/openai_conversation/conftest.py b/tests/components/openai_conversation/conftest.py index 4639d0dc8e0854..aa17c333a795a8 100644 --- a/tests/components/openai_conversation/conftest.py +++ b/tests/components/openai_conversation/conftest.py @@ -4,6 +4,7 @@ import pytest +from homeassistant.components.openai_conversation.const import DEFAULT_CONVERSATION_NAME from homeassistant.const import CONF_LLM_HASS_API from homeassistant.core import HomeAssistant from homeassistant.helpers import llm @@ -21,6 +22,15 @@ def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry: data={ "api_key": "bla", }, + version=2, + subentries_data=[ + { + "data": {}, + "subentry_type": "conversation", + "title": DEFAULT_CONVERSATION_NAME, + "unique_id": None, + } + ], ) entry.add_to_hass(hass) return entry @@ -31,8 +41,10 @@ def mock_config_entry_with_assist( hass: HomeAssistant, mock_config_entry: MockConfigEntry ) -> MockConfigEntry: """Mock a config entry with assist.""" - hass.config_entries.async_update_entry( - mock_config_entry, options={CONF_LLM_HASS_API: llm.LLM_API_ASSIST} + hass.config_entries.async_update_subentry( + mock_config_entry, + next(iter(mock_config_entry.subentries.values())), + data={CONF_LLM_HASS_API: llm.LLM_API_ASSIST}, ) return mock_config_entry diff --git a/tests/components/openai_conversation/snapshots/test_conversation.ambr b/tests/components/openai_conversation/snapshots/test_conversation.ambr index 0f874969aff8e7..48ad0878b2faae 100644 --- a/tests/components/openai_conversation/snapshots/test_conversation.ambr +++ b/tests/components/openai_conversation/snapshots/test_conversation.ambr @@ -6,7 +6,7 @@ 'role': 'user', }), dict({ - 'agent_id': 'conversation.openai', + 'agent_id': 'conversation.openai_conversation', 'content': None, 'role': 'assistant', 'tool_calls': list([ @@ -20,14 +20,14 @@ ]), }), dict({ - 'agent_id': 'conversation.openai', + 'agent_id': 'conversation.openai_conversation', 'role': 'tool_result', 'tool_call_id': 'call_call_1', 'tool_name': 'test_tool', 'tool_result': 'value1', }), dict({ - 'agent_id': 'conversation.openai', + 'agent_id': 'conversation.openai_conversation', 'content': None, 'role': 'assistant', 'tool_calls': list([ @@ -41,14 +41,14 @@ ]), }), dict({ - 'agent_id': 'conversation.openai', + 'agent_id': 'conversation.openai_conversation', 'role': 'tool_result', 'tool_call_id': 'call_call_2', 'tool_name': 'test_tool', 'tool_result': 'value2', }), dict({ - 'agent_id': 'conversation.openai', + 'agent_id': 'conversation.openai_conversation', 'content': 'Cool', 'role': 'assistant', 'tool_calls': None, @@ -62,7 +62,7 @@ 'role': 'user', }), dict({ - 'agent_id': 'conversation.openai', + 'agent_id': 'conversation.openai_conversation', 'content': None, 'role': 'assistant', 'tool_calls': list([ @@ -76,14 +76,14 @@ ]), }), dict({ - 'agent_id': 'conversation.openai', + 'agent_id': 'conversation.openai_conversation', 'role': 'tool_result', 'tool_call_id': 'call_call_1', 'tool_name': 'test_tool', 'tool_result': 'value1', }), dict({ - 'agent_id': 'conversation.openai', + 'agent_id': 'conversation.openai_conversation', 'content': 'Cool', 'role': 'assistant', 'tool_calls': None, diff --git a/tests/components/openai_conversation/test_config_flow.py b/tests/components/openai_conversation/test_config_flow.py index ad5bbffaed360f..b77542fbab31c3 100644 --- a/tests/components/openai_conversation/test_config_flow.py +++ b/tests/components/openai_conversation/test_config_flow.py @@ -24,12 +24,13 @@ CONF_WEB_SEARCH_REGION, CONF_WEB_SEARCH_TIMEZONE, CONF_WEB_SEARCH_USER_LOCATION, + DEFAULT_CONVERSATION_NAME, DOMAIN, RECOMMENDED_CHAT_MODEL, RECOMMENDED_MAX_TOKENS, RECOMMENDED_TOP_P, ) -from homeassistant.const import CONF_LLM_HASS_API +from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API from homeassistant.core import HomeAssistant from homeassistant.data_entry_flow import FlowResultType @@ -72,42 +73,132 @@ async def test_form(hass: HomeAssistant) -> None: assert result2["data"] == { "api_key": "bla", } - assert result2["options"] == RECOMMENDED_OPTIONS + assert result2["options"] == {} + assert result2["subentries"] == [ + { + "subentry_type": "conversation", + "data": RECOMMENDED_OPTIONS, + "title": DEFAULT_CONVERSATION_NAME, + "unique_id": None, + } + ] assert len(mock_setup_entry.mock_calls) == 1 -async def test_options_recommended( +async def test_duplicate_entry(hass: HomeAssistant) -> None: + """Test we abort on duplicate config entry.""" + MockConfigEntry( + domain=DOMAIN, + data={CONF_API_KEY: "bla"}, + ).add_to_hass(hass) + + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + assert result["type"] is FlowResultType.FORM + assert not result["errors"] + + with patch( + "homeassistant.components.openai_conversation.config_flow.openai.resources.models.AsyncModels.list", + ): + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + { + CONF_API_KEY: "bla", + }, + ) + + assert result["type"] is FlowResultType.ABORT + assert result["reason"] == "already_configured" + + +async def test_creating_conversation_subentry( + hass: HomeAssistant, + mock_init_component: None, + mock_config_entry: MockConfigEntry, +) -> None: + """Test creating a conversation subentry.""" + mock_config_entry.add_to_hass(hass) + + result = await hass.config_entries.subentries.async_init( + (mock_config_entry.entry_id, "conversation"), + context={"source": config_entries.SOURCE_USER}, + ) + + assert result["type"] is FlowResultType.FORM + assert result["step_id"] == "init" + assert not result["errors"] + + result2 = await hass.config_entries.subentries.async_configure( + result["flow_id"], + {"name": "My Custom Agent", **RECOMMENDED_OPTIONS}, + ) + await hass.async_block_till_done() + + assert result2["type"] is FlowResultType.CREATE_ENTRY + assert result2["title"] == "My Custom Agent" + + processed_options = RECOMMENDED_OPTIONS.copy() + processed_options[CONF_PROMPT] = processed_options[CONF_PROMPT].strip() + + assert result2["data"] == processed_options + + +async def test_creating_conversation_subentry_not_loaded( + hass: HomeAssistant, + mock_init_component, + mock_config_entry: MockConfigEntry, +) -> None: + """Test creating a conversation subentry when entry is not loaded.""" + await hass.config_entries.async_unload(mock_config_entry.entry_id) + with patch( + "homeassistant.components.openai_conversation.config_flow.openai.resources.models.AsyncModels.list", + return_value=[], + ): + result = await hass.config_entries.subentries.async_init( + (mock_config_entry.entry_id, "conversation"), + context={"source": config_entries.SOURCE_USER}, + ) + + assert result["type"] is FlowResultType.ABORT + assert result["reason"] == "entry_not_loaded" + + +async def test_subentry_recommended( hass: HomeAssistant, mock_config_entry, mock_init_component ) -> None: - """Test the options flow with recommended settings.""" - options_flow = await hass.config_entries.options.async_init( - mock_config_entry.entry_id + """Test the subentry flow with recommended settings.""" + subentry = next(iter(mock_config_entry.subentries.values())) + subentry_flow = await mock_config_entry.start_subentry_reconfigure_flow( + hass, subentry.subentry_id ) - options = await hass.config_entries.options.async_configure( - options_flow["flow_id"], + options = await hass.config_entries.subentries.async_configure( + subentry_flow["flow_id"], { "prompt": "Speak like a pirate", "recommended": True, }, ) await hass.async_block_till_done() - assert options["type"] is FlowResultType.CREATE_ENTRY - assert options["data"]["prompt"] == "Speak like a pirate" + assert options["type"] is FlowResultType.ABORT + assert options["reason"] == "reconfigure_successful" + assert subentry.data["prompt"] == "Speak like a pirate" -async def test_options_unsupported_model( +async def test_subentry_unsupported_model( hass: HomeAssistant, mock_config_entry, mock_init_component ) -> None: - """Test the options form giving error about models not supported.""" - options_flow = await hass.config_entries.options.async_init( - mock_config_entry.entry_id + """Test the subentry form giving error about models not supported.""" + subentry = next(iter(mock_config_entry.subentries.values())) + subentry_flow = await mock_config_entry.start_subentry_reconfigure_flow( + hass, subentry.subentry_id ) - assert options_flow["type"] == FlowResultType.FORM - assert options_flow["step_id"] == "init" + assert subentry_flow["type"] == FlowResultType.FORM + assert subentry_flow["step_id"] == "init" # Configure initial step - options_flow = await hass.config_entries.options.async_configure( - options_flow["flow_id"], + subentry_flow = await hass.config_entries.subentries.async_configure( + subentry_flow["flow_id"], { CONF_RECOMMENDED: False, CONF_PROMPT: "Speak like a pirate", @@ -115,19 +206,19 @@ async def test_options_unsupported_model( }, ) await hass.async_block_till_done() - assert options_flow["type"] == FlowResultType.FORM - assert options_flow["step_id"] == "advanced" + assert subentry_flow["type"] == FlowResultType.FORM + assert subentry_flow["step_id"] == "advanced" # Configure advanced step - options_flow = await hass.config_entries.options.async_configure( - options_flow["flow_id"], + subentry_flow = await hass.config_entries.subentries.async_configure( + subentry_flow["flow_id"], { CONF_CHAT_MODEL: "o1-mini", }, ) await hass.async_block_till_done() - assert options_flow["type"] is FlowResultType.FORM - assert options_flow["errors"] == {"chat_model": "model_not_supported"} + assert subentry_flow["type"] is FlowResultType.FORM + assert subentry_flow["errors"] == {"chat_model": "model_not_supported"} @pytest.mark.parametrize( @@ -494,7 +585,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non ), ], ) -async def test_options_switching( +async def test_subentry_switching( hass: HomeAssistant, mock_config_entry, mock_init_component, @@ -502,16 +593,22 @@ async def test_options_switching( new_options, expected_options, ) -> None: - """Test the options form.""" - hass.config_entries.async_update_entry(mock_config_entry, options=current_options) - options = await hass.config_entries.options.async_init(mock_config_entry.entry_id) - assert options["step_id"] == "init" + """Test the subentry form.""" + subentry = next(iter(mock_config_entry.subentries.values())) + hass.config_entries.async_update_subentry( + mock_config_entry, subentry, data=current_options + ) + await hass.async_block_till_done() + subentry_flow = await mock_config_entry.start_subentry_reconfigure_flow( + hass, subentry.subentry_id + ) + assert subentry_flow["step_id"] == "init" for step_options in new_options: - assert options["type"] == FlowResultType.FORM + assert subentry_flow["type"] == FlowResultType.FORM # Test that current options are showed as suggested values: - for key in options["data_schema"].schema: + for key in subentry_flow["data_schema"].schema: if ( isinstance(key.description, dict) and "suggested_value" in key.description @@ -523,38 +620,42 @@ async def test_options_switching( assert key.description["suggested_value"] == current_option # Configure current step - options = await hass.config_entries.options.async_configure( - options["flow_id"], + subentry_flow = await hass.config_entries.subentries.async_configure( + subentry_flow["flow_id"], step_options, ) await hass.async_block_till_done() - assert options["type"] is FlowResultType.CREATE_ENTRY - assert options["data"] == expected_options + assert subentry_flow["type"] is FlowResultType.ABORT + assert subentry_flow["reason"] == "reconfigure_successful" + assert subentry.data == expected_options -async def test_options_web_search_user_location( +async def test_subentry_web_search_user_location( hass: HomeAssistant, mock_config_entry, mock_init_component ) -> None: """Test fetching user location.""" - options = await hass.config_entries.options.async_init(mock_config_entry.entry_id) - assert options["type"] == FlowResultType.FORM - assert options["step_id"] == "init" + subentry = next(iter(mock_config_entry.subentries.values())) + subentry_flow = await mock_config_entry.start_subentry_reconfigure_flow( + hass, subentry.subentry_id + ) + assert subentry_flow["type"] == FlowResultType.FORM + assert subentry_flow["step_id"] == "init" # Configure initial step - options = await hass.config_entries.options.async_configure( - options["flow_id"], + subentry_flow = await hass.config_entries.subentries.async_configure( + subentry_flow["flow_id"], { CONF_RECOMMENDED: False, CONF_PROMPT: "Speak like a pirate", }, ) - assert options["type"] == FlowResultType.FORM - assert options["step_id"] == "advanced" + assert subentry_flow["type"] == FlowResultType.FORM + assert subentry_flow["step_id"] == "advanced" # Configure advanced step - options = await hass.config_entries.options.async_configure( - options["flow_id"], + subentry_flow = await hass.config_entries.subentries.async_configure( + subentry_flow["flow_id"], { CONF_TEMPERATURE: 1.0, CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL, @@ -563,8 +664,8 @@ async def test_options_web_search_user_location( }, ) await hass.async_block_till_done() - assert options["type"] == FlowResultType.FORM - assert options["step_id"] == "model" + assert subentry_flow["type"] == FlowResultType.FORM + assert subentry_flow["step_id"] == "model" hass.config.country = "US" hass.config.time_zone = "America/Los_Angeles" @@ -601,8 +702,8 @@ async def test_options_web_search_user_location( ) # Configure model step - options = await hass.config_entries.options.async_configure( - options["flow_id"], + subentry_flow = await hass.config_entries.subentries.async_configure( + subentry_flow["flow_id"], { CONF_WEB_SEARCH: True, CONF_WEB_SEARCH_CONTEXT_SIZE: "medium", @@ -614,8 +715,9 @@ async def test_options_web_search_user_location( mock_create.call_args.kwargs["input"][0]["content"] == "Where are the following" " coordinates located: (37.7749, -122.4194)?" ) - assert options["type"] is FlowResultType.CREATE_ENTRY - assert options["data"] == { + assert subentry_flow["type"] is FlowResultType.ABORT + assert subentry_flow["reason"] == "reconfigure_successful" + assert subentry.data == { CONF_RECOMMENDED: False, CONF_PROMPT: "Speak like a pirate", CONF_TEMPERATURE: 1.0, diff --git a/tests/components/openai_conversation/test_conversation.py b/tests/components/openai_conversation/test_conversation.py index 99559cb3b61277..8621465bd146ad 100644 --- a/tests/components/openai_conversation/test_conversation.py +++ b/tests/components/openai_conversation/test_conversation.py @@ -153,20 +153,18 @@ async def test_entity( mock_init_component, ) -> None: """Test entity properties.""" - state = hass.states.get("conversation.openai") + state = hass.states.get("conversation.openai_conversation") assert state assert state.attributes["supported_features"] == 0 - hass.config_entries.async_update_entry( + hass.config_entries.async_update_subentry( mock_config_entry, - options={ - **mock_config_entry.options, - CONF_LLM_HASS_API: "assist", - }, + next(iter(mock_config_entry.subentries.values())), + data={CONF_LLM_HASS_API: "assist"}, ) await hass.config_entries.async_reload(mock_config_entry.entry_id) - state = hass.states.get("conversation.openai") + state = hass.states.get("conversation.openai_conversation") assert state assert ( state.attributes["supported_features"] @@ -261,7 +259,7 @@ async def test_incomplete_response( "Please tell me a big story", "mock-conversation-id", Context(), - agent_id="conversation.openai", + agent_id="conversation.openai_conversation", ) assert result.response.response_type == intent.IntentResponseType.ERROR, result @@ -285,7 +283,7 @@ async def test_incomplete_response( "please tell me a big story", "mock-conversation-id", Context(), - agent_id="conversation.openai", + agent_id="conversation.openai_conversation", ) assert result.response.response_type == intent.IntentResponseType.ERROR, result @@ -324,7 +322,7 @@ async def test_failed_response( "next natural number please", "mock-conversation-id", Context(), - agent_id="conversation.openai", + agent_id="conversation.openai_conversation", ) assert result.response.response_type == intent.IntentResponseType.ERROR, result @@ -583,7 +581,7 @@ async def test_function_call( "Please call the test function", mock_chat_log.conversation_id, Context(), - agent_id="conversation.openai", + agent_id="conversation.openai_conversation", ) assert mock_create_stream.call_args.kwargs["input"][2] == { @@ -630,7 +628,7 @@ async def test_function_call_without_reasoning( "Please call the test function", mock_chat_log.conversation_id, Context(), - agent_id="conversation.openai", + agent_id="conversation.openai_conversation", ) assert result.response.response_type == intent.IntentResponseType.ACTION_DONE @@ -686,7 +684,7 @@ async def test_function_call_invalid( "Please call the test function", "mock-conversation-id", Context(), - agent_id="conversation.openai", + agent_id="conversation.openai_conversation", ) @@ -720,7 +718,7 @@ async def test_assist_api_tools_conversion( ] await conversation.async_converse( - hass, "hello", None, Context(), agent_id="conversation.openai" + hass, "hello", None, Context(), agent_id="conversation.openai_conversation" ) tools = mock_create_stream.mock_calls[0][2]["tools"] @@ -735,10 +733,12 @@ async def test_web_search( mock_chat_log: MockChatLog, # noqa: F811 ) -> None: """Test web_search_tool.""" - hass.config_entries.async_update_entry( + subentry = next(iter(mock_config_entry.subentries.values())) + hass.config_entries.async_update_subentry( mock_config_entry, - options={ - **mock_config_entry.options, + subentry, + data={ + **subentry.data, CONF_WEB_SEARCH: True, CONF_WEB_SEARCH_CONTEXT_SIZE: "low", CONF_WEB_SEARCH_USER_LOCATION: True, @@ -764,7 +764,7 @@ async def test_web_search( "What's on the latest news?", mock_chat_log.conversation_id, Context(), - agent_id="conversation.openai", + agent_id="conversation.openai_conversation", ) assert mock_create_stream.mock_calls[0][2]["tools"] == [ diff --git a/tests/components/openai_conversation/test_init.py b/tests/components/openai_conversation/test_init.py index b4f816707e9daf..d209554e8d3502 100644 --- a/tests/components/openai_conversation/test_init.py +++ b/tests/components/openai_conversation/test_init.py @@ -15,8 +15,10 @@ import pytest from homeassistant.components.openai_conversation import CONF_FILENAMES +from homeassistant.components.openai_conversation.const import DOMAIN from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError, ServiceValidationError +from homeassistant.helpers import device_registry as dr, entity_registry as er from homeassistant.setup import async_setup_component from tests.common import MockConfigEntry @@ -536,3 +538,271 @@ async def test_generate_content_service_error( blocking=True, return_response=True, ) + + +async def test_migration_from_v1_to_v2( + hass: HomeAssistant, + device_registry: dr.DeviceRegistry, + entity_registry: er.EntityRegistry, +) -> None: + """Test migration from version 1 to version 2.""" + # Create a v1 config entry with conversation options and an entity + OPTIONS = { + "recommended": True, + "llm_hass_api": ["assist"], + "prompt": "You are a helpful assistant", + "chat_model": "gpt-4o-mini", + } + mock_config_entry = MockConfigEntry( + domain=DOMAIN, + data={"api_key": "1234"}, + options=OPTIONS, + version=1, + title="ChatGPT", + ) + mock_config_entry.add_to_hass(hass) + + device = device_registry.async_get_or_create( + config_entry_id=mock_config_entry.entry_id, + identifiers={(DOMAIN, mock_config_entry.entry_id)}, + name=mock_config_entry.title, + manufacturer="OpenAI", + model="ChatGPT", + entry_type=dr.DeviceEntryType.SERVICE, + ) + entity = entity_registry.async_get_or_create( + "conversation", + DOMAIN, + mock_config_entry.entry_id, + config_entry=mock_config_entry, + device_id=device.id, + suggested_object_id="google_generative_ai_conversation", + ) + + # Run migration + with patch( + "homeassistant.components.openai_conversation.async_setup_entry", + return_value=True, + ): + await hass.config_entries.async_setup(mock_config_entry.entry_id) + await hass.async_block_till_done() + + assert mock_config_entry.version == 2 + assert mock_config_entry.data == {"api_key": "1234"} + assert mock_config_entry.options == {} + + assert len(mock_config_entry.subentries) == 1 + + subentry = next(iter(mock_config_entry.subentries.values())) + assert subentry.unique_id is None + assert subentry.title == "ChatGPT" + assert subentry.subentry_type == "conversation" + assert subentry.data == OPTIONS + + migrated_entity = entity_registry.async_get(entity.entity_id) + assert migrated_entity is not None + assert migrated_entity.config_entry_id == mock_config_entry.entry_id + assert migrated_entity.config_subentry_id == subentry.subentry_id + assert migrated_entity.unique_id == subentry.subentry_id + + # Check device migration + assert not device_registry.async_get_device( + identifiers={(DOMAIN, mock_config_entry.entry_id)} + ) + assert ( + migrated_device := device_registry.async_get_device( + identifiers={(DOMAIN, subentry.subentry_id)} + ) + ) + assert migrated_device.identifiers == {(DOMAIN, subentry.subentry_id)} + assert migrated_device.id == device.id + + +async def test_migration_from_v1_to_v2_with_multiple_keys( + hass: HomeAssistant, + device_registry: dr.DeviceRegistry, + entity_registry: er.EntityRegistry, +) -> None: + """Test migration from version 1 to version 2 with different API keys.""" + # Create two v1 config entries with different API keys + options = { + "recommended": True, + "llm_hass_api": ["assist"], + "prompt": "You are a helpful assistant", + "chat_model": "gpt-4o-mini", + } + mock_config_entry = MockConfigEntry( + domain=DOMAIN, + data={"api_key": "1234"}, + options=options, + version=1, + title="ChatGPT 1", + ) + mock_config_entry.add_to_hass(hass) + mock_config_entry_2 = MockConfigEntry( + domain=DOMAIN, + data={"api_key": "12345"}, + options=options, + version=1, + title="ChatGPT 2", + ) + mock_config_entry_2.add_to_hass(hass) + + device = device_registry.async_get_or_create( + config_entry_id=mock_config_entry.entry_id, + identifiers={(DOMAIN, mock_config_entry.entry_id)}, + name=mock_config_entry.title, + manufacturer="OpenAI", + model="ChatGPT 1", + entry_type=dr.DeviceEntryType.SERVICE, + ) + entity_registry.async_get_or_create( + "conversation", + DOMAIN, + mock_config_entry.entry_id, + config_entry=mock_config_entry, + device_id=device.id, + suggested_object_id="chatgpt_1", + ) + + device_2 = device_registry.async_get_or_create( + config_entry_id=mock_config_entry_2.entry_id, + identifiers={(DOMAIN, mock_config_entry_2.entry_id)}, + name=mock_config_entry_2.title, + manufacturer="OpenAI", + model="ChatGPT 2", + entry_type=dr.DeviceEntryType.SERVICE, + ) + entity_registry.async_get_or_create( + "conversation", + DOMAIN, + mock_config_entry_2.entry_id, + config_entry=mock_config_entry_2, + device_id=device_2.id, + suggested_object_id="chatgpt_2", + ) + + # Run migration + with patch( + "homeassistant.components.openai_conversation.async_setup_entry", + return_value=True, + ): + await hass.config_entries.async_setup(mock_config_entry.entry_id) + await hass.async_block_till_done() + await hass.async_block_till_done() + + entries = hass.config_entries.async_entries(DOMAIN) + assert len(entries) == 2 + + for idx, entry in enumerate(entries): + assert entry.version == 2 + assert not entry.options + assert len(entry.subentries) == 1 + subentry = list(entry.subentries.values())[0] + assert subentry.subentry_type == "conversation" + assert subentry.data == options + assert subentry.title == f"ChatGPT {idx + 1}" + + dev = device_registry.async_get_device( + identifiers={(DOMAIN, list(entry.subentries.values())[0].subentry_id)} + ) + assert dev is not None + + +async def test_migration_from_v1_to_v2_with_same_keys( + hass: HomeAssistant, + device_registry: dr.DeviceRegistry, + entity_registry: er.EntityRegistry, +) -> None: + """Test migration from version 1 to version 2 with same API keys consolidates entries.""" + # Create two v1 config entries with the same API key + options = { + "recommended": True, + "llm_hass_api": ["assist"], + "prompt": "You are a helpful assistant", + "chat_model": "gpt-4o-mini", + } + mock_config_entry = MockConfigEntry( + domain=DOMAIN, + data={"api_key": "1234"}, + options=options, + version=1, + title="ChatGPT", + ) + mock_config_entry.add_to_hass(hass) + mock_config_entry_2 = MockConfigEntry( + domain=DOMAIN, + data={"api_key": "1234"}, # Same API key + options=options, + version=1, + title="ChatGPT 2", + ) + mock_config_entry_2.add_to_hass(hass) + + device = device_registry.async_get_or_create( + config_entry_id=mock_config_entry.entry_id, + identifiers={(DOMAIN, mock_config_entry.entry_id)}, + name=mock_config_entry.title, + manufacturer="OpenAI", + model="ChatGPT", + entry_type=dr.DeviceEntryType.SERVICE, + ) + entity_registry.async_get_or_create( + "conversation", + DOMAIN, + mock_config_entry.entry_id, + config_entry=mock_config_entry, + device_id=device.id, + suggested_object_id="chatgpt", + ) + + device_2 = device_registry.async_get_or_create( + config_entry_id=mock_config_entry_2.entry_id, + identifiers={(DOMAIN, mock_config_entry_2.entry_id)}, + name=mock_config_entry_2.title, + manufacturer="OpenAI", + model="ChatGPT", + entry_type=dr.DeviceEntryType.SERVICE, + ) + entity_registry.async_get_or_create( + "conversation", + DOMAIN, + mock_config_entry_2.entry_id, + config_entry=mock_config_entry_2, + device_id=device_2.id, + suggested_object_id="chatgpt_2", + ) + + # Run migration + with patch( + "homeassistant.components.openai_conversation.async_setup_entry", + return_value=True, + ): + await hass.config_entries.async_setup(mock_config_entry.entry_id) + await hass.async_block_till_done() + await hass.async_block_till_done() + + # Should have only one entry left (consolidated) + entries = hass.config_entries.async_entries(DOMAIN) + assert len(entries) == 1 + + entry = entries[0] + assert entry.version == 2 + assert not entry.options + assert len(entry.subentries) == 2 # Two subentries from the two original entries + + # Check both subentries exist with correct data + subentries = list(entry.subentries.values()) + titles = [sub.title for sub in subentries] + assert "ChatGPT" in titles + assert "ChatGPT 2" in titles + + for subentry in subentries: + assert subentry.subentry_type == "conversation" + assert subentry.data == options + + # Check devices were migrated correctly + dev = device_registry.async_get_device( + identifiers={(DOMAIN, subentry.subentry_id)} + ) + assert dev is not None diff --git a/tests/components/reolink/conftest.py b/tests/components/reolink/conftest.py index 2f37fca251a230..6d5e7d2688e9fe 100644 --- a/tests/components/reolink/conftest.py +++ b/tests/components/reolink/conftest.py @@ -77,6 +77,9 @@ def _init_host_mock(host_mock: MagicMock) -> None: host_mock.get_stream_source = AsyncMock() host_mock.get_snapshot = AsyncMock() host_mock.get_encoding = AsyncMock(return_value="h264") + host_mock.pull_point_request = AsyncMock() + host_mock.set_audio = AsyncMock() + host_mock.set_email = AsyncMock() host_mock.ONVIF_event_callback = AsyncMock() host_mock.is_nvr = True host_mock.is_hub = False @@ -271,6 +274,7 @@ def reolink_chime(reolink_host: MagicMock) -> None: "people": {"switch": 0, "musicId": 1}, "visitor": {"switch": 1, "musicId": 2}, } + TEST_CHIME.remove = AsyncMock() reolink_host.chime_list = [TEST_CHIME] reolink_host.chime.return_value = TEST_CHIME diff --git a/tests/components/reolink/test_init.py b/tests/components/reolink/test_init.py index ed71314e961f6f..e439d3dff935ee 100644 --- a/tests/components/reolink/test_init.py +++ b/tests/components/reolink/test_init.py @@ -7,7 +7,6 @@ from freezegun.api import FrozenDateTimeFactory import pytest -from reolink_aio.api import Chime from reolink_aio.exceptions import ( CredentialsInvalidError, LoginPrivacyModeError, @@ -270,22 +269,25 @@ async def test_removing_disconnected_cams( @pytest.mark.parametrize( - ("attr", "value", "expected_models"), + ("attr", "value", "expected_models", "expected_remove_call_count"), [ ( None, None, [TEST_HOST_MODEL, TEST_CAM_MODEL, CHIME_MODEL], + 1, ), ( "connect_state", -1, [TEST_HOST_MODEL, TEST_CAM_MODEL], + 0, ), ( "remove", -1, [TEST_HOST_MODEL, TEST_CAM_MODEL], + 1, ), ], ) @@ -294,12 +296,13 @@ async def test_removing_chime( hass_ws_client: WebSocketGenerator, config_entry: MockConfigEntry, reolink_host: MagicMock, - reolink_chime: Chime, + reolink_chime: MagicMock, device_registry: dr.DeviceRegistry, entity_registry: er.EntityRegistry, attr: str | None, value: Any, expected_models: list[str], + expected_remove_call_count: int, ) -> None: """Test removing a chime.""" reolink_host.channels = [0] @@ -324,7 +327,7 @@ async def test_remove_chime(*args, **key_args): """Remove chime.""" reolink_chime.connect_state = -1 - reolink_chime.remove = test_remove_chime + reolink_chime.remove = AsyncMock(side_effect=test_remove_chime) elif attr is not None: setattr(reolink_chime, attr, value) @@ -334,6 +337,7 @@ async def test_remove_chime(*args, **key_args): if device.model == CHIME_MODEL: response = await client.remove_device(device.id, config_entry.entry_id) assert response["success"] == expected_success + assert reolink_chime.remove.call_count == expected_remove_call_count device_entries = dr.async_entries_for_config_entry( device_registry, config_entry.entry_id @@ -1156,11 +1160,11 @@ def register_callback( async def test_baichaun_only( hass: HomeAssistant, - reolink_connect: MagicMock, + reolink_host: MagicMock, config_entry: MockConfigEntry, ) -> None: """Test initializing a baichuan only device.""" - reolink_connect.baichuan_only = True + reolink_host.baichuan_only = True with patch("homeassistant.components.reolink.PLATFORMS", [Platform.SWITCH]): assert await hass.config_entries.async_setup(config_entry.entry_id) diff --git a/tests/components/reolink/test_services.py b/tests/components/reolink/test_services.py index 6ae9a2d97299d7..38819bbd51d998 100644 --- a/tests/components/reolink/test_services.py +++ b/tests/components/reolink/test_services.py @@ -20,8 +20,8 @@ async def test_play_chime_service_entity( hass: HomeAssistant, config_entry: MockConfigEntry, - reolink_connect: MagicMock, - test_chime: Chime, + reolink_host: MagicMock, + reolink_chime: Chime, entity_registry: er.EntityRegistry, ) -> None: """Test chime play service.""" @@ -37,14 +37,14 @@ async def test_play_chime_service_entity( device_id = entity.device_id # Test chime play service with device - test_chime.play = AsyncMock() + reolink_chime.play = AsyncMock() await hass.services.async_call( DOMAIN, "play_chime", {ATTR_DEVICE_ID: [device_id], ATTR_RINGTONE: "attraction"}, blocking=True, ) - test_chime.play.assert_called_once() + reolink_chime.play.assert_called_once() # Test errors with pytest.raises(ServiceValidationError): @@ -55,7 +55,7 @@ async def test_play_chime_service_entity( blocking=True, ) - test_chime.play = AsyncMock(side_effect=ReolinkError("Test error")) + reolink_chime.play = AsyncMock(side_effect=ReolinkError("Test error")) with pytest.raises(HomeAssistantError): await hass.services.async_call( DOMAIN, @@ -64,7 +64,7 @@ async def test_play_chime_service_entity( blocking=True, ) - test_chime.play = AsyncMock(side_effect=InvalidParameterError("Test error")) + reolink_chime.play = AsyncMock(side_effect=InvalidParameterError("Test error")) with pytest.raises(ServiceValidationError): await hass.services.async_call( DOMAIN, @@ -73,7 +73,7 @@ async def test_play_chime_service_entity( blocking=True, ) - reolink_connect.chime.return_value = None + reolink_host.chime.return_value = None with pytest.raises(ServiceValidationError): await hass.services.async_call( DOMAIN, @@ -86,8 +86,8 @@ async def test_play_chime_service_entity( async def test_play_chime_service_unloaded( hass: HomeAssistant, config_entry: MockConfigEntry, - reolink_connect: MagicMock, - test_chime: Chime, + reolink_host: MagicMock, + reolink_chime: Chime, entity_registry: er.EntityRegistry, ) -> None: """Test chime play service when config entry is unloaded.""" diff --git a/tests/components/reolink/test_switch.py b/tests/components/reolink/test_switch.py index 2b2c33f0e8fba3..9c0f2295a20142 100644 --- a/tests/components/reolink/test_switch.py +++ b/tests/components/reolink/test_switch.py @@ -33,11 +33,11 @@ async def test_switch( hass: HomeAssistant, config_entry: MockConfigEntry, freezer: FrozenDateTimeFactory, - reolink_connect: MagicMock, + reolink_host: MagicMock, ) -> None: """Test switch entity.""" - reolink_connect.camera_name.return_value = TEST_CAM_NAME - reolink_connect.audio_record.return_value = True + reolink_host.camera_name.return_value = TEST_CAM_NAME + reolink_host.audio_record.return_value = True with patch("homeassistant.components.reolink.PLATFORMS", [Platform.SWITCH]): assert await hass.config_entries.async_setup(config_entry.entry_id) @@ -47,7 +47,7 @@ async def test_switch( entity_id = f"{Platform.SWITCH}.{TEST_CAM_NAME}_record_audio" assert hass.states.get(entity_id).state == STATE_ON - reolink_connect.audio_record.return_value = False + reolink_host.audio_record.return_value = False freezer.tick(DEVICE_UPDATE_INTERVAL) async_fire_time_changed(hass) await hass.async_block_till_done() @@ -61,9 +61,9 @@ async def test_switch( {ATTR_ENTITY_ID: entity_id}, blocking=True, ) - reolink_connect.set_audio.assert_called_with(0, True) + reolink_host.set_audio.assert_called_with(0, True) - reolink_connect.set_audio.side_effect = ReolinkError("Test error") + reolink_host.set_audio.side_effect = ReolinkError("Test error") with pytest.raises(HomeAssistantError): await hass.services.async_call( SWITCH_DOMAIN, @@ -73,16 +73,16 @@ async def test_switch( ) # test switch turn off - reolink_connect.set_audio.reset_mock(side_effect=True) + reolink_host.set_audio.reset_mock(side_effect=True) await hass.services.async_call( SWITCH_DOMAIN, SERVICE_TURN_OFF, {ATTR_ENTITY_ID: entity_id}, blocking=True, ) - reolink_connect.set_audio.assert_called_with(0, False) + reolink_host.set_audio.assert_called_with(0, False) - reolink_connect.set_audio.side_effect = ReolinkError("Test error") + reolink_host.set_audio.side_effect = ReolinkError("Test error") with pytest.raises(HomeAssistantError): await hass.services.async_call( SWITCH_DOMAIN, @@ -91,29 +91,27 @@ async def test_switch( blocking=True, ) - reolink_connect.set_audio.reset_mock(side_effect=True) + reolink_host.set_audio.reset_mock(side_effect=True) - reolink_connect.camera_online.return_value = False + reolink_host.camera_online.return_value = False freezer.tick(DEVICE_UPDATE_INTERVAL) async_fire_time_changed(hass) await hass.async_block_till_done() assert hass.states.get(entity_id).state == STATE_UNAVAILABLE - reolink_connect.camera_online.return_value = True - async def test_host_switch( hass: HomeAssistant, config_entry: MockConfigEntry, freezer: FrozenDateTimeFactory, - reolink_connect: MagicMock, + reolink_host: MagicMock, ) -> None: """Test host switch entity.""" - reolink_connect.camera_name.return_value = TEST_CAM_NAME - reolink_connect.email_enabled.return_value = True - reolink_connect.is_hub = False - reolink_connect.supported.return_value = True + reolink_host.camera_name.return_value = TEST_CAM_NAME + reolink_host.email_enabled.return_value = True + reolink_host.is_hub = False + reolink_host.supported.return_value = True with patch("homeassistant.components.reolink.PLATFORMS", [Platform.SWITCH]): assert await hass.config_entries.async_setup(config_entry.entry_id) @@ -123,7 +121,7 @@ async def test_host_switch( entity_id = f"{Platform.SWITCH}.{TEST_NVR_NAME}_email_on_event" assert hass.states.get(entity_id).state == STATE_ON - reolink_connect.email_enabled.return_value = False + reolink_host.email_enabled.return_value = False freezer.tick(DEVICE_UPDATE_INTERVAL) async_fire_time_changed(hass) await hass.async_block_till_done() @@ -137,9 +135,9 @@ async def test_host_switch( {ATTR_ENTITY_ID: entity_id}, blocking=True, ) - reolink_connect.set_email.assert_called_with(None, True) + reolink_host.set_email.assert_called_with(None, True) - reolink_connect.set_email.side_effect = ReolinkError("Test error") + reolink_host.set_email.side_effect = ReolinkError("Test error") with pytest.raises(HomeAssistantError): await hass.services.async_call( SWITCH_DOMAIN, @@ -149,16 +147,16 @@ async def test_host_switch( ) # test switch turn off - reolink_connect.set_email.reset_mock(side_effect=True) + reolink_host.set_email.reset_mock(side_effect=True) await hass.services.async_call( SWITCH_DOMAIN, SERVICE_TURN_OFF, {ATTR_ENTITY_ID: entity_id}, blocking=True, ) - reolink_connect.set_email.assert_called_with(None, False) + reolink_host.set_email.assert_called_with(None, False) - reolink_connect.set_email.side_effect = ReolinkError("Test error") + reolink_host.set_email.side_effect = ReolinkError("Test error") with pytest.raises(HomeAssistantError): await hass.services.async_call( SWITCH_DOMAIN, @@ -167,15 +165,13 @@ async def test_host_switch( blocking=True, ) - reolink_connect.set_email.reset_mock(side_effect=True) - async def test_chime_switch( hass: HomeAssistant, config_entry: MockConfigEntry, freezer: FrozenDateTimeFactory, - reolink_connect: MagicMock, - test_chime: Chime, + reolink_host: MagicMock, + reolink_chime: Chime, ) -> None: """Test host switch entity.""" with patch("homeassistant.components.reolink.PLATFORMS", [Platform.SWITCH]): @@ -186,7 +182,7 @@ async def test_chime_switch( entity_id = f"{Platform.SWITCH}.test_chime_led" assert hass.states.get(entity_id).state == STATE_ON - test_chime.led_state = False + reolink_chime.led_state = False freezer.tick(DEVICE_UPDATE_INTERVAL) async_fire_time_changed(hass) await hass.async_block_till_done() @@ -194,16 +190,16 @@ async def test_chime_switch( assert hass.states.get(entity_id).state == STATE_OFF # test switch turn on - test_chime.set_option = AsyncMock() + reolink_chime.set_option = AsyncMock() await hass.services.async_call( SWITCH_DOMAIN, SERVICE_TURN_ON, {ATTR_ENTITY_ID: entity_id}, blocking=True, ) - test_chime.set_option.assert_called_with(led=True) + reolink_chime.set_option.assert_called_with(led=True) - test_chime.set_option.side_effect = ReolinkError("Test error") + reolink_chime.set_option.side_effect = ReolinkError("Test error") with pytest.raises(HomeAssistantError): await hass.services.async_call( SWITCH_DOMAIN, @@ -213,16 +209,16 @@ async def test_chime_switch( ) # test switch turn off - test_chime.set_option.reset_mock(side_effect=True) + reolink_chime.set_option.reset_mock(side_effect=True) await hass.services.async_call( SWITCH_DOMAIN, SERVICE_TURN_OFF, {ATTR_ENTITY_ID: entity_id}, blocking=True, ) - test_chime.set_option.assert_called_with(led=False) + reolink_chime.set_option.assert_called_with(led=False) - test_chime.set_option.side_effect = ReolinkError("Test error") + reolink_chime.set_option.side_effect = ReolinkError("Test error") with pytest.raises(HomeAssistantError): await hass.services.async_call( SWITCH_DOMAIN, @@ -231,8 +227,6 @@ async def test_chime_switch( blocking=True, ) - test_chime.set_option.reset_mock(side_effect=True) - @pytest.mark.parametrize( ( @@ -265,7 +259,7 @@ async def test_chime_switch( async def test_cleanup_hub_switches( hass: HomeAssistant, config_entry: MockConfigEntry, - reolink_connect: MagicMock, + reolink_host: MagicMock, entity_registry: er.EntityRegistry, original_id: str, capability: str, @@ -279,9 +273,9 @@ def mock_supported(ch, cap): domain = Platform.SWITCH - reolink_connect.channels = [0] - reolink_connect.is_hub = True - reolink_connect.supported = mock_supported + reolink_host.channels = [0] + reolink_host.is_hub = True + reolink_host.supported = mock_supported entity_registry.async_get_or_create( domain=domain, @@ -301,9 +295,6 @@ def mock_supported(ch, cap): assert entity_registry.async_get_entity_id(domain, DOMAIN, original_id) is None - reolink_connect.is_hub = False - reolink_connect.supported.return_value = True - @pytest.mark.parametrize( ( @@ -336,7 +327,7 @@ def mock_supported(ch, cap): async def test_hub_switches_repair_issue( hass: HomeAssistant, config_entry: MockConfigEntry, - reolink_connect: MagicMock, + reolink_host: MagicMock, entity_registry: er.EntityRegistry, issue_registry: ir.IssueRegistry, original_id: str, @@ -351,9 +342,9 @@ def mock_supported(ch, cap): domain = Platform.SWITCH - reolink_connect.channels = [0] - reolink_connect.is_hub = True - reolink_connect.supported = mock_supported + reolink_host.channels = [0] + reolink_host.is_hub = True + reolink_host.supported = mock_supported entity_registry.async_get_or_create( domain=domain, @@ -373,6 +364,3 @@ def mock_supported(ch, cap): assert entity_registry.async_get_entity_id(domain, DOMAIN, original_id) assert (DOMAIN, "hub_switch_deprecated") in issue_registry.issues - - reolink_connect.is_hub = False - reolink_connect.supported.return_value = True diff --git a/tests/components/wyoming/__init__.py b/tests/components/wyoming/__init__.py index 4540cdaabfdfba..de82dc08719a2c 100644 --- a/tests/components/wyoming/__init__.py +++ b/tests/components/wyoming/__init__.py @@ -69,6 +69,29 @@ ) ] ) +TTS_STREAMING_INFO = Info( + tts=[ + TtsProgram( + name="Test Streaming TTS", + description="Test Streaming TTS", + installed=True, + attribution=TEST_ATTR, + voices=[ + TtsVoice( + name="Test Voice", + description="Test Voice", + installed=True, + attribution=TEST_ATTR, + languages=["en-US"], + speakers=[TtsVoiceSpeaker(name="Test Speaker")], + version=None, + ) + ], + version=None, + supports_synthesize_streaming=True, + ) + ] +) WAKE_WORD_INFO = Info( wake=[ WakeProgram( @@ -155,9 +178,15 @@ def __init__(self, responses: list[Event]) -> None: self.port: int | None = None self.written: list[Event] = [] self.responses = responses + self.is_connected: bool | None = None async def connect(self) -> None: """Connect.""" + self.is_connected = True + + async def disconnect(self) -> None: + """Disconnect.""" + self.is_connected = False async def write_event(self, event: Event): """Send.""" diff --git a/tests/components/wyoming/conftest.py b/tests/components/wyoming/conftest.py index 125edc547c6c8c..2974bb4b013312 100644 --- a/tests/components/wyoming/conftest.py +++ b/tests/components/wyoming/conftest.py @@ -19,6 +19,7 @@ SATELLITE_INFO, STT_INFO, TTS_INFO, + TTS_STREAMING_INFO, WAKE_WORD_INFO, ) @@ -148,6 +149,20 @@ async def init_wyoming_tts( return tts_config_entry +@pytest.fixture +async def init_wyoming_streaming_tts( + hass: HomeAssistant, tts_config_entry: ConfigEntry +) -> ConfigEntry: + """Initialize Wyoming streaming TTS.""" + with patch( + "homeassistant.components.wyoming.data.load_wyoming_info", + return_value=TTS_STREAMING_INFO, + ): + await hass.config_entries.async_setup(tts_config_entry.entry_id) + + return tts_config_entry + + @pytest.fixture async def init_wyoming_wake_word( hass: HomeAssistant, wake_word_config_entry: ConfigEntry diff --git a/tests/components/wyoming/snapshots/test_tts.ambr b/tests/components/wyoming/snapshots/test_tts.ambr index 7ca5204e66c739..53cc02eaacf2fa 100644 --- a/tests/components/wyoming/snapshots/test_tts.ambr +++ b/tests/components/wyoming/snapshots/test_tts.ambr @@ -32,6 +32,43 @@ }), ]) # --- +# name: test_get_tts_audio_streaming + list([ + dict({ + 'data': dict({ + }), + 'payload': None, + 'type': 'synthesize-start', + }), + dict({ + 'data': dict({ + 'text': 'Hello ', + }), + 'payload': None, + 'type': 'synthesize-chunk', + }), + dict({ + 'data': dict({ + 'text': 'Word.', + }), + 'payload': None, + 'type': 'synthesize-chunk', + }), + dict({ + 'data': dict({ + 'text': 'Hello Word.', + }), + 'payload': None, + 'type': 'synthesize', + }), + dict({ + 'data': dict({ + }), + 'payload': None, + 'type': 'synthesize-stop', + }), + ]) +# --- # name: test_voice_speaker list([ dict({ diff --git a/tests/components/wyoming/test_satellite.py b/tests/components/wyoming/test_satellite.py index dec5d6cbebde31..870e2696601846 100644 --- a/tests/components/wyoming/test_satellite.py +++ b/tests/components/wyoming/test_satellite.py @@ -1472,3 +1472,184 @@ def tts_response_finished(self): # Stop the satellite await hass.config_entries.async_unload(entry.entry_id) await hass.async_block_till_done() + + +async def test_satellite_tts_streaming(hass: HomeAssistant) -> None: + """Test running a streaming TTS pipeline with a satellite.""" + assert await async_setup_component(hass, assist_pipeline.DOMAIN, {}) + + events = [ + RunPipeline(start_stage=PipelineStage.ASR, end_stage=PipelineStage.TTS).event(), + ] + + pipeline_kwargs: dict[str, Any] = {} + pipeline_event_callback: Callable[[assist_pipeline.PipelineEvent], None] | None = ( + None + ) + run_pipeline_called = asyncio.Event() + audio_chunk_received = asyncio.Event() + + async def async_pipeline_from_audio_stream( + hass: HomeAssistant, + context, + event_callback, + stt_metadata, + stt_stream, + **kwargs, + ) -> None: + nonlocal pipeline_kwargs, pipeline_event_callback + pipeline_kwargs = kwargs + pipeline_event_callback = event_callback + + run_pipeline_called.set() + async for chunk in stt_stream: + if chunk: + audio_chunk_received.set() + break + + with ( + patch( + "homeassistant.components.wyoming.data.load_wyoming_info", + return_value=SATELLITE_INFO, + ), + patch( + "homeassistant.components.wyoming.assist_satellite.AsyncTcpClient", + SatelliteAsyncTcpClient(events), + ) as mock_client, + patch( + "homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream", + async_pipeline_from_audio_stream, + ), + patch("homeassistant.components.wyoming.assist_satellite._PING_SEND_DELAY", 0), + ): + entry = await setup_config_entry(hass) + device: SatelliteDevice = hass.data[wyoming.DOMAIN][entry.entry_id].device + assert device is not None + + async with asyncio.timeout(1): + await mock_client.connect_event.wait() + await mock_client.run_satellite_event.wait() + + async with asyncio.timeout(1): + await run_pipeline_called.wait() + + assert pipeline_event_callback is not None + assert pipeline_kwargs.get("device_id") == device.device_id + + # Send TTS info early + mock_tts_result_stream = MockResultStream(hass, "wav", get_test_wav()) + pipeline_event_callback( + assist_pipeline.PipelineEvent( + assist_pipeline.PipelineEventType.RUN_START, + {"tts_output": {"token": mock_tts_result_stream.token}}, + ) + ) + + # Speech-to-text started + pipeline_event_callback( + assist_pipeline.PipelineEvent( + assist_pipeline.PipelineEventType.STT_START, + {"metadata": {"language": "en"}}, + ) + ) + async with asyncio.timeout(1): + await mock_client.transcribe_event.wait() + + # Push in some audio + mock_client.inject_event( + AudioChunk(rate=16000, width=2, channels=1, audio=bytes(1024)).event() + ) + + # User started speaking + pipeline_event_callback( + assist_pipeline.PipelineEvent( + assist_pipeline.PipelineEventType.STT_VAD_START, {"timestamp": 1234} + ) + ) + async with asyncio.timeout(1): + await mock_client.voice_started_event.wait() + + # User stopped speaking + pipeline_event_callback( + assist_pipeline.PipelineEvent( + assist_pipeline.PipelineEventType.STT_VAD_END, {"timestamp": 5678} + ) + ) + async with asyncio.timeout(1): + await mock_client.voice_stopped_event.wait() + + # Speech-to-text transcription + pipeline_event_callback( + assist_pipeline.PipelineEvent( + assist_pipeline.PipelineEventType.STT_END, + {"stt_output": {"text": "test transcript"}}, + ) + ) + async with asyncio.timeout(1): + await mock_client.transcript_event.wait() + + # Intent progress starts TTS streaming early with info received in the + # run-start event. + pipeline_event_callback( + assist_pipeline.PipelineEvent( + assist_pipeline.PipelineEventType.INTENT_PROGRESS, + {"tts_start_streaming": True}, + ) + ) + + # TTS events are sent now. In practice, these would be streamed as text + # chunks are generated. + async with asyncio.timeout(1): + await mock_client.tts_audio_start_event.wait() + await mock_client.tts_audio_chunk_event.wait() + await mock_client.tts_audio_stop_event.wait() + + # Verify audio chunk from test WAV + assert mock_client.tts_audio_chunk is not None + assert mock_client.tts_audio_chunk.rate == 22050 + assert mock_client.tts_audio_chunk.width == 2 + assert mock_client.tts_audio_chunk.channels == 1 + assert mock_client.tts_audio_chunk.audio == b"1234" + + # Text-to-speech text + pipeline_event_callback( + assist_pipeline.PipelineEvent( + assist_pipeline.PipelineEventType.TTS_START, + { + "tts_input": "test text to speak", + "voice": "test voice", + }, + ) + ) + + # synthesize event is sent with complete message for non-streaming clients + async with asyncio.timeout(1): + await mock_client.synthesize_event.wait() + + assert mock_client.synthesize is not None + assert mock_client.synthesize.text == "test text to speak" + assert mock_client.synthesize.voice is not None + assert mock_client.synthesize.voice.name == "test voice" + + # Because we started streaming TTS after intent progress, we should not + # stream it again on tts-end. + with patch( + "homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite._stream_tts" + ) as mock_stream_tts: + pipeline_event_callback( + assist_pipeline.PipelineEvent( + assist_pipeline.PipelineEventType.TTS_END, + {"tts_output": {"token": mock_tts_result_stream.token}}, + ) + ) + + mock_stream_tts.assert_not_called() + + # Pipeline finished + pipeline_event_callback( + assist_pipeline.PipelineEvent(assist_pipeline.PipelineEventType.RUN_END) + ) + + # Stop the satellite + await hass.config_entries.async_unload(entry.entry_id) + await hass.async_block_till_done() diff --git a/tests/components/wyoming/test_tts.py b/tests/components/wyoming/test_tts.py index c658bff1d0cf41..3374328f4119aa 100644 --- a/tests/components/wyoming/test_tts.py +++ b/tests/components/wyoming/test_tts.py @@ -8,7 +8,8 @@ import pytest from syrupy.assertion import SnapshotAssertion -from wyoming.audio import AudioChunk, AudioStop +from wyoming.audio import AudioChunk, AudioStart, AudioStop +from wyoming.tts import SynthesizeStopped from homeassistant.components import tts, wyoming from homeassistant.core import HomeAssistant @@ -43,11 +44,11 @@ async def test_get_tts_audio( hass: HomeAssistant, init_wyoming_tts, snapshot: SnapshotAssertion ) -> None: """Test get audio.""" + entity = hass.data[DATA_INSTANCES]["tts"].get_entity("tts.test_tts") + assert entity is not None + assert not entity.async_supports_streaming_input() + audio = bytes(100) - audio_events = [ - AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(), - AudioStop().event(), - ] # Verify audio audio_events = [ @@ -215,3 +216,52 @@ async def test_voice_speaker( ), ) assert mock_client.written == snapshot + + +async def test_get_tts_audio_streaming( + hass: HomeAssistant, init_wyoming_streaming_tts, snapshot: SnapshotAssertion +) -> None: + """Test get audio with streaming.""" + entity = hass.data[DATA_INSTANCES]["tts"].get_entity("tts.test_streaming_tts") + assert entity is not None + assert entity.async_supports_streaming_input() + + audio = bytes(100) + + # Verify audio + audio_events = [ + AudioStart(rate=16000, width=2, channels=1).event(), + AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(), + AudioStop().event(), + SynthesizeStopped().event(), + ] + + async def message_gen(): + yield "Hello " + yield "Word." + + with patch( + "homeassistant.components.wyoming.tts.AsyncTcpClient", + MockAsyncTcpClient(audio_events), + ) as mock_client: + stream = tts.async_create_stream( + hass, + "tts.test_streaming_tts", + "en-US", + options={tts.ATTR_PREFERRED_FORMAT: "wav"}, + ) + stream.async_set_message_stream(message_gen()) + data = b"".join([chunk async for chunk in stream.async_stream_result()]) + + # Ensure client was disconnected properly + assert mock_client.is_connected is False + + assert data is not None + with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file: + assert wav_file.getframerate() == 16000 + assert wav_file.getsampwidth() == 2 + assert wav_file.getnchannels() == 1 + assert wav_file.getnframes() == 0 # streaming + assert data[44:] == audio # WAV header is 44 bytes + + assert mock_client.written == snapshot