Skip to content

Commit a6aab08

Browse files
authored
Add get_triggers_for_target websocket command (home-assistant#156778)
1 parent 655a63c commit a6aab08

File tree

3 files changed

+645
-0
lines changed

3 files changed

+645
-0
lines changed
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
"""Automation related helper methods for the Websocket API."""
2+
3+
from __future__ import annotations
4+
5+
from dataclasses import dataclass
6+
import logging
7+
from typing import Any, Self
8+
9+
from homeassistant.const import CONF_TARGET
10+
from homeassistant.core import HomeAssistant
11+
from homeassistant.helpers import target as target_helpers
12+
from homeassistant.helpers.entity import (
13+
entity_sources,
14+
get_device_class,
15+
get_supported_features,
16+
)
17+
from homeassistant.helpers.trigger import (
18+
async_get_all_descriptions as async_get_all_trigger_descriptions,
19+
)
20+
from homeassistant.helpers.typing import ConfigType
21+
22+
_LOGGER = logging.getLogger(__name__)
23+
24+
25+
@dataclass(slots=True, kw_only=True)
26+
class _EntityFilter:
27+
"""Single entity filter configuration."""
28+
29+
integration: str | None
30+
domains: set[str]
31+
device_classes: set[str]
32+
supported_features: set[int]
33+
34+
def matches(
35+
self, hass: HomeAssistant, entity_id: str, domain: str, integration: str
36+
) -> bool:
37+
"""Return if entity matches all criteria in this filter."""
38+
if self.integration and integration != self.integration:
39+
return False
40+
41+
if self.domains and domain not in self.domains:
42+
return False
43+
44+
if self.device_classes:
45+
if (
46+
entity_device_class := get_device_class(hass, entity_id)
47+
) is None or entity_device_class not in self.device_classes:
48+
return False
49+
50+
if self.supported_features:
51+
entity_supported_features = get_supported_features(hass, entity_id)
52+
if not any(
53+
feature & entity_supported_features == feature
54+
for feature in self.supported_features
55+
):
56+
return False
57+
58+
return True
59+
60+
61+
@dataclass(slots=True, kw_only=True)
62+
class _AutomationComponentLookupData:
63+
"""Helper class for looking up automation components."""
64+
65+
component: str
66+
filters: list[_EntityFilter]
67+
68+
@classmethod
69+
def create(cls, component: str, target_description: dict[str, Any]) -> Self:
70+
"""Build automation component lookup data from target description."""
71+
filters: list[_EntityFilter] = []
72+
73+
entity_filters_config = target_description.get("entity", [])
74+
for entity_filter_config in entity_filters_config:
75+
entity_filter = _EntityFilter(
76+
integration=entity_filter_config.get("integration"),
77+
domains=set(entity_filter_config.get("domain", [])),
78+
device_classes=set(entity_filter_config.get("device_class", [])),
79+
supported_features=set(
80+
entity_filter_config.get("supported_features", [])
81+
),
82+
)
83+
filters.append(entity_filter)
84+
85+
return cls(component=component, filters=filters)
86+
87+
def matches(
88+
self, hass: HomeAssistant, entity_id: str, domain: str, integration: str
89+
) -> bool:
90+
"""Return if entity matches ANY of the filters."""
91+
if not self.filters:
92+
return True
93+
return any(
94+
f.matches(hass, entity_id, domain, integration) for f in self.filters
95+
)
96+
97+
98+
def _get_automation_component_domains(
99+
target_description: dict[str, Any],
100+
) -> set[str | None]:
101+
"""Get a list of domains (including integration domains) of an automation component.
102+
103+
The list of domains is extracted from each target's entity filters.
104+
If a filter is missing both domain and integration keys, None is added to the
105+
returned set.
106+
"""
107+
entity_filters_config = target_description.get("entity", [])
108+
if not entity_filters_config:
109+
return {None}
110+
111+
domains: set[str | None] = set()
112+
for entity_filter_config in entity_filters_config:
113+
filter_integration = entity_filter_config.get("integration")
114+
filter_domains = entity_filter_config.get("domain", [])
115+
116+
if not filter_domains and not filter_integration:
117+
domains.add(None)
118+
continue
119+
120+
if filter_integration:
121+
domains.add(filter_integration)
122+
123+
for domain in filter_domains:
124+
domains.add(domain)
125+
126+
return domains
127+
128+
129+
def _async_get_automation_components_for_target(
130+
hass: HomeAssistant,
131+
target_selection: ConfigType,
132+
expand_group: bool,
133+
component_descriptions: dict[str, dict[str, Any] | None],
134+
) -> set[str]:
135+
"""Get automation components (triggers/conditions/services) for a target.
136+
137+
Returns all components that can be used on any entity that are currently part of a target.
138+
"""
139+
extracted = target_helpers.async_extract_referenced_entity_ids(
140+
hass,
141+
target_helpers.TargetSelectorData(target_selection),
142+
expand_group=expand_group,
143+
)
144+
_LOGGER.debug("Extracted entities for lookup: %s", extracted)
145+
146+
# Build lookup structure: domain -> list of trigger/condition/service lookup data
147+
domain_components: dict[str | None, list[_AutomationComponentLookupData]] = {}
148+
component_count = 0
149+
for component, description in component_descriptions.items():
150+
if description is None or CONF_TARGET not in description:
151+
_LOGGER.debug("Skipping component %s without target description", component)
152+
continue
153+
domains = _get_automation_component_domains(description[CONF_TARGET])
154+
lookup_data = _AutomationComponentLookupData.create(
155+
component, description[CONF_TARGET]
156+
)
157+
for domain in domains:
158+
domain_components.setdefault(domain, []).append(lookup_data)
159+
component_count += 1
160+
161+
_LOGGER.debug("Automation components per domain: %s", domain_components)
162+
163+
entity_infos = entity_sources(hass)
164+
matched_components: set[str] = set()
165+
for entity_id in extracted.referenced | extracted.indirectly_referenced:
166+
if component_count == len(matched_components):
167+
# All automation components matched already, so we don't need to iterate further
168+
break
169+
170+
entity_info = entity_infos.get(entity_id)
171+
if entity_info is None:
172+
_LOGGER.debug("No entity source found for %s", entity_id)
173+
continue
174+
175+
entity_domain = entity_id.split(".")[0]
176+
entity_integration = entity_info["domain"]
177+
for domain in (entity_domain, entity_integration, None):
178+
for component_data in domain_components.get(domain, []):
179+
if component_data.component in matched_components:
180+
continue
181+
if component_data.matches(
182+
hass, entity_id, entity_domain, entity_integration
183+
):
184+
matched_components.add(component_data.component)
185+
186+
return matched_components
187+
188+
189+
async def async_get_triggers_for_target(
190+
hass: HomeAssistant, target_selector: ConfigType, expand_group: bool
191+
) -> set[str]:
192+
"""Get triggers for a target."""
193+
descriptions = await async_get_all_trigger_descriptions(hass)
194+
return _async_get_automation_components_for_target(
195+
hass, target_selector, expand_group, descriptions
196+
)

homeassistant/components/websocket_api/commands.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@
8787
from homeassistant.util.json import format_unserializable_data
8888

8989
from . import const, decorators, messages
90+
from .automation import async_get_triggers_for_target
9091
from .connection import ActiveConnection
9192
from .messages import construct_event_message, construct_result_message
9293

@@ -107,6 +108,7 @@ def async_register_commands(
107108
async_reg(hass, handle_entity_source)
108109
async_reg(hass, handle_execute_script)
109110
async_reg(hass, handle_extract_from_target)
111+
async_reg(hass, handle_get_triggers_for_target)
110112
async_reg(hass, handle_fire_event)
111113
async_reg(hass, handle_get_config)
112114
async_reg(hass, handle_get_services)
@@ -877,6 +879,29 @@ def handle_extract_from_target(
877879
connection.send_result(msg["id"], extracted_dict)
878880

879881

882+
@decorators.websocket_command(
883+
{
884+
vol.Required("type"): "get_triggers_for_target",
885+
vol.Required("target"): cv.TARGET_FIELDS,
886+
vol.Optional("expand_group", default=True): bool,
887+
}
888+
)
889+
@decorators.async_response
890+
async def handle_get_triggers_for_target(
891+
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
892+
) -> None:
893+
"""Handle get triggers for target command.
894+
895+
This command returns all triggers that can be used with any entities that are currently
896+
part of a target.
897+
"""
898+
triggers = await async_get_triggers_for_target(
899+
hass, msg["target"], msg["expand_group"]
900+
)
901+
902+
connection.send_result(msg["id"], triggers)
903+
904+
880905
@decorators.websocket_command(
881906
{
882907
vol.Required("type"): "subscribe_trigger",

0 commit comments

Comments
 (0)