Skip to content

Commit 25fbcbc

Browse files
frenckCopilot
andauthored
Extract floor template functions into a floors Jinja2 extension (home-assistant#156589)
Co-authored-by: Copilot <[email protected]>
1 parent a670286 commit 25fbcbc

File tree

7 files changed

+608
-408
lines changed

7 files changed

+608
-408
lines changed

homeassistant/helpers/template/__init__.py

Lines changed: 3 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@
5959
area_registry as ar,
6060
device_registry as dr,
6161
entity_registry as er,
62-
floor_registry as fr,
6362
issue_registry as ir,
6463
location as loc_helper,
6564
)
@@ -79,7 +78,7 @@
7978
template_context_manager,
8079
template_cv,
8180
)
82-
from .helpers import raise_no_default
81+
from .helpers import raise_no_default, resolve_area_id
8382
from .render_info import RenderInfo, render_info_cv
8483

8584
if TYPE_CHECKING:
@@ -1318,112 +1317,14 @@ def issue(hass: HomeAssistant, domain: str, issue_id: str) -> dict[str, Any] | N
13181317
return None
13191318

13201319

1321-
def floors(hass: HomeAssistant) -> Iterable[str | None]:
1322-
"""Return all floors."""
1323-
floor_registry = fr.async_get(hass)
1324-
return [floor.floor_id for floor in floor_registry.async_list_floors()]
1325-
1326-
1327-
def floor_id(hass: HomeAssistant, lookup_value: Any) -> str | None:
1328-
"""Get the floor ID from a floor or area name, alias, device id, or entity id."""
1329-
floor_registry = fr.async_get(hass)
1330-
lookup_str = str(lookup_value)
1331-
if floor := floor_registry.async_get_floor_by_name(lookup_str):
1332-
return floor.floor_id
1333-
floors_list = floor_registry.async_get_floors_by_alias(lookup_str)
1334-
if floors_list:
1335-
return floors_list[0].floor_id
1336-
1337-
if aid := area_id(hass, lookup_value):
1338-
area_reg = ar.async_get(hass)
1339-
if area := area_reg.async_get_area(aid):
1340-
return area.floor_id
1341-
1342-
return None
1343-
1344-
1345-
def floor_name(hass: HomeAssistant, lookup_value: str) -> str | None:
1346-
"""Get the floor name from a floor id."""
1347-
floor_registry = fr.async_get(hass)
1348-
if floor := floor_registry.async_get_floor(lookup_value):
1349-
return floor.name
1350-
1351-
if aid := area_id(hass, lookup_value):
1352-
area_reg = ar.async_get(hass)
1353-
if (
1354-
(area := area_reg.async_get_area(aid))
1355-
and area.floor_id
1356-
and (floor := floor_registry.async_get_floor(area.floor_id))
1357-
):
1358-
return floor.name
1359-
1360-
return None
1361-
1362-
1363-
def floor_areas(hass: HomeAssistant, floor_id_or_name: str) -> Iterable[str]:
1364-
"""Return area IDs for a given floor ID or name."""
1365-
_floor_id: str | None
1366-
# If floor_name returns a value, we know the input was an ID, otherwise we
1367-
# assume it's a name, and if it's neither, we return early
1368-
if floor_name(hass, floor_id_or_name) is not None:
1369-
_floor_id = floor_id_or_name
1370-
else:
1371-
_floor_id = floor_id(hass, floor_id_or_name)
1372-
if _floor_id is None:
1373-
return []
1374-
1375-
area_reg = ar.async_get(hass)
1376-
entries = ar.async_entries_for_floor(area_reg, _floor_id)
1377-
return [entry.id for entry in entries if entry.id]
1378-
1379-
1380-
def floor_entities(hass: HomeAssistant, floor_id_or_name: str) -> Iterable[str]:
1381-
"""Return entity_ids for a given floor ID or name."""
1382-
return [
1383-
entity_id
1384-
for area_id in floor_areas(hass, floor_id_or_name)
1385-
for entity_id in area_entities(hass, area_id)
1386-
]
1387-
1388-
13891320
def areas(hass: HomeAssistant) -> Iterable[str | None]:
13901321
"""Return all areas."""
13911322
return list(ar.async_get(hass).areas)
13921323

13931324

