Skip to content

Commit e5d512d

Browse files
abmantisCopilot
andauthored
Add entity filter to target state change tracker (#150064)
Co-authored-by: Copilot <[email protected]>
1 parent 2b5028b commit e5d512d

File tree

2 files changed

+121
-15
lines changed

2 files changed

+121
-15
lines changed

homeassistant/helpers/target.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,11 +268,13 @@ def __init__(
268268
hass: HomeAssistant,
269269
selector_data: TargetSelectorData,
270270
action: Callable[[TargetStateChangedData], Any],
271+
entity_filter: Callable[[set[str]], set[str]],
271272
) -> None:
272273
"""Initialize the state change tracker."""
273274
self._hass = hass
274275
self._selector_data = selector_data
275276
self._action = action
277+
self._entity_filter = entity_filter
276278

277279
self._state_change_unsub: CALLBACK_TYPE | None = None
278280
self._registry_unsubs: list[CALLBACK_TYPE] = []
@@ -289,7 +291,9 @@ def _track_entities_state_change(self) -> None:
289291
self._hass, self._selector_data, expand_group=False
290292
)
291293

292-
tracked_entities = selected.referenced.union(selected.indirectly_referenced)
294+
tracked_entities = self._entity_filter(
295+
selected.referenced.union(selected.indirectly_referenced)
296+
)
293297

294298
@callback
295299
def state_change_listener(event: Event[EventStateChangedData]) -> None:
@@ -348,12 +352,13 @@ def async_track_target_selector_state_change_event(
348352
hass: HomeAssistant,
349353
target_selector_config: ConfigType,
350354
action: Callable[[TargetStateChangedData], Any],
355+
entity_filter: Callable[[set[str]], set[str]] = lambda x: x,
351356
) -> CALLBACK_TYPE:
352357
"""Track state changes for entities referenced directly or indirectly in a target selector."""
353358
selector_data = TargetSelectorData(target_selector_config)
354359
if not selector_data.has_any_selector:
355360
raise HomeAssistantError(
356361
f"Target selector {target_selector_config} does not have any selectors defined"
357362
)
358-
tracker = TargetStateChangeTracker(hass, selector_data, action)
363+
tracker = TargetStateChangeTracker(hass, selector_data, action, entity_filter)
359364
return tracker.async_setup()

tests/helpers/test_target.py

Lines changed: 114 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,29 @@
3636
)
3737

3838

