Skip to content

Commit 91220dd

Browse files
authored
Fix template variable validation (#54)
* Handle template variable validation edge cases * Include all BaseCase & EdgeCase fields in template variable validation
1 parent e8b8372 commit 91220dd

File tree

3 files changed

+48
-29
lines changed

3 files changed

+48
-29
lines changed

afp/schemas.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""AFP data structures."""
22

33
from decimal import Decimal
4+
from itertools import chain
45
from typing import Annotated, Any, ClassVar, Literal, Self
56

67
from pydantic import AfterValidator, BeforeValidator, Field, model_validator
@@ -321,9 +322,17 @@ def _cross_validate(self) -> Self:
321322
self.product.min_price,
322323
self.product.max_price,
323324
)
324-
validators.validate_outcome_space_conditions(
325-
self.outcome_space.base_case.condition,
326-
[case.condition for case in self.outcome_space.edge_cases],
325+
validators.validate_outcome_space_template_variables(
326+
[
327+
self.outcome_space.base_case.condition,
328+
self.outcome_space.base_case.fsp_resolution,
329+
]
330+
+ list(
331+
chain.from_iterable(
332+
[edge_case.condition, edge_case.fsp_resolution]
333+
for edge_case in self.outcome_space.edge_cases
334+
)
335+
),
327336
self.outcome_point.model_dump(),
328337
)
329338
if isinstance(self.outcome_space, OutcomeSpaceTimeSeries) and isinstance(

afp/validators.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from decimal import Decimal
44
from functools import reduce
55
from operator import getitem
6-
from typing import Any
6+
from typing import Any, Iterable
77

88
import requests
99
from binascii import Error
@@ -178,23 +178,23 @@ def validate_oracle_fallback_fsp(
178178
)
179179

180180

181-
def validate_outcome_space_conditions(
182-
base_case_condition: str,
183-
edge_case_conditions: list[str],
184-
outcome_point_dict: dict[Any, Any],
181+
def validate_outcome_space_template_variables(
182+
values: Iterable[str], outcome_point_dict: dict[Any, Any]
185183
) -> None:
186-
conditions = [base_case_condition] + edge_case_conditions
187-
schemas = ["BaseCaseResolution"] + [
188-
f"EdgeCase[{i}]" for i in range(len(edge_case_conditions))
189-
]
190-
for condition, schema in zip(conditions, schemas):
191-
for variable in re.findall(r"{(.+?)}", condition):
192-
parts = variable.split(".")
184+
for value in values:
185+
for variable in re.findall(r"{(.*?)}", value):
193186
try:
194-
reduce(getitem, parts, outcome_point_dict)
195-
except KeyError:
187+
referred_value = reduce(
188+
getitem, variable.split("."), outcome_point_dict
189+
)
190+
except (TypeError, KeyError):
191+
raise ValueError(
192+
f"OutcomeSpace: Invalid template variable '{variable}'"
193+
)
194+
if isinstance(referred_value, dict) or isinstance(referred_value, list): # type: ignore
196195
raise ValueError(
197-
f"{schema}: condition: Invalid template variable '{variable}'"
196+
f"OutcomeSpace: Template variable '{variable}' "
197+
"should not refer to a nested object or list"
198198
)
199199

200200

tests/test_validators.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def test_validate_price_limits__error():
110110
validators.validate_price_limits(Decimal("0.11"), Decimal("0.10"))
111111

112112

113-
def test_validate_outcome_space_conditions__pass():
113+
def test_validate_outcome_space_template_variables__pass():
114114
dct = {
115115
"a": {
116116
"b": {
@@ -119,15 +119,29 @@ def test_validate_outcome_space_conditions__pass():
119119
},
120120
},
121121
}
122-
base_case_condition = "Reference to {a.b.c}"
123-
edge_case_conditions = ["And to {a.b.d} as well"]
124122

125-
validators.validate_outcome_space_conditions(
126-
base_case_condition, edge_case_conditions, dct
123+
validators.validate_outcome_space_template_variables(
124+
[
125+
"Reference to {a.b.c} should pass",
126+
"And to {a.b.d} as well",
127+
"So as having no template variable",
128+
],
129+
dct,
127130
)
128131

129132

130-
def test_validate_outcome_space_conditions__error():
133+
@pytest.mark.parametrize(
134+
"value",
135+
[
136+
"{a.b.c} and {a.b.e}",
137+
"{c}",
138+
"{a.b.c.d}",
139+
"{a.b}",
140+
"{}",
141+
],
142+
ids=str,
143+
)
144+
def test_validate_outcome_space_tempate_variables__error(value):
131145
dct = {
132146
"a": {
133147
"b": {
@@ -136,13 +150,9 @@ def test_validate_outcome_space_conditions__error():
136150
},
137151
},
138152
}
139-
base_case_condition = "Reference to {a.b.e}"
140-
edge_case_conditions = []
141153

142154
with pytest.raises(ValueError):
143-
validators.validate_outcome_space_conditions(
144-
base_case_condition, edge_case_conditions, dct
145-
)
155+
validators.validate_outcome_space_template_variables([value], dct)
146156

147157

148158
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)