Skip to content

Commit e5b2d44

Browse files
authored
Extract area template functions into an areas Jinja2 extension (home-assistant#156629)
1 parent 4d4ad90 commit e5b2d44

File tree

5 files changed

+476
-418
lines changed

5 files changed

+476
-418
lines changed

homeassistant/helpers/template/__init__.py

Lines changed: 2 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,6 @@
5656
)
5757
from homeassistant.exceptions import TemplateError
5858
from homeassistant.helpers import (
59-
area_registry as ar,
60-
device_registry as dr,
6159
entity_registry as er,
6260
issue_registry as ir,
6361
location as loc_helper,
@@ -78,7 +76,7 @@
7876
template_context_manager,
7977
template_cv,
8078
)
81-
from .helpers import raise_no_default, resolve_area_id
79+
from .helpers import raise_no_default
8280
from .render_info import RenderInfo, render_info_cv
8381

8482
if TYPE_CHECKING:
@@ -1244,103 +1242,6 @@ def issue(hass: HomeAssistant, domain: str, issue_id: str) -> dict[str, Any] | N
12441242
return None
12451243

12461244

1247-
def areas(hass: HomeAssistant) -> Iterable[str | None]:
1248-
"""Return all areas."""
1249-
return list(ar.async_get(hass).areas)
1250-
1251-
1252-
def area_id(hass: HomeAssistant, lookup_value: str) -> str | None:
1253-
"""Get the area ID from an area name, alias, device id, or entity id."""
1254-
return resolve_area_id(hass, lookup_value)
1255-
1256-
1257-
def _get_area_name(area_reg: ar.AreaRegistry, valid_area_id: str) -> str:
1258-
"""Get area name from valid area ID."""
1259-
area = area_reg.async_get_area(valid_area_id)
1260-
assert area
1261-
return area.name
1262-
1263-
1264-
def area_name(hass: HomeAssistant, lookup_value: str) -> str | None:
1265-
"""Get the area name from an area id, device id, or entity id."""
1266-
area_reg = ar.async_get(hass)
1267-
if area := area_reg.async_get_area(lookup_value):
1268-
return area.name
1269-
1270-
dev_reg = dr.async_get(hass)
1271-
ent_reg = er.async_get(hass)
1272-
# Import here, not at top-level to avoid circular import
1273-
from homeassistant.helpers import config_validation as cv # noqa: PLC0415
1274-
1275-
try:
1276-
cv.entity_id(lookup_value)
1277-
except vol.Invalid:
1278-
pass
1279-
else:
1280-
if entity := ent_reg.async_get(lookup_value):
1281-
# If entity has an area ID, get the area name for that
1282-
if entity.area_id:
1283-
return _get_area_name(area_reg, entity.area_id)
1284-
# If entity has a device ID and the device exists with an area ID, get the
1285-
# area name for that
1286-
if (
1287-
entity.device_id
1288-
and (device := dev_reg.async_get(entity.device_id))
1289-
and device.area_id
1290-
):
1291-
return _get_area_name(area_reg, device.area_id)
1292-
1293-
if (device := dev_reg.async_get(lookup_value)) and device.area_id:
1294-
return _get_area_name(area_reg, device.area_id)
1295-
1296-
return None
1297-
1298-
1299-
def area_entities(hass: HomeAssistant, area_id_or_name: str) -> Iterable[str]:
1300-
"""Return entities for a given area ID or name."""
1301-
_area_id: str | None
1302-
# if area_name returns a value, we know the input was an ID, otherwise we
1303-
# assume it's a name, and if it's neither, we return early
1304-
if area_name(hass, area_id_or_name) is None:
1305-
_area_id = area_id(hass, area_id_or_name)
1306-
else:
1307-
_area_id = area_id_or_name
1308-
if _area_id is None:
1309-
return []
1310-
ent_reg = er.async_get(hass)
1311-
entity_ids = [
1312-
entry.entity_id for entry in er.async_entries_for_area(ent_reg, _area_id)
1313-
]
1314-
dev_reg = dr.async_get(hass)
1315-
# We also need to add entities tied to a device in the area that don't themselves
1316-
# have an area specified since they inherit the area from the device.
1317-
entity_ids.extend(
1318-
[
1319-
entity.entity_id
1320-
for device in dr.async_entries_for_area(dev_reg, _area_id)
1321-
for entity in er.async_entries_for_device(ent_reg, device.id)
1322-
if entity.area_id is None
1323-
]
1324-
)
1325-
return entity_ids
1326-
1327-
1328-
def area_devices(hass: HomeAssistant, area_id_or_name: str) -> Iterable[str]:
1329-
"""Return device IDs for a given area ID or name."""
1330-
_area_id: str | None
1331-
# if area_name returns a value, we know the input was an ID, otherwise we
1332-
# assume it's a name, and if it's neither, we return early
1333-
if area_name(hass, area_id_or_name) is not None:
1334-
_area_id = area_id_or_name
1335-
else:
1336-
_area_id = area_id(hass, area_id_or_name)
1337-
if _area_id is None:
1338-
return []
1339-
dev_reg = dr.async_get(hass)
1340-
entries = dr.async_entries_for_area(dev_reg, _area_id)
1341-
return [entry.id for entry in entries]
1342-
1343-
13441245
def closest(hass: HomeAssistant, *args: Any) -> State | None:
13451246
"""Find closest entity.
13461247
@@ -2182,6 +2083,7 @@ def __init__(
21822083
] = weakref.WeakValueDictionary()
21832084
self.add_extension("jinja2.ext.loopcontrols")
21842085
self.add_extension("jinja2.ext.do")
2086+
self.add_extension("homeassistant.helpers.template.extensions.AreaExtension")
21852087
self.add_extension("homeassistant.helpers.template.extensions.Base64Extension")
21862088
self.add_extension(
21872089
"homeassistant.helpers.template.extensions.CollectionExtension"
@@ -2276,22 +2178,6 @@ def wrapper(_: Any, *args: _P.args, **kwargs: _P.kwargs) -> _R:
22762178

22772179
return jinja_context(wrapper)
22782180

2279-
# Area extensions
2280-
2281-
self.globals["areas"] = hassfunction(areas)
2282-
2283-
self.globals["area_id"] = hassfunction(area_id)
2284-
self.filters["area_id"] = self.globals["area_id"]
2285-
2286-
self.globals["area_name"] = hassfunction(area_name)
2287-
self.filters["area_name"] = self.globals["area_name"]
2288-
2289-
self.globals["area_entities"] = hassfunction(area_entities)
2290-
self.filters["area_entities"] = self.globals["area_entities"]
2291-
2292-
self.globals["area_devices"] = hassfunction(area_devices)
2293-
self.filters["area_devices"] = self.globals["area_devices"]
2294-
22952181
# Integration extensions
22962182

22972183
self.globals["integration_entities"] = hassfunction(integration_entities)

homeassistant/helpers/template/extensions/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Home Assistant template extensions."""
22