39+
async def set_states_and_check_target_events(
40+
hass: HomeAssistant,
41+
events: list[target.TargetStateChangedData],
42+
state: str,
43+
entities_to_set_state: list[str],
44+
entities_to_assert_change: list[str],
45+
) -> None:
46+
"""Toggle the state entities and check for events."""
47+
for entity_id in entities_to_set_state:
48+
hass.states.async_set(entity_id, state)
49+
await hass.async_block_till_done()
50+
51+
assert len(events) == len(entities_to_assert_change)
52+
entities_seen = set()
53+
for event in events:
54+
state_change_event = event.state_change_event
55+
entities_seen.add(state_change_event.data["entity_id"])
56+
assert state_change_event.data["new_state"].state == state
57+
assert event.targeted_entity_ids == set(entities_to_assert_change)
58+
assert entities_seen == set(entities_to_assert_change)
59+
events.clear()
60+
61+
3962
@pytest.fixture
4063
def registries_mock(hass: HomeAssistant) -> None:
4164
"""Mock including floor and area info."""
@@ -497,19 +520,9 @@ async def set_states_and_check_events(
497520
"""Toggle the state entities and check for events."""
498521
nonlocal last_state
499522
last_state = STATE_ON if last_state == STATE_OFF else STATE_OFF
500-
for entity_id in entities_to_set_state:
501-
hass.states.async_set(entity_id, last_state)
502-
await hass.async_block_till_done()
503-
504-
assert len(events) == len(entities_to_assert_change)
505-
entities_seen = set()
506-
for event in events:
507-
state_change_event = event.state_change_event
508-
entities_seen.add(state_change_event.data["entity_id"])
509-
assert state_change_event.data["new_state"].state == last_state
510-
assert event.targeted_entity_ids == set(entities_to_assert_change)
511-
assert entities_seen == set(entities_to_assert_change)
512-
events.clear()
523+
await set_states_and_check_target_events(
524+
hass, events, last_state, entities_to_set_state, entities_to_assert_change
525+
)
513526

514527
config_entry = MockConfigEntry(domain="test")
515528
config_entry.add_to_hass(hass)
@@ -645,3 +658,91 @@ async def set_states_and_check_events(
645658
# After unsubscribing, changes should not trigger
646659
unsub()
647660
await set_states_and_check_events(targeted_entities, [])
661+
662+
663+
async def test_async_track_target_selector_state_change_event_filter(
664+
hass: HomeAssistant,
665+
) -> None:
666+
"""Test async_track_target_selector_state_change_event with entity filter."""
667+
events: list[target.TargetStateChangedData] = []
668+
669+
filtered_entity = ""
670+
671+
@callback
672+
def entity_filter(entity_ids: set[str]) -> set[str]:
673+
return {entity_id for entity_id in entity_ids if entity_id != filtered_entity}
674+
675+
@callback
676+
def state_change_callback(event: target.TargetStateChangedData):
677+
"""Handle state change events."""
678+
events.append(event)
679+
680+
last_state = STATE_OFF
681+
682+
async def set_states_and_check_events(
683+
entities_to_set_state: list[str], entities_to_assert_change: list[str]
684+
) -> None:
685+
"""Toggle the state entities and check for events."""
686+
nonlocal last_state
687+
last_state = STATE_ON if last_state == STATE_OFF else STATE_OFF
688+
await set_states_and_check_target_events(
689+
hass, events, last_state, entities_to_set_state, entities_to_assert_change
690+
)
691+
692+
config_entry = MockConfigEntry(domain="test")
693+
config_entry.add_to_hass(hass)
694+
695+
entity_reg = er.async_get(hass)
696+
697+
label = lr.async_get(hass).async_create("Test Label").name
698+
label_entity = entity_reg.async_get_or_create(
699+
domain="light",
700+
platform="test",
701+
unique_id="label_light",
702+
).entity_id
703+
entity_reg.async_update_entity(label_entity, labels={label})
704+
705+
targeted_entity = "light.test_light"
706+
707+
targeted_entities = [targeted_entity, label_entity]
708+
await set_states_and_check_events(targeted_entities, [])
709+
710+
selector_config = {
711+
ATTR_ENTITY_ID: targeted_entity,
712+
ATTR_LABEL_ID: label,
713+
}
714+
unsub = target.async_track_target_selector_state_change_event(
715+
hass, selector_config, state_change_callback, entity_filter
716+
)
717+
718+
await set_states_and_check_events(
719+
targeted_entities, [targeted_entity, label_entity]
720+
)
721+
722+
filtered_entity = targeted_entity
723+
# Fire an event so that the targeted entities are re-evaluated
724+
hass.bus.async_fire(
725+
er.EVENT_ENTITY_REGISTRY_UPDATED,
726+
{
727+
"action": "update",
728+
"entity_id": "light.other",
729+
"changes": {},
730+
},
731+
)
732+
await set_states_and_check_events([targeted_entity, label_entity], [label_entity])
733+
734+
filtered_entity = label_entity
735+
# Fire an event so that the targeted entities are re-evaluated
736+
hass.bus.async_fire(
737+
er.EVENT_ENTITY_REGISTRY_UPDATED,
738+
{
739+
"action": "update",
740+
"entity_id": "light.other",
741+
"changes": {},
742+
},
743+
)
744+
await set_states_and_check_events(
745+
[targeted_entity, label_entity], [targeted_entity]
746+
)
747+
748+
unsub()

0 commit comments

Comments
 (0)