Skip to content

Commit 1753baf

Browse files
Add method to track entity state changes from target selectors (home-assistant#148086)
Co-authored-by: Erik Montnemery <[email protected]>
1 parent 8421ca7 commit 1753baf

File tree

2 files changed

+303
-6
lines changed

2 files changed

+303
-6
lines changed

homeassistant/helpers/target.py

Lines changed: 113 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22

33
from __future__ import annotations
44

5+
from collections.abc import Callable
56
import dataclasses
7+
import logging
68
from logging import Logger
7-
from typing import TypeGuard
9+
from typing import Any, TypeGuard
810

911
from homeassistant.const import (
1012
ATTR_AREA_ID,
@@ -14,7 +16,14 @@
1416
ATTR_LABEL_ID,
1517
ENTITY_MATCH_NONE,
1618
)
17-
from homeassistant.core import HomeAssistant
19+
from homeassistant.core import (
20+
CALLBACK_TYPE,
21+
Event,
22+
EventStateChangedData,
23+
HomeAssistant,
24+
callback,
25+
)
26+
from homeassistant.exceptions import HomeAssistantError
1827

1928
from . import (
2029
area_registry as ar,
@@ -25,8 +34,11 @@
2534
group,
2635
label_registry as lr,
2736
)
37+
from .event import async_track_state_change_event
2838
from .typing import ConfigType
2939

40+
_LOGGER = logging.getLogger(__name__)
41+
3042

3143
def _has_match(ids: str | list[str] | None) -> TypeGuard[str | list[str]]:
3244
"""Check if ids can match anything."""
@@ -238,3 +250,102 @@ def async_extract_referenced_entity_ids(
238250
)
239251

240252
return selected
253+
254+
255+
class TargetStateChangeTracker:
256+
"""Helper class to manage state change tracking for targets."""
257+
258+
def __init__(
259+
self,
260+
hass: HomeAssistant,
261+
selector_data: TargetSelectorData,
262+
action: Callable[[Event[EventStateChangedData]], Any],
263+
) -> None:
264+
"""Initialize the state change tracker."""
265+
self._hass = hass
266+
self._selector_data = selector_data
267+
self._action = action
268+
269+
self._state_change_unsub: CALLBACK_TYPE | None = None
270+
self._registry_unsubs: list[CALLBACK_TYPE] = []
271+
272+
def async_setup(self) -> Callable[[], None]:
273+
"""Set up the state change tracking."""
274+
self._setup_registry_listeners()
275+
self._track_entities_state_change()
276+
return self._unsubscribe
277+
278+
def _track_entities_state_change(self) -> None:
279+
"""Set up state change tracking for currently selected entities."""
280+
selected = async_extract_referenced_entity_ids(
281+
self._hass, self._selector_data, expand_group=False
282+
)
283+
284+
@callback
285+
def state_change_listener(event: Event[EventStateChangedData]) -> None:
286+
"""Handle state change events."""
287+
if (
288+
event.data["entity_id"] in selected.referenced
289+
or event.data["entity_id"] in selected.indirectly_referenced
290+
):
291+
self._action(event)
292+
293+
tracked_entities = selected.referenced.union(selected.indirectly_referenced)
294+
295+
_LOGGER.debug("Tracking state changes for entities: %s", tracked_entities)
296+
self._state_change_unsub = async_track_state_change_event(
297+
self._hass, tracked_entities, state_change_listener
298+
)
299+
300+
def _setup_registry_listeners(self) -> None:
301+
"""Set up listeners for registry changes that require resubscription."""
302+
303+
@callback
304+
def resubscribe_state_change_event(event: Event[Any] | None = None) -> None:
305+
"""Resubscribe to state change events when registry changes."""
306+
if self._state_change_unsub:
307+
self._state_change_unsub()
308+
self._track_entities_state_change()
309+
310+
# Subscribe to registry updates that can change the entities to track:
311+
# - Entity registry: entity added/removed; entity labels changed; entity area changed.
312+
# - Device registry: device labels changed; device area changed.
313+
# - Area registry: area floor changed.
314+
#
315+
# We don't track other registries (like floor or label registries) because their
316+
# changes don't affect which entities are tracked.
317+
self._registry_unsubs = [
318+
self._hass.bus.async_listen(
319+
er.EVENT_ENTITY_REGISTRY_UPDATED, resubscribe_state_change_event
320+
),
321+
self._hass.bus.async_listen(
322+
dr.EVENT_DEVICE_REGISTRY_UPDATED, resubscribe_state_change_event
323+
),
324+
self._hass.bus.async_listen(
325+
ar.EVENT_AREA_REGISTRY_UPDATED, resubscribe_state_change_event
326+
),
327+
]
328+
329+
def _unsubscribe(self) -> None:
330+
"""Unsubscribe from all events."""
331+
for registry_unsub in self._registry_unsubs:
332+
registry_unsub()
333+
self._registry_unsubs.clear()
334+
if self._state_change_unsub:
335+
self._state_change_unsub()
336+
self._state_change_unsub = None
337+
338+
339+
def async_track_target_selector_state_change_event(
340+
hass: HomeAssistant,
341+
target_selector_config: ConfigType,
342+
action: Callable[[Event[EventStateChangedData]], Any],
343+
) -> CALLBACK_TYPE:
344+
"""Track state changes for entities referenced directly or indirectly in a target selector."""
345+
selector_data = TargetSelectorData(target_selector_config)
346+
if not selector_data.has_any_selector:
347+
raise HomeAssistantError(
348+
f"Target selector {target_selector_config} does not have any selectors defined"
349+
)
350+
tracker = TargetStateChangeTracker(hass, selector_data, action)
351+
return tracker.async_setup()

tests/helpers/test_target.py

Lines changed: 190 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,6 @@
22

33
import pytest
44

5-
# TODO(abmantis): is this import needed?
6-
# To prevent circular import when running just this file
7-
import homeassistant.components # noqa: F401
85
from homeassistant.components.group import Group
96
from homeassistant.const import (
107
ATTR_AREA_ID,
@@ -17,17 +14,21 @@
1714
STATE_ON,
1815
EntityCategory,
1916
)
20-
from homeassistant.core import HomeAssistant
17+
from homeassistant.core import Event, EventStateChangedData, HomeAssistant, callback
18+
from homeassistant.exceptions import HomeAssistantError
2119
from homeassistant.helpers import (
2220
area_registry as ar,
2321
device_registry as dr,
2422
entity_registry as er,
23+
floor_registry as fr,
24+
label_registry as lr,
2525
target,
2626
)
2727
from homeassistant.helpers.typing import ConfigType
2828
from homeassistant.setup import async_setup_component
2929

3030
from tests.common import (
31+
MockConfigEntry,
3132
RegistryEntryWithDefaults,
3233
mock_area_registry,
3334
mock_device_registry,
@@ -457,3 +458,188 @@ async def test_extract_referenced_entity_ids(
457458
)
458459
== expected_selected
459460
)
461+
462+
463+
async def test_async_track_target_selector_state_change_event_empty_selector(
464+
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
465+
) -> None:
466+
"""Test async_track_target_selector_state_change_event with empty selector."""
467+
468+
@callback
469+
def state_change_callback(event):
470+
"""Handle state change events."""
471+
472+
with pytest.raises(HomeAssistantError) as excinfo:
473+
target.async_track_target_selector_state_change_event(
474+
hass, {}, state_change_callback
475+
)
476+
assert str(excinfo.value) == (
477+
"Target selector {} does not have any selectors defined"
478+
)
479+
480+
481+
async def test_async_track_target_selector_state_change_event(
482+
hass: HomeAssistant,
483+
) -> None:
484+
"""Test async_track_target_selector_state_change_event with multiple targets."""
485+
events: list[Event[EventStateChangedData]] = []
486+
487+
@callback
488+
def state_change_callback(event: Event[EventStateChangedData]):
489+
"""Handle state change events."""
490+
events.append(event)
491+
492+
last_state = STATE_OFF
493+
494+
async def set_states_and_check_events(
495+
entities_to_set_state: list[str], entities_to_assert_change: list[str]
496+
) -> None:
497+
"""Toggle the state entities and check for events."""
498+
nonlocal last_state
499+
last_state = STATE_ON if last_state == STATE_OFF else STATE_OFF
500+
for entity_id in entities_to_set_state:
501+
hass.states.async_set(entity_id, last_state)
502+
await hass.async_block_till_done()
503+
504+
assert len(events) == len(entities_to_assert_change)
505+
entities_seen = set()
506+
for event in events:
507+
entities_seen.add(event.data["entity_id"])
508+
assert event.data["new_state"].state == last_state
509+
assert entities_seen == set(entities_to_assert_change)
510+
events.clear()
511+
512+
config_entry = MockConfigEntry(domain="test")
513+
config_entry.add_to_hass(hass)
514+
515+
device_reg = dr.async_get(hass)
516+
device_entry = device_reg.async_get_or_create(
517+
config_entry_id=config_entry.entry_id,
518+
identifiers={("test", "device_1")},
519+
)
520+
521+
untargeted_device_entry = device_reg.async_get_or_create(
522+
config_entry_id=config_entry.entry_id,
523+
identifiers={("test", "area_device")},
524+
)
525+
526+
entity_reg = er.async_get(hass)
527+
device_entity = entity_reg.async_get_or_create(
528+
domain="light",
529+
platform="test",
530+
unique_id="device_light",
531+
device_id=device_entry.id,
532+
).entity_id
533+
534+
untargeted_device_entity = entity_reg.async_get_or_create(
535+
domain="light",
536+
platform="test",
537+
unique_id="area_device_light",
538+
device_id=untargeted_device_entry.id,
539+
).entity_id
540+
541+
untargeted_entity = entity_reg.async_get_or_create(
542+
domain="light",
543+
platform="test",
544+
unique_id="untargeted_light",
545+
).entity_id
546+
547+
targeted_entity = "light.test_light"
548+
549+
targeted_entities = [targeted_entity, device_entity]
550+
await set_states_and_check_events(targeted_entities, [])
551+
552+
label = lr.async_get(hass).async_create("Test Label").name
553+
area = ar.async_get(hass).async_create("Test Area").id
554+
floor = fr.async_get(hass).async_create("Test Floor").floor_id
555+
556+
selector_config = {
557+
ATTR_ENTITY_ID: targeted_entity,
558+
ATTR_DEVICE_ID: device_entry.id,
559+
ATTR_AREA_ID: area,
560+
ATTR_FLOOR_ID: floor,
561+
ATTR_LABEL_ID: label,
562+
}
563+
unsub = target.async_track_target_selector_state_change_event(
564+
hass, selector_config, state_change_callback
565+
)
566+
567+
# Test directly targeted entity and device
568+
await set_states_and_check_events(targeted_entities, targeted_entities)
569+
570+
# Add new entity to the targeted device -> should trigger on state change
571+
device_entity_2 = entity_reg.async_get_or_create(
572+
domain="light",
573+
platform="test",
574+
unique_id="device_light_2",
575+
device_id=device_entry.id,
576+
).entity_id
577+
578+
targeted_entities = [targeted_entity, device_entity, device_entity_2]
579+
await set_states_and_check_events(targeted_entities, targeted_entities)
580+
581+
# Test untargeted entity -> should not trigger
582+
await set_states_and_check_events(
583+
[*targeted_entities, untargeted_entity], targeted_entities
584+
)
585+
586+
# Add label to untargeted entity -> should trigger now
587+
entity_reg.async_update_entity(untargeted_entity, labels={label})
588+
await set_states_and_check_events(
589+
[*targeted_entities, untargeted_entity], [*targeted_entities, untargeted_entity]
590+
)
591+
592+
# Remove label from untargeted entity -> should not trigger anymore
593+
entity_reg.async_update_entity(untargeted_entity, labels={})
594+
await set_states_and_check_events(
595+
[*targeted_entities, untargeted_entity], targeted_entities
596+
)
597+
598+
# Add area to untargeted entity -> should trigger now
599+
entity_reg.async_update_entity(untargeted_entity, area_id=area)
600+
await set_states_and_check_events(
601+
[*targeted_entities, untargeted_entity], [*targeted_entities, untargeted_entity]
602+
)
603+
604+
# Remove area from untargeted entity -> should not trigger anymore
605+
entity_reg.async_update_entity(untargeted_entity, area_id=None)
606+
await set_states_and_check_events(
607+
[*targeted_entities, untargeted_entity], targeted_entities
608+
)
609+
610+
# Add area to untargeted device -> should trigger on state change
611+
device_reg.async_update_device(untargeted_device_entry.id, area_id=area)
612+
await set_states_and_check_events(
613+
[*targeted_entities, untargeted_device_entity],
614+
[*targeted_entities, untargeted_device_entity],
615+
)
616+
617+
# Remove area from untargeted device -> should not trigger anymore
618+
device_reg.async_update_device(untargeted_device_entry.id, area_id=None)
619+
await set_states_and_check_events(
620+
[*targeted_entities, untargeted_device_entity], targeted_entities
621+
)
622+
623+
# Set the untargeted area on the untargeted entity -> should not trigger
624+
untracked_area = ar.async_get(hass).async_create("Untargeted Area").id
625+
entity_reg.async_update_entity(untargeted_entity, area_id=untracked_area)
626+
await set_states_and_check_events(
627+
[*targeted_entities, untargeted_entity], targeted_entities
628+
)
629+
630+
# Set targeted floor on the untargeted area -> should trigger now
631+
ar.async_get(hass).async_update(untracked_area, floor_id=floor)
632+
await set_states_and_check_events(
633+
[*targeted_entities, untargeted_entity],
634+
[*targeted_entities, untargeted_entity],
635+
)
636+
637+
# Remove untargeted area from targeted floor -> should not trigger anymore
638+
ar.async_get(hass).async_update(untracked_area, floor_id=None)
639+
await set_states_and_check_events(
640+
[*targeted_entities, untargeted_entity], targeted_entities
641+
)
642+
643+
# After unsubscribing, changes should not trigger
644+
unsub()
645+
await set_states_and_check_events(targeted_entities, [])

0 commit comments

Comments
 (0)