Skip to content

Commit aefdf41

Browse files
authored
Extract device template functions into a devices Jinja2 extension (home-assistant#156619)
1 parent 56ab6b2 commit aefdf41

File tree

5 files changed

+471
-411
lines changed

5 files changed

+471
-411
lines changed

homeassistant/helpers/template/__init__.py

Lines changed: 1 addition & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1166,13 +1166,6 @@ def expand(hass: HomeAssistant, *args: Any) -> Iterable[State]:
11661166
return list(found.values())
11671167

11681168

1169-
def device_entities(hass: HomeAssistant, _device_id: str) -> Iterable[str]:
1170-
"""Get entity ids for entities tied to a device."""
1171-
entity_reg = er.async_get(hass)
1172-
entries = er.async_entries_for_device(entity_reg, _device_id)
1173-
return [entry.entity_id for entry in entries]
1174-
1175-
11761169
def integration_entities(hass: HomeAssistant, entry_name: str) -> Iterable[str]:
11771170
"""Get entity ids for entities tied to an integration/domain.
11781171
@@ -1214,65 +1207,6 @@ def config_entry_id(hass: HomeAssistant, entity_id: str) -> str | None:
12141207
return None
12151208

12161209

1217-
def device_id(hass: HomeAssistant, entity_id_or_device_name: str) -> str | None:
1218-
"""Get a device ID from an entity ID or device name."""
1219-
entity_reg = er.async_get(hass)
1220-
entity = entity_reg.async_get(entity_id_or_device_name)
1221-
if entity is not None:
1222-
return entity.device_id
1223-
1224-
dev_reg = dr.async_get(hass)
1225-
return next(
1226-
(
1227-
device_id
1228-
for device_id, device in dev_reg.devices.items()
1229-
if (name := device.name_by_user or device.name)
1230-
and (str(entity_id_or_device_name) == name)
1231-
),
1232-
None,
1233-
)
1234-
1235-
1236-
def device_name(hass: HomeAssistant, lookup_value: str) -> str | None:
1237-
"""Get the device name from an device id, or entity id."""
1238-
device_reg = dr.async_get(hass)
1239-
if device := device_reg.async_get(lookup_value):
1240-
return device.name_by_user or device.name
1241-
1242-
ent_reg = er.async_get(hass)
1243-
# Import here, not at top-level to avoid circular import
1244-
from homeassistant.helpers import config_validation as cv # noqa: PLC0415
1245-
1246-
try:
1247-
cv.entity_id(lookup_value)
1248-
except vol.Invalid:
1249-
pass
1250-
else:
1251-
if entity := ent_reg.async_get(lookup_value):
1252-
if entity.device_id and (device := device_reg.async_get(entity.device_id)):
1253-
return device.name_by_user or device.name
1254-
1255-
return None
1256-
1257-
1258-
def device_attr(hass: HomeAssistant, device_or_entity_id: str, attr_name: str) -> Any:
1259-
"""Get the device specific attribute."""
1260-
device_reg = dr.async_get(hass)
1261-
if not isinstance(device_or_entity_id, str):
1262-
raise TemplateError("Must provide a device or entity ID")
1263-
device = None
1264-
if (
1265-
"." in device_or_entity_id
1266-
and (_device_id := device_id(hass, device_or_entity_id)) is not None
1267-
):
1268-
device = device_reg.async_get(_device_id)
1269-
elif "." not in device_or_entity_id:
1270-
device = device_reg.async_get(device_or_entity_id)
1271-
if device is None or not hasattr(device, attr_name):
1272-
return None
1273-
return getattr(device, attr_name)
1274-
1275-
12761210
def config_entry_attr(
12771211
hass: HomeAssistant, config_entry_id_: str, attr_name: str
12781212
) -> Any:
@@ -1291,13 +1225,6 @@ def config_entry_attr(
12911225
return getattr(config_entry, attr_name)
12921226

12931227

1294-
def is_device_attr(
1295-
hass: HomeAssistant, device_or_entity_id: str, attr_name: str, attr_value: Any
1296-
) -> bool:
1297-
"""Test if a device's attribute is a specific value."""
1298-
return bool(device_attr(hass, device_or_entity_id, attr_name) == attr_value)
1299-
1300-
13011228
def issues(hass: HomeAssistant) -> dict[tuple[str, str], dict[str, Any]]:
13021229
"""Return all open issues."""
13031230
current_issues = ir.async_get(hass).issues
@@ -2260,6 +2187,7 @@ def __init__(
22602187
"homeassistant.helpers.template.extensions.CollectionExtension"
22612188
)
22622189
self.add_extension("homeassistant.helpers.template.extensions.CryptoExtension")
2190+
self.add_extension("homeassistant.helpers.template.extensions.DeviceExtension")
22632191
self.add_extension("homeassistant.helpers.template.extensions.FloorExtension")
22642192
self.add_extension("homeassistant.helpers.template.extensions.LabelExtension")
22652193
self.add_extension("homeassistant.helpers.template.extensions.MathExtension")
@@ -2377,23 +2305,6 @@ def wrapper(_: Any, *args: _P.args, **kwargs: _P.kwargs) -> _R:
23772305
self.globals["config_entry_id"] = hassfunction(config_entry_id)
23782306
self.filters["config_entry_id"] = self.globals["config_entry_id"]
23792307

2380-
# Device extensions
2381-
2382-
self.globals["device_name"] = hassfunction(device_name)
2383-
self.filters["device_name"] = self.globals["device_name"]
2384-
2385-
self.globals["device_attr"] = hassfunction(device_attr)
2386-
self.filters["device_attr"] = self.globals["device_attr"]
2387-
2388-
self.globals["device_entities"] = hassfunction(device_entities)
2389-
self.filters["device_entities"] = self.globals["device_entities"]
2390-
2391-
self.globals["is_device_attr"] = hassfunction(is_device_attr)
2392-
self.tests["is_device_attr"] = hassfunction(is_device_attr, pass_eval_context)
2393-
2394-
self.globals["device_id"] = hassfunction(device_id)
2395-
self.filters["device_id"] = self.globals["device_id"]
2396-
23972308
# Issue extensions
23982309

23992310
self.globals["issues"] = hassfunction(issues)
@@ -2415,12 +2326,9 @@ def warn_unsupported(*args: Any, **kwargs: Any) -> NoReturn:
24152326
"area_id",
24162327
"area_name",
24172328
"closest",
2418-
"device_attr",
2419-
"device_id",
24202329
"distance",
24212330
"expand",
24222331
"has_value",
2423-
"is_device_attr",
24242332
"is_hidden_entity",
24252333
"is_state_attr",
24262334
"is_state",
@@ -2438,7 +2346,6 @@ def warn_unsupported(*args: Any, **kwargs: Any) -> NoReturn:
24382346
"area_id",
24392347
"area_name",
24402348
"closest",
2441-
"device_id",
24422349
"expand",
24432350
"has_value",
24442351
]