13941325
def area_id(hass: HomeAssistant, lookup_value: str) -> str | None:
13951326
"""Get the area ID from an area name, alias, device id, or entity id."""
1396-
area_reg = ar.async_get(hass)
1397-
lookup_str = str(lookup_value)
1398-
if area := area_reg.async_get_area_by_name(lookup_str):
1399-
return area.id
1400-
areas_list = area_reg.async_get_areas_by_alias(lookup_str)
1401-
if areas_list:
1402-
return areas_list[0].id
1403-
1404-
ent_reg = er.async_get(hass)
1405-
dev_reg = dr.async_get(hass)
1406-
# Import here, not at top-level to avoid circular import
1407-
from homeassistant.helpers import config_validation as cv # noqa: PLC0415
1408-
1409-
try:
1410-
cv.entity_id(lookup_value)
1411-
except vol.Invalid:
1412-
pass
1413-
else:
1414-
if entity := ent_reg.async_get(lookup_value):
1415-
# If entity has an area ID, return that
1416-
if entity.area_id:
1417-
return entity.area_id
1418-
# If entity has a device ID, return the area ID for the device
1419-
if entity.device_id and (device := dev_reg.async_get(entity.device_id)):
1420-
return device.area_id
1421-
1422-
# Check if this could be a device ID
1423-
if device := dev_reg.async_get(lookup_value):
1424-
return device.area_id
1425-
1426-
return None
1327+
return resolve_area_id(hass, lookup_value)
14271328

14281329

14291330
def _get_area_name(area_reg: ar.AreaRegistry, valid_area_id: str) -> str:
@@ -2359,6 +2260,7 @@ def __init__(
23592260
"homeassistant.helpers.template.extensions.CollectionExtension"
23602261
)
23612262
self.add_extension("homeassistant.helpers.template.extensions.CryptoExtension")
2263+
self.add_extension("homeassistant.helpers.template.extensions.FloorExtension")
23622264
self.add_extension("homeassistant.helpers.template.extensions.LabelExtension")
23632265
self.add_extension("homeassistant.helpers.template.extensions.MathExtension")
23642266
self.add_extension("homeassistant.helpers.template.extensions.RegexExtension")
@@ -2462,23 +2364,6 @@ def wrapper(_: Any, *args: _P.args, **kwargs: _P.kwargs) -> _R:
24622364
self.globals["area_devices"] = hassfunction(area_devices)
24632365
self.filters["area_devices"] = self.globals["area_devices"]
24642366

2465-
# Floor extensions
2466-
2467-
self.globals["floors"] = hassfunction(floors)
2468-
self.filters["floors"] = self.globals["floors"]
2469-
2470-
self.globals["floor_id"] = hassfunction(floor_id)
2471-
self.filters["floor_id"] = self.globals["floor_id"]
2472-
2473-
self.globals["floor_name"] = hassfunction(floor_name)
2474-
self.filters["floor_name"] = self.globals["floor_name"]
2475-
2476-
self.globals["floor_areas"] = hassfunction(floor_areas)
2477-
self.filters["floor_areas"] = self.globals["floor_areas"]
2478-
2479-
self.globals["floor_entities"] = hassfunction(floor_entities)
2480-
self.filters["floor_entities"] = self.globals["floor_entities"]
2481-
24822367
# Integration extensions
24832368

