Skip to content

Commit 929d76e

Browse files
Add validation for ObjectSelector (home-assistant#153081)
1 parent fe1ff08 commit 929d76e

File tree

2 files changed

+141
-8
lines changed

2 files changed

+141
-8
lines changed

homeassistant/helpers/selector.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1148,15 +1148,15 @@ def __call__(self, data: Any) -> float:
11481148
return value
11491149

11501150

1151-
class ObjectSelectorField(TypedDict):
1151+
class ObjectSelectorField(TypedDict, total=False):
11521152
"""Class to represent an object selector fields dict."""
11531153

11541154
label: str
11551155
required: bool
1156-
selector: dict[str, Any]
1156+
selector: Required[dict[str, Any]]
11571157

11581158

1159-
class ObjectSelectorConfig(BaseSelectorConfig):
1159+
class ObjectSelectorConfig(BaseSelectorConfig, total=False):
11601160
"""Class to represent an object selector config."""
11611161

11621162
fields: dict[str, ObjectSelectorField]
@@ -1176,7 +1176,7 @@ class ObjectSelector(Selector[ObjectSelectorConfig]):
11761176
{
11771177
vol.Optional("fields"): {
11781178
str: {
1179-
vol.Required("selector"): dict,
1179+
vol.Required("selector"): validate_selector,
11801180
vol.Optional("required"): bool,
11811181
vol.Optional("label"): str,
11821182
}
@@ -1194,6 +1194,28 @@ def __init__(self, config: ObjectSelectorConfig | None = None) -> None:
11941194

11951195
def __call__(self, data: Any) -> Any:
11961196
"""Validate the passed selection."""
1197+
if "fields" not in self.config:
1198+
# Return data if no fields are defined
1199+
return data
1200+
1201+
if not isinstance(data, (list, dict)):
1202+
raise vol.Invalid("Value should be a dict or a list of dicts")
1203+
if isinstance(data, list) and not self.config["multiple"]:
1204+
raise vol.Invalid("Value should not be a list")
1205+
1206+
test_data = data if isinstance(data, list) else [data]
1207+
1208+
for _config in test_data:
1209+
for field, field_data in self.config["fields"].items():
1210+
if field_data.get("required") and field not in _config:
1211+
raise vol.Invalid(f"Field {field} is required")
1212+
if field in _config:
1213+
selector(field_data["selector"])(_config[field]) # type: ignore[operator]
1214+
1215+
for key in _config:
1216+
if key not in self.config["fields"]:
1217+
raise vol.Invalid(f"Field {key} is not allowed")
1218+
11971219
return data
11981220

11991221

tests/helpers/test_selector.py

Lines changed: 115 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Test selectors."""
22

33
from collections.abc import Callable, Iterable
4+
from contextlib import AbstractContextManager, nullcontext as does_not_raise
45
from enum import Enum
56
from typing import Any
67

@@ -653,7 +654,38 @@ def test_action_selector_schema(schema, valid_selections, invalid_selections) ->
653654
@pytest.mark.parametrize(
654655
("schema", "valid_selections", "invalid_selections"),
655656
[
656-
({}, ("abc123",), ()),
657+
({}, ("abc123", None, {"key": "value"}), ()),
658+
({"multiple": False}, ("abc123", None, {"key": "value"}), ()),
659+
(
660+
{
661+
"fields": {
662+
"name": {
663+
"required": True,
664+
"selector": {"text": {}},
665+
},
666+
"percentage": {
667+
"selector": {"number": {}},
668+
},
669+
},
670+
"multiple": False,
671+
"label_field": "name",
672+
"description_field": "percentage",
673+
},
674+
(
675+
{"name": "abc123", "percentage": 3},
676+
{"name": "abc123"},
677+
),
678+
(
679+
"abc123",
680+
None,
681+
{"name": "abc123", "percentage": "nope"},
682+
[
683+
{"name": "abc123", "percentage": 3},
684+
{"name": "def987", "percentage": 5},
685+
],
686+
[{"name": "abc123"}],
687+
),
688+
),
657689
(
658690
{
659691
"fields": {
@@ -669,17 +701,96 @@ def test_action_selector_schema(schema, valid_selections, invalid_selections) ->
669701
"label_field": "name",
670702
"description_field": "percentage",
671703
},
672-
(),
673-
(),
704+
(
705+
[
706+
{"name": "abc123", "percentage": 3},
707+
{"name": "def987", "percentage": 5},
708+
],
709+
[{"name": "abc123"}],
710+
),
711+
(
712+
"abc123",
713+
None,
714+
[{"name": "abc123", "percentage": "nope"}],
715+
[{"name": "abc123", "percentage": 3, "not_exist": 5}],
716+
),
674717
),
675718
],
676-
[],
677719
)
678720
def test_object_selector_schema(schema, valid_selections, invalid_selections) -> None:
679721
"""Test object selector."""
680722
_test_selector("object", schema, valid_selections, invalid_selections)
681723

682724

725+
@pytest.mark.parametrize(
726+
("schema", "raises"),
727+
[
728+
({}, does_not_raise()),
729+
({"multiple": False}, does_not_raise()),
730+
(
731+
{
732+
"fields": {
733+
"name": {
734+
"required": True,
735+
"selector": {"text": {}},
736+
},
737+
"percentage": {
738+
"selector": {"number": {}},
739+
},
740+
},
741+
"multiple": True,
742+
"label_field": "name",
743+
"description_field": "percentage",
744+
},
745+
does_not_raise(),
746+
),
747+
(
748+
{
749+
"fields": {
750+
"name": {
751+
"required": True,
752+
"selector": selector.TextSelector(),
753+
},
754+
"percentage": {
755+
"selector": selector.NumberSelector(),
756+
},
757+
},
758+
"multiple": True,
759+
"label_field": "name",
760+
"description_field": "percentage",
761+
},
762+
pytest.raises(vol.Invalid),
763+
),
764+
(
765+
{
766+
"fields": {
767+
"name": {
768+
"required": True,
769+
"selector": {"not_exist": {}},
770+
},
771+
"percentage": {
772+
"selector": {"number": {}},
773+
},
774+
},
775+
"multiple": True,
776+
"label_field": "name",
777+
"description_field": "percentage",
778+
},
779+
pytest.raises(vol.Invalid),
780+
),
781+
({"multiple": "False"}, pytest.raises(vol.Invalid)),
782+
],
783+
)
784+
def test_object_selector_validate_schema(
785+
schema: dict, raises: AbstractContextManager
786+
) -> None:
787+
"""Test object selector schemas."""
788+
# Validate selector configuration
789+
790+
with raises:
791+
selector.validate_selector({"object": schema})
792+
793+
683794
@pytest.mark.parametrize(
684795
("schema", "valid_selections", "invalid_selections"),
685796
[

0 commit comments

Comments
 (0)