homeassistant/helpers/template/extensions/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .base64 import Base64Extension
44
from .collection import CollectionExtension
55
from .crypto import CryptoExtension
6+
from .devices import DeviceExtension
67
from .floors import FloorExtension
78
from .labels import LabelExtension
89
from .math import MathExtension
@@ -13,6 +14,7 @@
1314
"Base64Extension",
1415
"CollectionExtension",
1516
"CryptoExtension",
17+
"DeviceExtension",
1618
"FloorExtension",
1719
"LabelExtension",
1820
"MathExtension",
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
"""Device functions for Home Assistant templates."""
2+
3+
from __future__ import annotations
4+
5+
from collections.abc import Iterable
6+
from typing import TYPE_CHECKING, Any
7+
8+
import voluptuous as vol
9+
10+
from homeassistant.exceptions import TemplateError
11+
from homeassistant.helpers import (
12+
config_validation as cv,
13+
device_registry as dr,
14+
entity_registry as er,
15+
)
16+
17+
from .base import BaseTemplateExtension, TemplateFunction
18+
19+
if TYPE_CHECKING:
20+
from homeassistant.helpers.template import TemplateEnvironment
21+
22+
23+
class DeviceExtension(BaseTemplateExtension):
24+
"""Extension for device-related template functions."""
25+
26+
def __init__(self, environment: TemplateEnvironment) -> None:
27+
"""Initialize the device extension."""
28+
super().__init__(
29+
environment,
30+
functions=[
31+
TemplateFunction(
32+
"device_entities",
33+
self.device_entities,
34+
as_global=True,
35+
as_filter=True,
36+
requires_hass=True,
37+
),
38+
TemplateFunction(
39+
"device_id",
40+
self.device_id,
41+
as_global=True,
42+
as_filter=True,
43+
requires_hass=True,
44+
limited_ok=False,
45+
),
46+
TemplateFunction(
47+
"device_name",
48+
self.device_name,
49+
as_global=True,
50+
as_filter=True,
51+
requires_hass=True,
52+
limited_ok=False,
53+
),
54+
TemplateFunction(
55+
"device_attr",
56+
self.device_attr,
57+
as_global=True,
58+
as_filter=True,
59+
requires_hass=True,
60+
limited_ok=False,
61+
),
62+
TemplateFunction(
63+
"is_device_attr",
64+
self.is_device_attr,
65+
as_global=True,
66+
as_test=True,
67+
requires_hass=True,
68+
limited_ok=False,
69+
),
70+
],
71+
)
72+
73+
def device_entities(self, _device_id: str) -> Iterable[str]:
74+
"""Get entity ids for entities tied to a device."""
75+
entity_reg = er.async_get(self.hass)
76+
entries = er.async_entries_for_device(entity_reg, _device_id)
77+
return [entry.entity_id for entry in entries]
78+
79+
def device_id(self, entity_id_or_device_name: str) -> str | None:
80+
"""Get a device ID from an entity ID or device name."""
81+
entity_reg = er.async_get(self.hass)
82+
entity = entity_reg.async_get(entity_id_or_device_name)
83+
if entity is not None:
84+
return entity.device_id
85+
86+
dev_reg = dr.async_get(self.hass)
87+
return next(
88+
(
89+
device_id
90+
for device_id, device in dev_reg.devices.items()
91+
if (name := device.name_by_user or device.name)
92+
and (str(entity_id_or_device_name) == name)
93+
),
94+
None,
95+
)
96+
97+
def device_name(self, lookup_value: str) -> str | None:
98+
"""Get the device name from an device id, or entity id."""
99+
device_reg = dr.async_get(self.hass)
100+
if device := device_reg.async_get(lookup_value):
101+
return device.name_by_user or device.name
102+
103+
ent_reg = er.async_get(self.hass)
104+
105+
try:
106+
cv.entity_id(lookup_value)
107+
except vol.Invalid:
108+
pass
109+
else:
110+
if entity := ent_reg.async_get(lookup_value):
111+
if entity.device_id and (
112+
device := device_reg.async_get(entity.device_id)
113+
):
114+
return device.name_by_user or device.name
115+
116+
return None
117+
118+
def device_attr(self, device_or_entity_id: str, attr_name: str) -> Any:
119+
"""Get the device specific attribute."""
120+
device_reg = dr.async_get(self.hass)
121+
if not isinstance(device_or_entity_id, str):
122+
raise TemplateError("Must provide a device or entity ID")
123+
device = None
124+
if (
125+
"." in device_or_entity_id
126+
and (_device_id := self.device_id(device_or_entity_id)) is not None
127+
):
128+
device = device_reg.async_get(_device_id)
129+
elif "." not in device_or_entity_id:
130+
device = device_reg.async_get(device_or_entity_id)
131+
if device is None or not hasattr(device, attr_name):
132+
return None
133+
return getattr(device, attr_name)
134+
135+
def is_device_attr(
136+
self, device_or_entity_id: str, attr_name: str, attr_value: Any
137+
) -> bool:
138+
"""Test if a device's attribute is a specific value."""
139+
return bool(self.device_attr(device_or_entity_id, attr_name) == attr_value)

0 commit comments

Comments
 (0)