3+
from .areas import AreaExtension
34
from .base64 import Base64Extension
45
from .collection import CollectionExtension
56
from .crypto import CryptoExtension
@@ -11,6 +12,7 @@
1112
from .string import StringExtension
1213

1314
__all__ = [
15+
"AreaExtension",
1416
"Base64Extension",
1517
"CollectionExtension",
1618
"CryptoExtension",
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
"""Area functions for Home Assistant templates."""
2+
3+
from __future__ import annotations
4+
5+
from collections.abc import Iterable
6+
from typing import TYPE_CHECKING
7+
8+
import voluptuous as vol
9+
10+
from homeassistant.helpers import (
11+
area_registry as ar,
12+
device_registry as dr,
13+
entity_registry as er,
14+
)
15+
from homeassistant.helpers.template.helpers import resolve_area_id
16+
17+
from .base import BaseTemplateExtension, TemplateFunction
18+
19+
if TYPE_CHECKING:
20+
from homeassistant.helpers.template import TemplateEnvironment
21+
22+
23+
class AreaExtension(BaseTemplateExtension):
24+
"""Extension for area-related template functions."""
25+
26+
def __init__(self, environment: TemplateEnvironment) -> None:
27+
"""Initialize the area extension."""
28+
super().__init__(
29+
environment,
30+
functions=[
31+
TemplateFunction(
32+
"areas",
33+
self.areas,
34+
as_global=True,
35+
requires_hass=True,
36+
),
37+
TemplateFunction(
38+
"area_id",
39+
self.area_id,
40+
as_global=True,
41+
as_filter=True,
42+
requires_hass=True,
43+
limited_ok=False,
44+
),
45+
TemplateFunction(
46+
"area_name",
47+
self.area_name,
48+
as_global=True,
49+
as_filter=True,
50+
requires_hass=True,
51+
limited_ok=False,
52+
),
53+
TemplateFunction(
54+
"area_entities",
55+
self.area_entities,
56+
as_global=True,
57+
as_filter=True,
58+
requires_hass=True,
59+
),
60+
TemplateFunction(
61+
"area_devices",
62+
self.area_devices,
63+
as_global=True,
64+
as_filter=True,
65+
requires_hass=True,
66+
),
67+
],
68+
)
69+
70+
def areas(self) -> Iterable[str | None]:
71+
"""Return all areas."""
72+
return list(ar.async_get(self.hass).areas)
73+
74+
def area_id(self, lookup_value: str) -> str | None:
75+
"""Get the area ID from an area name, alias, device id, or entity id."""
76+
return resolve_area_id(self.hass, lookup_value)
77+
78+
def _get_area_name(self, area_reg: ar.AreaRegistry, valid_area_id: str) -> str:
79+
"""Get area name from valid area ID."""
80+
area = area_reg.async_get_area(valid_area_id)
81+
assert area
82+
return area.name
83+
84+
def area_name(self, lookup_value: str) -> str | None:
85+
"""Get the area name from an area id, device id, or entity id."""
86+
area_reg = ar.async_get(self.hass)
87+
if area := area_reg.async_get_area(lookup_value):
88+
return area.name
89+
90+
dev_reg = dr.async_get(self.hass)
91+
ent_reg = er.async_get(self.hass)
92+
# Import here, not at top-level to avoid circular import
93+
from homeassistant.helpers import config_validation as cv # noqa: PLC0415
94+
95+
try:
96+
cv.entity_id(lookup_value)
97+
except vol.Invalid:
98+
pass
99+
else:
100+
if entity := ent_reg.async_get(lookup_value):
101+
# If entity has an area ID, get the area name for that
102+
if entity.area_id:
103+
return self._get_area_name(area_reg, entity.area_id)
104+
# If entity has a device ID and the device exists with an area ID, get the
105+
# area name for that
106+
if (
107+
entity.device_id
108+
and (device := dev_reg.async_get(entity.device_id))
109+
and device.area_id
110+
):
111+
return self._get_area_name(area_reg, device.area_id)
112+
113+
if (device := dev_reg.async_get(lookup_value)) and device.area_id:
114+
return self._get_area_name(area_reg, device.area_id)
115+
116+
return None
117+
118+
def area_entities(self, area_id_or_name: str) -> Iterable[str]:
119+
"""Return entities for a given area ID or name."""
120+
_area_id: str | None
121+
# if area_name returns a value, we know the input was an ID, otherwise we
122+
# assume it's a name, and if it's neither, we return early
123+
if self.area_name(area_id_or_name) is None:
124+
_area_id = self.area_id(area_id_or_name)
125+
else:
126+
_area_id = area_id_or_name
127+
if _area_id is None:
128+
return []
129+
ent_reg = er.async_get(self.hass)
130+
entity_ids = [
131+
entry.entity_id for entry in er.async_entries_for_area(ent_reg, _area_id)
132+
]
133+
dev_reg = dr.async_get(self.hass)
134+
# We also need to add entities tied to a device in the area that don't themselves
135+
# have an area specified since they inherit the area from the device.
136+
entity_ids.extend(
137+
[
138+
entity.entity_id
139+
for device in dr.async_entries_for_area(dev_reg, _area_id)
140+
for entity in er.async_entries_for_device(ent_reg, device.id)
141+
if entity.area_id is None
142+
]
143+
)
144+
return entity_ids
145+
146+
def area_devices(self, area_id_or_name: str) -> Iterable[str]:
147+
"""Return device IDs for a given area ID or name."""
148+
_area_id: str | None
149+
# if area_name returns a value, we know the input was an ID, otherwise we
150+
# assume it's a name, and if it's neither, we return early
151+
if self.area_name(area_id_or_name) is not None:
152+
_area_id = area_id_or_name
153+
else:
154+
_area_id = self.area_id(area_id_or_name)
155+
if _area_id is None:
156+
return []
157+
dev_reg = dr.async_get(self.hass)
158+
entries = dr.async_entries_for_area(dev_reg, _area_id)
159+
return [entry.id for entry in entries]

0 commit comments

Comments
 (0)