Skip to content

Commit 90c68f8

Browse files
authored
Prevent reloading the ZHA integration while adapter firmware is being updated (home-assistant#152626)
1 parent 6b79aa7 commit 90c68f8

File tree

9 files changed

+302
-48
lines changed

9 files changed

+302
-48
lines changed

homeassistant/components/homeassistant_hardware/helpers.py

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,20 @@
11
"""Home Assistant Hardware integration helpers."""
22

3+
from __future__ import annotations
4+
35
from collections import defaultdict
46
from collections.abc import AsyncIterator, Awaitable, Callable
7+
from contextlib import asynccontextmanager
58
import logging
6-
from typing import Protocol
9+
from typing import TYPE_CHECKING, Protocol
710

811
from homeassistant.config_entries import ConfigEntry
912
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback as hass_callback
1013

1114
from . import DATA_COMPONENT
12-
from .util import FirmwareInfo
15+
16+
if TYPE_CHECKING:
17+
from .util import FirmwareInfo
1318

1419
_LOGGER = logging.getLogger(__name__)
1520

@@ -51,6 +56,7 @@ def __init__(self, hass: HomeAssistant) -> None:
5156
self._notification_callbacks: defaultdict[
5257
str, set[Callable[[FirmwareInfo], None]]
5358
] = defaultdict(set)
59+
self._active_firmware_updates: dict[str, str] = {}
5460

5561
def register_firmware_info_provider(
5662
self, domain: str, platform: HardwareFirmwareInfoModule
@@ -118,6 +124,36 @@ async def iter_firmware_info(self) -> AsyncIterator[FirmwareInfo]:
118124
if fw_info is not None:
119125
yield fw_info
120126

127+
def register_firmware_update_in_progress(
128+
self, device: str, source_domain: str
129+
) -> None:
130+
"""Register that a firmware update is in progress for a device."""
131+
if device in self._active_firmware_updates:
132+
current_domain = self._active_firmware_updates[device]
133+
raise ValueError(
134+
f"Firmware update already in progress for {device} by {current_domain}"
135+
)
136+
self._active_firmware_updates[device] = source_domain
137+
138+
def unregister_firmware_update_in_progress(
139+
self, device: str, source_domain: str
140+
) -> None:
141+
"""Unregister a firmware update for a device."""
142+
if device not in self._active_firmware_updates:
143+
raise ValueError(f"No firmware update in progress for {device}")
144+
145+
if self._active_firmware_updates[device] != source_domain:
146+
current_domain = self._active_firmware_updates[device]
147+
raise ValueError(
148+
f"Firmware update for {device} is owned by {current_domain}, not {source_domain}"
149+
)
150+
151+
del self._active_firmware_updates[device]
152+
153+
def is_firmware_update_in_progress(self, device: str) -> bool:
154+
"""Check if a firmware update is in progress for a device."""
155+
return device in self._active_firmware_updates
156+
121157

122158
@hass_callback
123159
def async_register_firmware_info_provider(
@@ -141,3 +177,42 @@ def async_notify_firmware_info(
141177
) -> Awaitable[None]:
142178
"""Notify the dispatcher of new firmware information."""
143179
return hass.data[DATA_COMPONENT].notify_firmware_info(domain, firmware_info)
180+
181+
182+
@hass_callback
183+
def async_register_firmware_update_in_progress(
184+
hass: HomeAssistant, device: str, source_domain: str
185+
) -> None:
186+
"""Register that a firmware update is in progress for a device."""
187+
return hass.data[DATA_COMPONENT].register_firmware_update_in_progress(
188+
device, source_domain
189+
)
190+
191+
192+
@hass_callback
193+
def async_unregister_firmware_update_in_progress(
194+
hass: HomeAssistant, device: str, source_domain: str
195+
) -> None:
196+
"""Unregister a firmware update for a device."""
197+
return hass.data[DATA_COMPONENT].unregister_firmware_update_in_progress(
198+
device, source_domain
199+
)
200+
201+
202+
@hass_callback
203+
def async_is_firmware_update_in_progress(hass: HomeAssistant, device: str) -> bool:
204+
"""Check if a firmware update is in progress for a device."""
205+
return hass.data[DATA_COMPONENT].is_firmware_update_in_progress(device)
206+
207+
208+
@asynccontextmanager
209+
async def async_firmware_update_context(
210+
hass: HomeAssistant, device: str, source_domain: str
211+
) -> AsyncIterator[None]:
212+
"""Register a device as having its firmware being actively updated."""
213+
async_register_firmware_update_in_progress(hass, device, source_domain)
214+
215+
try:
216+
yield
217+
finally:
218+
async_unregister_firmware_update_in_progress(hass, device, source_domain)

homeassistant/components/homeassistant_hardware/update.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,7 @@ async def async_install(
275275
expected_installed_firmware_type=self.entity_description.expected_firmware_type,
276276
bootloader_reset_methods=self.bootloader_reset_methods,
277277
progress_callback=self._update_progress,
278+
domain=self._config_entry.domain,
278279
)
279280
finally:
280281
self._attr_in_progress = False

homeassistant/components/homeassistant_hardware/util.py

Lines changed: 41 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,15 @@
2626

2727
from . import DATA_COMPONENT
2828
from .const import (
29+
DOMAIN,
2930
OTBR_ADDON_MANAGER_DATA,
3031
OTBR_ADDON_NAME,
3132
OTBR_ADDON_SLUG,
3233
ZIGBEE_FLASHER_ADDON_MANAGER_DATA,
3334
ZIGBEE_FLASHER_ADDON_NAME,
3435
ZIGBEE_FLASHER_ADDON_SLUG,
3536
)
37+
from .helpers import async_firmware_update_context
3638
from .silabs_multiprotocol_addon import (
3739
WaitingAddonManager,
3840
get_multiprotocol_addon_manager,
@@ -359,45 +361,50 @@ async def async_flash_silabs_firmware(
359361
expected_installed_firmware_type: ApplicationType,
360362
bootloader_reset_methods: Sequence[ResetTarget] = (),
361363
progress_callback: Callable[[int, int], None] | None = None,
364+
*,
365+
domain: str = DOMAIN,
362366
) -> FirmwareInfo:
363367
"""Flash firmware to the SiLabs device."""
364-
firmware_info = await guess_firmware_info(hass, device)
365-
_LOGGER.debug("Identified firmware info: %s", firmware_info)
366-
367-
fw_image = await hass.async_add_executor_job(parse_firmware_image, fw_data)
368-
369-
flasher = Flasher(
370-
device=device,
371-
probe_methods=(
372-
ApplicationType.GECKO_BOOTLOADER.as_flasher_application_type(),
373-
ApplicationType.EZSP.as_flasher_application_type(),
374-
ApplicationType.SPINEL.as_flasher_application_type(),
375-
ApplicationType.CPC.as_flasher_application_type(),
376-
),
377-
bootloader_reset=tuple(
378-
m.as_flasher_reset_target() for m in bootloader_reset_methods
379-
),
380-
)
368+
async with async_firmware_update_context(hass, device, domain):
369+
firmware_info = await guess_firmware_info(hass, device)
370+
_LOGGER.debug("Identified firmware info: %s", firmware_info)
371+
372+
fw_image = await hass.async_add_executor_job(parse_firmware_image, fw_data)
373+
374+
flasher = Flasher(
375+
device=device,
376+
probe_methods=(
377+
ApplicationType.GECKO_BOOTLOADER.as_flasher_application_type(),
378+
ApplicationType.EZSP.as_flasher_application_type(),
379+
ApplicationType.SPINEL.as_flasher_application_type(),
380+
ApplicationType.CPC.as_flasher_application_type(),
381+
),
382+
bootloader_reset=tuple(
383+
m.as_flasher_reset_target() for m in bootloader_reset_methods
384+
),
385+
)
381386

382-
async with AsyncExitStack() as stack:
383-
for owner in firmware_info.owners:
384-
await stack.enter_async_context(owner.temporarily_stop(hass))
387+
async with AsyncExitStack() as stack:
388+
for owner in firmware_info.owners:
389+
await stack.enter_async_context(owner.temporarily_stop(hass))
385390

386-
try:
387-
# Enter the bootloader with indeterminate progress
388-
await flasher.enter_bootloader()
391+
try:
392+
# Enter the bootloader with indeterminate progress
393+
await flasher.enter_bootloader()
389394

390-
# Flash the firmware, with progress
391-
await flasher.flash_firmware(fw_image, progress_callback=progress_callback)
392-
except Exception as err:
393-
raise HomeAssistantError("Failed to flash firmware") from err
395+
# Flash the firmware, with progress
396+
await flasher.flash_firmware(
397+
fw_image, progress_callback=progress_callback
398+
)
399+
except Exception as err:
400+
raise HomeAssistantError("Failed to flash firmware") from err
394401

395-
probed_firmware_info = await probe_silabs_firmware_info(
396-
device,
397-
probe_methods=(expected_installed_firmware_type,),
398-
)
402+
probed_firmware_info = await probe_silabs_firmware_info(
403+
device,
404+
probe_methods=(expected_installed_firmware_type,),
405+
)
399406

400-
if probed_firmware_info is None:
401-
raise HomeAssistantError("Failed to probe the firmware after flashing")
407+
if probed_firmware_info is None:
408+
raise HomeAssistantError("Failed to probe the firmware after flashing")
402409

403-
return probed_firmware_info
410+
return probed_firmware_info

homeassistant/components/zha/__init__.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from zigpy.exceptions import NetworkSettingsInconsistent, TransientConnectionError
1414

1515
from homeassistant.components.homeassistant_hardware.helpers import (
16+
async_is_firmware_update_in_progress,
1617
async_notify_firmware_info,
1718
async_register_firmware_info_provider,
1819
)
@@ -119,6 +120,14 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
119120
return True
120121

121122

123+
def _raise_if_port_in_use(hass: HomeAssistant, device_path: str) -> None:
124+
"""Ensure that the specified serial port is not in use by a firmware update."""
125+
if async_is_firmware_update_in_progress(hass, device_path):
126+
raise ConfigEntryNotReady(
127+
f"Firmware update in progress for device {device_path}"
128+
)
129+
130+
122131
async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> bool:
123132
"""Set up ZHA.
124133
@@ -152,6 +161,10 @@ async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> b
152161

153162
_LOGGER.debug("Trigger cache: %s", zha_lib_data.device_trigger_cache)
154163

164+
# Check if firmware update is in progress for this device
165+
device_path = config_entry.data[CONF_DEVICE][CONF_DEVICE_PATH]
166+
_raise_if_port_in_use(hass, device_path)
167+
155168
try:
156169
await zha_gateway.async_initialize()
157170
except NetworkSettingsInconsistent as exc:
@@ -168,7 +181,7 @@ async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> b
168181
raise ConfigEntryNotReady from exc
169182
except Exception as exc:
170183
_LOGGER.debug("Failed to set up ZHA", exc_info=exc)
171-
device_path = config_entry.data[CONF_DEVICE][CONF_DEVICE_PATH]
184+
_raise_if_port_in_use(hass, device_path)
172185

173186
if (
174187
not device_path.startswith("socket://")

tests/components/homeassistant_hardware/test_config_flow.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
BaseFirmwareConfigFlow,
2323
BaseFirmwareOptionsFlow,
2424
)
25+
from homeassistant.components.homeassistant_hardware.helpers import (
26+
async_firmware_update_context,
27+
)
2528
from homeassistant.components.homeassistant_hardware.util import (
2629
ApplicationType,
2730
FirmwareInfo,
@@ -302,18 +305,21 @@ async def mock_flash_firmware(
302305
expected_installed_firmware_type: ApplicationType,
303306
bootloader_reset_methods: Sequence[ResetTarget] = (),
304307
progress_callback: Callable[[int, int], None] | None = None,
308+
*,
309+
domain: str = "homeassistant_hardware",
305310
) -> FirmwareInfo:
306-
await asyncio.sleep(0)
307-
progress_callback(0, 100)
308-
await asyncio.sleep(0)
309-
progress_callback(50, 100)
310-
await asyncio.sleep(0)
311-
progress_callback(100, 100)
312-
313-
if flashed_firmware_info is None:
314-
raise HomeAssistantError("Failed to probe the firmware after flashing")
315-
316-
return flashed_firmware_info
311+
async with async_firmware_update_context(hass, device, domain):
312+
await asyncio.sleep(0)
313+
progress_callback(0, 100)
314+
await asyncio.sleep(0)
315+
progress_callback(50, 100)
316+
await asyncio.sleep(0)
317+
progress_callback(100, 100)
318+
319+
if flashed_firmware_info is None:
320+
raise HomeAssistantError("Failed to probe the firmware after flashing")
321+
322+
return flashed_firmware_info
317323

318324
with (
319325
patch(

tests/components/homeassistant_hardware/test_helpers.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,13 @@
77

88
from homeassistant.components.homeassistant_hardware.const import DATA_COMPONENT
99
from homeassistant.components.homeassistant_hardware.helpers import (
10+
async_firmware_update_context,
11+
async_is_firmware_update_in_progress,
1012
async_notify_firmware_info,
1113
async_register_firmware_info_callback,
1214
async_register_firmware_info_provider,
15+
async_register_firmware_update_in_progress,
16+
async_unregister_firmware_update_in_progress,
1317
)
1418
from homeassistant.components.homeassistant_hardware.util import (
1519
ApplicationType,
@@ -183,3 +187,73 @@ async def test_dispatcher_callback_error_handling(
183187

184188
assert callback1.mock_calls == [call(FIRMWARE_INFO_EZSP)]
185189
assert callback2.mock_calls == [call(FIRMWARE_INFO_EZSP)]
190+
191+
192+
async def test_firmware_update_tracking(hass: HomeAssistant) -> None:
193+
"""Test firmware update tracking API."""
194+
await async_setup_component(hass, "homeassistant_hardware", {})
195+
196+
device_path = "/dev/ttyUSB0"
197+
198+
assert not async_is_firmware_update_in_progress(hass, device_path)
199+
200+
# Register an update in progress
201+
async_register_firmware_update_in_progress(hass, device_path, "zha")
202+
assert async_is_firmware_update_in_progress(hass, device_path)
203+
204+
with pytest.raises(ValueError, match="Firmware update already in progress"):
205+
async_register_firmware_update_in_progress(hass, device_path, "skyconnect")
206+
207+
assert async_is_firmware_update_in_progress(hass, device_path)
208+
209+
# Unregister the update with correct domain
210+
async_unregister_firmware_update_in_progress(hass, device_path, "zha")
211+
assert not async_is_firmware_update_in_progress(hass, device_path)
212+
213+
# Test unregistering with wrong domain should raise an error
214+
async_register_firmware_update_in_progress(hass, device_path, "zha")
215+
with pytest.raises(ValueError, match="is owned by zha, not skyconnect"):
216+
async_unregister_firmware_update_in_progress(hass, device_path, "skyconnect")
217+
218+
# Still registered to zha
219+
assert async_is_firmware_update_in_progress(hass, device_path)
220+
async_unregister_firmware_update_in_progress(hass, device_path, "zha")
221+
assert not async_is_firmware_update_in_progress(hass, device_path)
222+
223+
224+
async def test_firmware_update_context_manager(hass: HomeAssistant) -> None:
225+
"""Test firmware update progress context manager."""
226+
await async_setup_component(hass, "homeassistant_hardware", {})
227+
228+
device_path = "/dev/ttyUSB0"
229+
230+
# Initially no updates in progress
231+
assert not async_is_firmware_update_in_progress(hass, device_path)
232+
233+
# Test successful completion
234+
async with async_firmware_update_context(hass, device_path, "zha"):
235+
assert async_is_firmware_update_in_progress(hass, device_path)
236+
237+
# Should be cleaned up after context
238+
assert not async_is_firmware_update_in_progress(hass, device_path)
239+
240+
# Test exception handling
241+
with pytest.raises(ValueError, match="test error"): # noqa: PT012
242+
async with async_firmware_update_context(hass, device_path, "zha"):
243+
assert async_is_firmware_update_in_progress(hass, device_path)
244+
raise ValueError("test error")
245+
246+
# Should still be cleaned up after exception
247+
assert not async_is_firmware_update_in_progress(hass, device_path)
248+
249+
# Test concurrent context manager attempts should fail
250+
async with async_firmware_update_context(hass, device_path, "zha"):
251+
assert async_is_firmware_update_in_progress(hass, device_path)
252+
253+
# Second context manager should fail to register
254+
with pytest.raises(ValueError, match="Firmware update already in progress"):
255+
async with async_firmware_update_context(hass, device_path, "skyconnect"):
256+
pytest.fail("We should not enter this context manager")
257+
258+
# Should be cleaned up after first context
259+
assert not async_is_firmware_update_in_progress(hass, device_path)

0 commit comments

Comments
 (0)