24842369
self.globals["integration_entities"] = hassfunction(integration_entities)
@@ -2534,8 +2419,6 @@ def warn_unsupported(*args: Any, **kwargs: Any) -> NoReturn:
25342419
"device_id",
25352420
"distance",
25362421
"expand",
2537-
"floor_id",
2538-
"floor_name",
25392422
"has_value",
25402423
"is_device_attr",
25412424
"is_hidden_entity",
@@ -2557,8 +2440,6 @@ def warn_unsupported(*args: Any, **kwargs: Any) -> NoReturn:
25572440
"closest",
25582441
"device_id",
25592442
"expand",
2560-
"floor_id",
2561-
"floor_name",
25622443
"has_value",
25632444
]
25642445
hass_tests = [

homeassistant/helpers/template/extensions/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .base64 import Base64Extension
44
from .collection import CollectionExtension
55
from .crypto import CryptoExtension
6+
from .floors import FloorExtension
67
from .labels import LabelExtension
78
from .math import MathExtension
89
from .regex import RegexExtension
@@ -12,6 +13,7 @@
1213
"Base64Extension",
1314
"CollectionExtension",
1415
"CryptoExtension",
16+
"FloorExtension",
1517
"LabelExtension",
1618
"MathExtension",
1719
"RegexExtension",
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
"""Floor functions for Home Assistant templates."""
2+
3+
from __future__ import annotations
4+
5+
from collections.abc import Iterable
6+
from typing import TYPE_CHECKING, Any
7+
8+
from homeassistant.helpers import (
9+
area_registry as ar,
10+
device_registry as dr,
11+
entity_registry as er,
12+
floor_registry as fr,
13+
)
14+
from homeassistant.helpers.template.helpers import resolve_area_id
15+
16+
from .base import BaseTemplateExtension, TemplateFunction
17+
18+
if TYPE_CHECKING:
19+
from homeassistant.helpers.template import TemplateEnvironment
20+
21+
22+
class FloorExtension(BaseTemplateExtension):
23+
"""Extension for floor-related template functions."""
24+
25+
def __init__(self, environment: TemplateEnvironment) -> None:
26+
"""Initialize the floor extension."""
27+
super().__init__(
28+
environment,
29+
functions=[
30+
TemplateFunction(
31+
"floors",
32+
self.floors,
33+
as_global=True,
34+
requires_hass=True,
35+
),
36+
TemplateFunction(
37+
"floor_id",
38+
self.floor_id,
39+
as_global=True,
40+
as_filter=True,
41+
requires_hass=True,
42+
limited_ok=False,
43+
),
44+
TemplateFunction(
45+
"floor_name",
46+
self.floor_name,
47+
as_global=True,
48+
as_filter=True,
49+
requires_hass=True,
50+
limited_ok=False,
51+
),
52+
TemplateFunction(
53+
"floor_areas",
54+
self.floor_areas,
55+
as_global=True,
56+
as_filter=True,
57+
requires_hass=True,
58+
),
59+
TemplateFunction(
60+
"floor_entities",
61+
self.floor_entities,
62+
as_global=True,
63+
as_filter=True,
64+
requires_hass=True,
65+
),
66+
],
67+
)
68+
69+
def floors(self) -> Iterable[str | None]:
70+
"""Return all floors."""
71+
floor_registry = fr.async_get(self.hass)
72+
return [floor.floor_id for floor in floor_registry.async_list_floors()]
73+
74+
def floor_id(self, lookup_value: Any) -> str | None:
75+
"""Get the floor ID from a floor or area name, alias, device id, or entity id."""
76+
floor_registry = fr.async_get(self.hass)
77+
lookup_str = str(lookup_value)
78+
79+
# Check if it's a floor name or alias
80+
if floor := floor_registry.async_get_floor_by_name(lookup_str):
81+
return floor.floor_id
82+
floors_list = floor_registry.async_get_floors_by_alias(lookup_str)
83+
if floors_list:
84+
return floors_list[0].floor_id
85+
86+
# Resolve to area ID and get floor from area
87+
if aid := resolve_area_id(self.hass, lookup_value):
88+
area_reg = ar.async_get(self.hass)
89+
if area := area_reg.async_get_area(aid):
90+
return area.floor_id
91+
92+
return None
93+
94+
def floor_name(self, lookup_value: str) -> str | None:
95+
"""Get the floor name from a floor id."""
96+
floor_registry = fr.async_get(self.hass)
97+
98+
# Check if it's a floor ID
99+
if floor := floor_registry.async_get_floor(lookup_value):
100+
return floor.name
101+
102+
# Resolve to area ID and get floor name from area's floor
103+
if aid := resolve_area_id(self.hass, lookup_value):
104+
area_reg = ar.async_get(self.hass)
105+
if (
106+
(area := area_reg.async_get_area(aid))
107+
and area.floor_id
108+
and (floor := floor_registry.async_get_floor(area.floor_id))
109+
):
110+
return floor.name
111+
112+
return None
113+
114+
def _floor_id_or_name(self, floor_id_or_name: str) -> str | None:
115+
"""Get the floor ID from a floor name or ID."""
116+
# If floor_name returns a value, we know the input was an ID, otherwise we
117+
# assume it's a name, and if it's neither, we return early.
118+
if self.floor_name(floor_id_or_name) is not None:
119+
return floor_id_or_name
120+
return self.floor_id(floor_id_or_name)
121+
122+
def floor_areas(self, floor_id_or_name: str) -> Iterable[str]:
123+
"""Return area IDs for a given floor ID or name."""
124+
if (_floor_id := self._floor_id_or_name(floor_id_or_name)) is None:
125+
return []
126+
127+
area_reg = ar.async_get(self.hass)
128+
entries = ar.async_entries_for_floor(area_reg, _floor_id)
129+
return [entry.id for entry in entries if entry.id]
130+
131+
def floor_entities(self, floor_id_or_name: str) -> Iterable[str]:
132+
"""Return entity_ids for a given floor ID or name."""
133+
ent_reg = er.async_get(self.hass)
134+
dev_reg = dr.async_get(self.hass)
135+
entity_ids = []
136+
137+
for area_id in self.floor_areas(floor_id_or_name):
138+
# Get entities directly assigned to the area
139+
entity_ids.extend(
140+
[
141+
entry.entity_id
142+
for entry in er.async_entries_for_area(ent_reg, area_id)
143+
]
144+
)
145+
146+
# Also add entities tied to a device in the area that don't themselves
147+
# have an area specified since they inherit the area from the device
148+
entity_ids.extend(
149+
[
150+
entity.entity_id
151+
for device in dr.async_entries_for_area(dev_reg, area_id)
152+
for entity in er.async_entries_for_device(ent_reg, device.id)
153+
if entity.area_id is None
154+
]
155+
)
156+
157+
return entity_ids

0 commit comments

Comments
 (0)