Skip to content

Commit 188459e

Browse files
Allow use of Selector in ObjectSelector fields (home-assistant#147929)
1 parent 7324a12 commit 188459e

File tree

3 files changed

+246
-3
lines changed

3 files changed

+246
-3
lines changed

homeassistant/helpers/selector.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
from collections.abc import Callable, Mapping, Sequence
6+
from copy import deepcopy
67
from enum import StrEnum
78
from functools import cache
89
import importlib
@@ -1153,7 +1154,7 @@ class ObjectSelectorField(TypedDict, total=False):
11531154

11541155
label: str
11551156
required: bool
1156-
selector: Required[dict[str, Any]]
1157+
selector: Required[Selector | dict[str, Any]]
11571158

11581159

11591160
class ObjectSelectorConfig(BaseSelectorConfig, total=False):
@@ -1176,7 +1177,7 @@ class ObjectSelector(Selector[ObjectSelectorConfig]):
11761177
{
11771178
vol.Optional("fields"): {
11781179
str: {
1179-
vol.Required("selector"): validate_selector,
1180+
vol.Required("selector"): vol.Any(Selector, validate_selector),
11801181
vol.Optional("required"): bool,
11811182
vol.Optional("label"): str,
11821183
}
@@ -1192,6 +1193,21 @@ def __init__(self, config: ObjectSelectorConfig | None = None) -> None:
11921193
"""Instantiate a selector."""
11931194
super().__init__(config)
11941195

1196+
def serialize(self) -> dict[str, dict[str, ObjectSelectorConfig]]:
1197+
"""Serialize ObjectSelector for voluptuous_serialize."""
1198+
_config = deepcopy(self.config)
1199+
if "fields" in _config:
1200+
for field_items in _config["fields"].values():
1201+
if isinstance(field_items["selector"], ObjectSelector):
1202+
field_items["selector"] = field_items["selector"].serialize()
1203+
elif isinstance(field_items["selector"], Selector):
1204+
field_items["selector"] = {
1205+
field_items["selector"].selector_type: field_items[
1206+
"selector"
1207+
].config
1208+
}
1209+
return {"selector": {self.selector_type: _config}}
1210+
11951211
def __call__(self, data: Any) -> Any:
11961212
"""Validate the passed selection."""
11971213
if "fields" not in self.config:
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# serializer version: 1
2+
# name: test_nested_object_selectors
3+
dict({
4+
'selector': dict({
5+
'object': dict({
6+
'description_field': 'percentage',
7+
'fields': dict({
8+
'name': dict({
9+
'required': True,
10+
'selector': dict({
11+
'text': dict({
12+
'multiline': False,
13+
'multiple': False,
14+
}),
15+
}),
16+
}),
17+
'object': dict({
18+
'selector': dict({
19+
'selector': dict({
20+
'object': dict({
21+
'description_field': 'other_name',
22+
'fields': dict({
23+
'new_object': dict({
24+
'required': True,
25+
'selector': dict({
26+
'selector': dict({
27+
'object': dict({
28+
'description_field': 'description',
29+
'fields': dict({
30+
'description': dict({
31+
'required': True,
32+
'selector': dict({
33+
'text': dict({
34+
'multiline': False,
35+
'multiple': False,
36+
}),
37+
}),
38+
}),
39+
'title': dict({
40+
'required': True,
41+
'selector': dict({
42+
'text': dict({
43+
'multiline': False,
44+
'multiple': False,
45+
}),
46+
}),
47+
}),
48+
}),
49+
'label_field': 'title',
50+
'multiple': False,
51+
}),
52+
}),
53+
}),
54+
}),
55+
'no_name': dict({
56+
'required': True,
57+
'selector': dict({
58+
'text': dict({
59+
'multiline': False,
60+
'multiple': False,
61+
}),
62+
}),
63+
}),
64+
'other_name': dict({
65+
'required': True,
66+
'selector': dict({
67+
'text': dict({
68+
'multiline': False,
69+
'multiple': False,
70+
}),
71+
}),
72+
}),
73+
}),
74+
'label_field': 'no_name',
75+
'multiple': False,
76+
}),
77+
}),
78+
}),
79+
}),
80+
}),
81+
'label_field': 'name',
82+
'multiple': True,
83+
}),
84+
}),
85+
})
86+
# ---
87+
# name: test_object_selector_uses_selectors
88+
dict({
89+
'selector': dict({
90+
'object': dict({
91+
'description_field': 'percentage',
92+
'fields': dict({
93+
'name': dict({
94+
'required': True,
95+
'selector': dict({
96+
'text': dict({
97+
'multiline': False,
98+
'multiple': False,
99+
}),
100+
}),
101+
}),
102+
'percentage': dict({
103+
'selector': dict({
104+
'number': dict({
105+
'max': 100.0,
106+
'min': 0.0,
107+
'mode': 'slider',
108+
'step': 1.0,
109+
}),
110+
}),
111+
}),
112+
}),
113+
'label_field': 'name',
114+
'multiple': True,
115+
}),
116+
}),
117+
})
118+
# ---

tests/helpers/test_selector.py

Lines changed: 110 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Any
77

88
import pytest
9+
from syrupy.assertion import SnapshotAssertion
910
import voluptuous as vol
1011

1112
from homeassistant.helpers import selector
@@ -722,6 +723,114 @@ def test_object_selector_schema(schema, valid_selections, invalid_selections) ->
722723
_test_selector("object", schema, valid_selections, invalid_selections)
723724

724725

726+
def test_object_selector_uses_selectors(snapshot: SnapshotAssertion) -> None:
727+
"""Test ObjectSelector custom serializer for using Selector in ObjectSelectorField."""
728+
729+
selector_type = "object"
730+
schema = {
731+
"fields": {
732+
"name": {
733+
"required": True,
734+
"selector": selector.TextSelector(),
735+
},
736+
"percentage": {
737+
"selector": selector.NumberSelector(
738+
selector.NumberSelectorConfig(min=0, max=100)
739+
),
740+
},
741+
},
742+
"multiple": True,
743+
"label_field": "name",
744+
"description_field": "percentage",
745+
}
746+
747+
# Validate selector configuration
748+
config = {selector_type: schema}
749+
selector.validate_selector(config)
750+
selector_instance = selector.selector(config)
751+
752+
# Serialize selector
753+
selector_instance = selector.selector({selector_type: schema})
754+
assert selector_instance.serialize() != {
755+
"selector": {selector_type: selector_instance.config}
756+
}
757+
assert selector_instance.serialize() == snapshot()
758+
759+
# Test serialized selector can be dumped to YAML
760+
yaml_util.dump(selector_instance.serialize())
761+
762+
763+
def test_nested_object_selectors(snapshot: SnapshotAssertion) -> None:
764+
"""Test ObjectSelector custom serializer with nested ObjectSelectors."""
765+
766+
selector_type = "object"
767+
schema = {
768+
"fields": {
769+
"name": {
770+
"required": True,
771+
"selector": selector.TextSelector(),
772+
},
773+
"object": {
774+
"selector": selector.ObjectSelector(
775+
selector.ObjectSelectorConfig(
776+
fields={
777+
"no_name": {
778+
"required": True,
779+
"selector": selector.TextSelector(),
780+
},
781+
"other_name": {
782+
"required": True,
783+
"selector": selector.TextSelector(),
784+
},
785+
"new_object": {
786+
"required": True,
787+
"selector": selector.ObjectSelector(
788+
selector.ObjectSelectorConfig(
789+
fields={
790+
"title": {
791+
"required": True,
792+
"selector": selector.TextSelector(),
793+
},
794+
"description": {
795+
"required": True,
796+
"selector": selector.TextSelector(),
797+
},
798+
},
799+
multiple=False,
800+
label_field="title",
801+
description_field="description",
802+
)
803+
),
804+
},
805+
},
806+
multiple=False,
807+
label_field="no_name",
808+
description_field="other_name",
809+
)
810+
),
811+
},
812+
},
813+
"multiple": True,
814+
"label_field": "name",
815+
"description_field": "percentage",
816+
}
817+
818+
# Validate selector configuration
819+
config = {selector_type: schema}
820+
selector.validate_selector(config)
821+
selector_instance = selector.selector(config)
822+
823+
# Serialize selector
824+
selector_instance = selector.selector({selector_type: schema})
825+
assert selector_instance.serialize() != {
826+
"selector": {selector_type: selector_instance.config}
827+
}
828+
assert selector_instance.serialize() == snapshot()
829+
830+
# Test serialized selector can be dumped to YAML
831+
yaml_util.dump(selector_instance.serialize())
832+
833+
725834
@pytest.mark.parametrize(
726835
("schema", "raises"),
727836
[
@@ -759,7 +868,7 @@ def test_object_selector_schema(schema, valid_selections, invalid_selections) ->
759868
"label_field": "name",
760869
"description_field": "percentage",
761870
},
762-
pytest.raises(vol.Invalid),
871+
does_not_raise(),
763872
),
764873
(
765874
{

0 commit comments

Comments
 (0)