Skip to content

Commit f484b6d

Browse files
frenckCopilot
andauthored
Extract label template functions into a label Jinja2 extension (home-assistant#156439)
Co-authored-by: Copilot <[email protected]>
1 parent 34c1d45 commit f484b6d

File tree

7 files changed

+623
-522
lines changed

7 files changed

+623
-522
lines changed

homeassistant/helpers/template/__init__.py

Lines changed: 1 addition & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@
6161
entity_registry as er,
6262
floor_registry as fr,
6363
issue_registry as ir,
64-
label_registry as lr,
6564
location as loc_helper,
6665
)
6766
from homeassistant.helpers.singleton import singleton
@@ -1514,100 +1513,6 @@ def area_devices(hass: HomeAssistant, area_id_or_name: str) -> Iterable[str]:
15141513
return [entry.id for entry in entries]
15151514

15161515

1517-
def labels(hass: HomeAssistant, lookup_value: Any = None) -> Iterable[str | None]:
1518-
"""Return all labels, or those from a area ID, device ID, or entity ID."""
1519-
label_reg = lr.async_get(hass)
1520-
if lookup_value is None:
1521-
return list(label_reg.labels)
1522-
1523-
ent_reg = er.async_get(hass)
1524-
1525-
# Import here, not at top-level to avoid circular import
1526-
from homeassistant.helpers import config_validation as cv # noqa: PLC0415
1527-
1528-
lookup_value = str(lookup_value)
1529-
1530-
try:
1531-
cv.entity_id(lookup_value)
1532-
except vol.Invalid:
1533-
pass
1534-
else:
1535-
if entity := ent_reg.async_get(lookup_value):
1536-
return list(entity.labels)
1537-
1538-
# Check if this could be a device ID
1539-
dev_reg = dr.async_get(hass)
1540-
if device := dev_reg.async_get(lookup_value):
1541-
return list(device.labels)
1542-
1543-
# Check if this could be a area ID
1544-
area_reg = ar.async_get(hass)
1545-
if area := area_reg.async_get_area(lookup_value):
1546-
return list(area.labels)
1547-
1548-
return []
1549-
1550-
1551-
def label_id(hass: HomeAssistant, lookup_value: Any) -> str | None:
1552-
"""Get the label ID from a label name."""
1553-
label_reg = lr.async_get(hass)
1554-
if label := label_reg.async_get_label_by_name(str(lookup_value)):
1555-
return label.label_id
1556-
return None
1557-
1558-
1559-
def label_name(hass: HomeAssistant, lookup_value: str) -> str | None:
1560-
"""Get the label name from a label ID."""
1561-
label_reg = lr.async_get(hass)
1562-
if label := label_reg.async_get_label(lookup_value):
1563-
return label.name
1564-
return None
1565-
1566-
1567-
def label_description(hass: HomeAssistant, lookup_value: str) -> str | None:
1568-
"""Get the label description from a label ID."""
1569-
label_reg = lr.async_get(hass)
1570-
if label := label_reg.async_get_label(lookup_value):
1571-
return label.description
1572-
return None
1573-
1574-
1575-
def _label_id_or_name(hass: HomeAssistant, label_id_or_name: str) -> str | None:
1576-
"""Get the label ID from a label name or ID."""
1577-
# If label_name returns a value, we know the input was an ID, otherwise we
1578-
# assume it's a name, and if it's neither, we return early.
1579-
if label_name(hass, label_id_or_name) is not None:
1580-
return label_id_or_name
1581-
return label_id(hass, label_id_or_name)
1582-
1583-
1584-
def label_areas(hass: HomeAssistant, label_id_or_name: str) -> Iterable[str]:
1585-
"""Return areas for a given label ID or name."""
1586-
if (_label_id := _label_id_or_name(hass, label_id_or_name)) is None:
1587-
return []
1588-
area_reg = ar.async_get(hass)
1589-
entries = ar.async_entries_for_label(area_reg, _label_id)
1590-
return [entry.id for entry in entries]
1591-
1592-
1593-
def label_devices(hass: HomeAssistant, label_id_or_name: str) -> Iterable[str]:
1594-
"""Return device IDs for a given label ID or name."""
1595-
if (_label_id := _label_id_or_name(hass, label_id_or_name)) is None:
1596-
return []
1597-
dev_reg = dr.async_get(hass)
1598-
entries = dr.async_entries_for_label(dev_reg, _label_id)
1599-
return [entry.id for entry in entries]
1600-
1601-
1602-
def label_entities(hass: HomeAssistant, label_id_or_name: str) -> Iterable[str]:
1603-
"""Return entities for a given label ID or name."""
1604-
if (_label_id := _label_id_or_name(hass, label_id_or_name)) is None:
1605-
return []
1606-
ent_reg = er.async_get(hass)
1607-
entries = er.async_entries_for_label(ent_reg, _label_id)
1608-
return [entry.entity_id for entry in entries]
1609-
1610-
16111516
def closest(hass: HomeAssistant, *args: Any) -> State | None:
16121517
"""Find closest entity.
16131518
@@ -2454,6 +2359,7 @@ def __init__(
24542359
"homeassistant.helpers.template.extensions.CollectionExtension"
24552360
)
24562361
self.add_extension("homeassistant.helpers.template.extensions.CryptoExtension")
2362+
self.add_extension("homeassistant.helpers.template.extensions.LabelExtension")
24572363
self.add_extension("homeassistant.helpers.template.extensions.MathExtension")
24582364
self.add_extension("homeassistant.helpers.template.extensions.RegexExtension")
24592365
self.add_extension("homeassistant.helpers.template.extensions.StringExtension")
@@ -2603,29 +2509,6 @@ def wrapper(_: Any, *args: _P.args, **kwargs: _P.kwargs) -> _R:
26032509
self.globals["device_id"] = hassfunction(device_id)
26042510
self.filters["device_id"] = self.globals["device_id"]
26052511

2606-
# Label extensions
2607-
2608-
self.globals["labels"] = hassfunction(labels)
2609-
self.filters["labels"] = self.globals["labels"]
2610-
2611-
self.globals["label_id"] = hassfunction(label_id)
2612-
self.filters["label_id"] = self.globals["label_id"]
2613-
2614-
self.globals["label_name"] = hassfunction(label_name)
2615-
self.filters["label_name"] = self.globals["label_name"]
2616-
2617-
self.globals["label_description"] = hassfunction(label_description)
2618-
self.filters["label_description"] = self.globals["label_description"]
2619-
2620-
self.globals["label_areas"] = hassfunction(label_areas)
2621-
self.filters["label_areas"] = self.globals["label_areas"]
2622-
2623-
self.globals["label_devices"] = hassfunction(label_devices)
2624-
self.filters["label_devices"] = self.globals["label_devices"]
2625-
2626-
self.globals["label_entities"] = hassfunction(label_entities)
2627-
self.filters["label_entities"] = self.globals["label_entities"]
2628-
26292512
# Issue extensions
26302513

26312514
self.globals["issues"] = hassfunction(issues)
@@ -2658,8 +2541,6 @@ def warn_unsupported(*args: Any, **kwargs: Any) -> NoReturn:
26582541
"is_hidden_entity",
26592542
"is_state_attr",
26602543
"is_state",
2661-
"label_id",
2662-
"label_name",
26632544
"now",
26642545
"relative_time",
26652546
"state_attr",
@@ -2679,8 +2560,6 @@ def warn_unsupported(*args: Any, **kwargs: Any) -> NoReturn:
26792560
"floor_id",
26802561
"floor_name",
26812562
"has_value",
2682-
"label_id",
2683-
"label_name",
26842563
]
26852564
hass_tests = [
26862565
"has_value",

homeassistant/helpers/template/extensions/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .base64 import Base64Extension
44
from .collection import CollectionExtension
55
from .crypto import CryptoExtension
6+
from .labels import LabelExtension
67
from .math import MathExtension
78
from .regex import RegexExtension
89
from .string import StringExtension
@@ -11,6 +12,7 @@
1112
"Base64Extension",
1213
"CollectionExtension",
1314
"CryptoExtension",
15+
"LabelExtension",
1416
"MathExtension",
1517
"RegexExtension",
1618
"StringExtension",

homeassistant/helpers/template/extensions/base.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44

55
from collections.abc import Callable
66
from dataclasses import dataclass
7-
from typing import TYPE_CHECKING, Any, NoReturn
7+
from functools import wraps
8+
from typing import TYPE_CHECKING, Any, Concatenate, NoReturn
89

10+
from jinja2 import pass_context
911
from jinja2.ext import Extension
1012
from jinja2.nodes import Node
1113
from jinja2.parser import Parser
@@ -32,6 +34,28 @@ class TemplateFunction:
3234
requires_hass: bool = False # Whether this function requires hass to be available
3335

3436

37+
def _pass_context[**_P, _R](
38+
func: Callable[Concatenate[Any, _P], _R],
39+
jinja_context: Callable[
40+
[Callable[Concatenate[Any, _P], _R]],
41+
Callable[Concatenate[Any, _P], _R],
42+
] = pass_context,
43+
) -> Callable[Concatenate[Any, _P], _R]:
44+
"""Wrap function to pass context.
45+
46+
We mark these as a context functions to ensure they get
47+
evaluated fresh with every execution, rather than executed
48+
at compile time and the value stored. The context itself
49+
can be discarded.
50+
"""
51+
52+
@wraps(func)
53+
def wrapper(_: Any, *args: _P.args, **kwargs: _P.kwargs) -> _R:
54+
return func(*args, **kwargs)
55+
56+
return jinja_context(wrapper)
57+
58+
3559
class BaseTemplateExtension(Extension):
3660
"""Base class for Home Assistant template extensions."""
3761

@@ -65,12 +89,20 @@ def __init__(
6589
environment.tests[template_func.name] = unsupported_func
6690
continue
6791

92+
func = template_func.func
93+
94+
if template_func.requires_hass:
95+
# We wrap these as a context functions to ensure they get
96+
# evaluated fresh with every execution, rather than executed
97+
# at compile time and the value stored.
98+
func = _pass_context(func)
99+
68100
if template_func.as_global:
69-
environment.globals[template_func.name] = template_func.func
101+
environment.globals[template_func.name] = func
70102
if template_func.as_filter:
71-
environment.filters[template_func.name] = template_func.func
103+
environment.filters[template_func.name] = func
72104
if template_func.as_test:
73-
environment.tests[template_func.name] = template_func.func
105+
environment.tests[template_func.name] = func
74106

75107
@staticmethod
76108
def _create_unsupported_function(name: str) -> Callable[[], NoReturn]:

0 commit comments

Comments
 (0)