Skip to content

Commit 7a3eb9b

Browse files
committed
fix: Redo ParametricReform schema
1 parent bcb1600 commit 7a3eb9b

File tree

3 files changed

+261
-41
lines changed

3 files changed

+261
-41
lines changed

changelog_entry.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
- bump: patch
2+
changes:
3+
changed:
4+
- Refactored ParametricReform schema into clearer subschemas.
5+
- Added conversion of Infinity and -Infinity to np.inf and -np.inf.

policyengine/utils/reforms.py

Lines changed: 61 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,81 @@
1-
import re
21
from pydantic import (
32
RootModel,
4-
ValidationError,
53
Field,
6-
model_validator,
74
field_validator,
85
)
9-
from typing import Dict, Self, Any, TYPE_CHECKING
6+
from typing import Dict, Any
107
from typing_extensions import Annotated
11-
from typing import Callable
128
from policyengine_core.simulations import Simulation
139

1410

15-
class ParameterChangeDict(RootModel):
16-
"""A dict of changes to a parameter, with custom date string as keys
17-
and various possible value types."""
11+
class ParameterChangeValue(RootModel):
12+
"""A value for a parameter change, which can be any primitive type or 'Infinity'/'-Infinity'"""
1813

19-
root: Dict[str, Any]
14+
# To prevent validation errors, allow all types except containers
15+
# via field validator
16+
root: Any
2017

21-
@model_validator(mode="after")
22-
def check_keys(self) -> Self:
23-
for key in self.root.keys():
24-
# Check if key is YYYY-MM-DD.YYYY-MM-DD
25-
if not re.match(r"^\d{4}-\d{2}-\d{2}\.\d{4}-\d{2}-\d{2}$", key):
26-
raise ValueError(f"Invalid date format in key: {key}")
27-
return self
18+
@field_validator("root", mode="after")
19+
@classmethod
20+
def check_type(cls, value: Any) -> Any:
21+
# Check if the value is not a container type
22+
if isinstance(value, (dict, list, set, tuple)):
23+
raise ValueError(
24+
"ParameterChangeValue must not be a container type (dict, list, set, tuple)"
25+
)
26+
return value
2827

2928
# Convert "Infinity" to "np.inf" and "-Infinity" to "-np.inf"
3029
@field_validator("root", mode="after")
3130
@classmethod
32-
def convert_infinity(cls, value: Dict[str, Any]) -> Dict[str, Any]:
33-
for key, val in value.items():
34-
if isinstance(val, str):
35-
if val == "Infinity":
36-
value[key] = float("inf")
37-
elif val == "-Infinity":
38-
value[key] = float("-inf")
31+
def convert_infinity(cls, value: Any) -> Any:
32+
if isinstance(value, str):
33+
if value == "Infinity":
34+
value = float("inf")
35+
elif value == "-Infinity":
36+
value = float("-inf")
3937
return value
4038

4139

40+
class ParameterChangePeriod(RootModel):
41+
"""A period for a parameter change, which can be a single year or a date range"""
42+
43+
root: Annotated[
44+
str,
45+
Field(
46+
pattern=r"^\d{4}$|^\d{4}-\d{2}-\d{2}\.\d{4}-\d{2}-\d{2}$",
47+
description="A single year (YYYY) or a date range (YYYY-MM-DD.YYYY-MM-DD)",
48+
),
49+
]
50+
51+
def __hash__(self):
52+
return hash(self.root)
53+
54+
def __eq__(self, other):
55+
if isinstance(other, ParameterChangePeriod):
56+
return self.root == other.root
57+
return False
58+
59+
60+
class ParameterChangeDict(RootModel):
61+
"""
62+
A dict of changes to a parameter, with custom date string as keys
63+
and various possible value types.
64+
65+
Keys can be formatted one of two ways:
66+
1. A single year (e.g., "YYYY")
67+
2. A date range (e.g., "YYYY-MM-DD.YYYY-MM-DD")
68+
"""
69+
70+
root: Dict[ParameterChangePeriod, ParameterChangeValue]
71+
72+
4273
class ParametricReform(RootModel):
43-
"""A reform that just changes parameter values."""
74+
"""
75+
A reform that just changes parameter values.
76+
77+
This is a dict that equates a parameter name to either a single value or a dict of changes.
78+
79+
"""
4480

45-
root: Dict[str, ParameterChangeDict]
81+
root: Dict[str, ParameterChangeValue | ParameterChangeDict]

tests/utils/test_reforms.py

Lines changed: 195 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,12 @@
22
import pytest
33
import numpy as np
44

5-
from policyengine.utils.reforms import ParameterChangeDict, ParametricReform
5+
from policyengine.utils.reforms import (
6+
ParameterChangeDict,
7+
ParametricReform,
8+
ParameterChangeValue,
9+
ParameterChangePeriod,
10+
)
611

712

813
class TestParameterChangeDict:
@@ -12,41 +17,82 @@ def test_schema__given_float_inputs__returns_valid_dict(self):
1217
"2024-01-01.2024-12-31": 0.2,
1318
}
1419

20+
expected_output_data = {
21+
ParameterChangePeriod(
22+
root="2023-01-01.2023-12-31"
23+
): ParameterChangeValue(root=0.1),
24+
ParameterChangePeriod(
25+
root="2024-01-01.2024-12-31"
26+
): ParameterChangeValue(root=0.2),
27+
}
28+
1529
result = ParameterChangeDict(root=input_data)
1630

1731
assert isinstance(result, ParameterChangeDict)
18-
assert result.root == input_data
32+
assert result.root == expected_output_data
1933

2034
def test_schema__given_string_inputs__returns_valid_dict(self):
2135
input_data = {
2236
"2023-01-01.2023-12-31": "0.1",
2337
"2024-01-01.2024-12-31": "0.2",
2438
}
2539

40+
expected_output_data = {
41+
ParameterChangePeriod(
42+
root="2023-01-01.2023-12-31"
43+
): ParameterChangeValue(root="0.1"),
44+
ParameterChangePeriod(
45+
root="2024-01-01.2024-12-31"
46+
): ParameterChangeValue(root="0.2"),
47+
}
48+
2649
result = ParameterChangeDict(root=input_data)
2750

2851
assert isinstance(result, ParameterChangeDict)
29-
assert result.root == input_data
52+
assert result.root == expected_output_data
3053

3154
def test_schema__given_infinity_string__returns_valid_dict(self):
3255
input_data = {
3356
"2023-01-01.2023-12-31": "Infinity",
3457
"2024-01-01.2024-12-31": "-Infinity",
3558
}
3659

60+
expected_output_data = {
61+
ParameterChangePeriod(
62+
root="2023-01-01.2023-12-31"
63+
): ParameterChangeValue(root=np.inf),
64+
ParameterChangePeriod(
65+
root="2024-01-01.2024-12-31"
66+
): ParameterChangeValue(root=-np.inf),
67+
}
68+
3769
result = ParameterChangeDict(root=input_data)
3870

3971
assert isinstance(result, ParameterChangeDict)
40-
assert result.root == {
41-
"2023-01-01.2023-12-31": np.inf,
42-
"2024-01-01.2024-12-31": -np.inf,
72+
assert result.root == expected_output_data
73+
74+
def test_schema__given_yearly_input__returns_valid_dict(self):
75+
input_data = {
76+
"2023": 0.1,
77+
"2024": 0.2,
78+
}
79+
80+
expected_output_data = {
81+
ParameterChangePeriod(root="2023"): ParameterChangeValue(root=0.1),
82+
ParameterChangePeriod(root="2024"): ParameterChangeValue(root=0.2),
4383
}
4484

85+
result = ParameterChangeDict(root=input_data)
86+
87+
assert isinstance(result, ParameterChangeDict)
88+
assert result.root == expected_output_data
89+
4590
def test_schema__given_invalid_date_format__raises_validation_error(self):
4691
input_data = {"2023-01-01.2023-12-31": 0.1, "invalid_date_format": 0.2}
4792

4893
with pytest.raises(
49-
ValidationError, match="Invalid date format in key"
94+
ValidationError,
95+
match=r"validation errors? for ParameterChangeDict",
5096
):
5197
ParameterChangeDict(root=input_data)
5298

@@ -59,8 +105,93 @@ def test_schema__given_invalid_key_type__raises_validation_error(self):
59105
ParameterChangeDict(root=input_data)
60106

61107

108+
class TestParameterChangePeriod:
109+
def test_schema__given_valid_year__returns_valid_period(self):
110+
input_data = "2023"
111+
112+
result = ParameterChangePeriod(root=input_data)
113+
114+
assert isinstance(result, ParameterChangePeriod)
115+
assert result.root == input_data
116+
117+
def test_schema__given_valid_date_range__returns_valid_period(self):
118+
input_data = "2023-01-01.2023-12-31"
119+
120+
result = ParameterChangePeriod(root=input_data)
121+
122+
assert isinstance(result, ParameterChangePeriod)
123+
assert result.root == input_data
124+
125+
def test_schema__given_invalid_date_format__raises_validation_error(self):
126+
input_data = "2023.01.01-2024.12.31"
127+
128+
with pytest.raises(
129+
ValidationError,
130+
match=r"validation errors? for ParameterChangePeriod",
131+
):
132+
ParameterChangePeriod(root=input_data)
133+
134+
135+
class TestParameterChangeValue:
136+
def test_schema__given_float_input__returns_valid_value(self):
137+
input_data = 0.1
138+
139+
result = ParameterChangeValue(root=input_data)
140+
141+
assert isinstance(result, ParameterChangeValue)
142+
assert result.root == input_data
143+
144+
def test_schema__given_string_input__returns_valid_value(self):
145+
input_data = "0.1"
146+
147+
result = ParameterChangeValue(root=input_data)
148+
149+
assert isinstance(result, ParameterChangeValue)
150+
assert result.root == input_data
151+
152+
def test_schema__given_bool_input__returns_valid_value(self):
153+
input_data = True
154+
155+
result = ParameterChangeValue(root=input_data)
156+
157+
assert isinstance(result, ParameterChangeValue)
158+
assert result.root == input_data
159+
160+
def test_schema__given_infinity_string__returns_valid_value(self):
161+
input_data = "Infinity"
162+
163+
result = ParameterChangeValue(root=input_data)
164+
165+
assert isinstance(result, ParameterChangeValue)
166+
assert result.root == float("inf")
167+
168+
def test_schema__given_negative_infinity_string__returns_valid_value(self):
169+
input_data = "-Infinity"
170+
171+
result = ParameterChangeValue(root=input_data)
172+
173+
assert isinstance(result, ParameterChangeValue)
174+
assert result.root == float("-inf")
175+
176+
def test_schema__given_invalid_type__raises_validation_error(self):
177+
input_data = [0.1, 0.2]
178+
179+
with pytest.raises(
180+
ValidationError, match="validation error for ParameterChangeValue"
181+
):
182+
ParameterChangeValue(root=input_data)
183+
184+
def test_schema__given_dict_input__raises_validation_error(self):
185+
input_data = {"key": "value"}
186+
187+
with pytest.raises(
188+
ValidationError, match="validation error for ParameterChangeValue"
189+
):
190+
ParameterChangeValue(root=input_data)
191+
192+
62193
class TestParametricReform:
63-
def test_schema__given_valid_dict__returns_valid_reform(self):
194+
def test_schema__given_full_date_dict__returns_valid_reform(self):
64195
input_data = {
65196
"parameter1": {
66197
"2023-01-01.2023-12-31": 0.1,
@@ -92,6 +223,62 @@ def test_schema__given_valid_dict__returns_valid_reform(self):
92223
assert isinstance(result, ParametricReform)
93224
assert result.root == expected_output_data
94225

226+
def test_schema__given_yearly_dict__returns_valid_reform(self):
227+
input_data = {
228+
"parameter1": {"2023": 0.1, "2024": 0.2},
229+
"parameter2": {"2023": 0.3, "2024": 0.4},
230+
}
231+
232+
expected_output_data = {
233+
"parameter1": ParameterChangeDict(root={"2023": 0.1, "2024": 0.2}),
234+
"parameter2": ParameterChangeDict(root={"2023": 0.3, "2024": 0.4}),
235+
}
236+
237+
result = ParametricReform(root=input_data)
238+
239+
assert isinstance(result, ParametricReform)
240+
assert result.root == expected_output_data
241+
242+
def test_schema__given_single_value_dict__returns_valid_reform(self):
243+
input_data = {
244+
"parameter1": 0.1,
245+
"parameter2": 0.2,
246+
}
247+
248+
expected_output_data = {
249+
"parameter1": ParameterChangeValue(root=0.1),
250+
"parameter2": ParameterChangeValue(root=0.2),
251+
}
252+
253+
result = ParametricReform(root=input_data)
254+
255+
assert isinstance(result, ParametricReform)
256+
assert result.root == expected_output_data
257+
258+
def test_schema__given_mixed_dict__returns_valid_reform(self):
259+
input_data = {
260+
"parameter1": {
261+
"2023-01-01.2023-12-31": 0.1,
262+
"2024-01-01.2024-12-31": 0.2,
263+
},
264+
"parameter2": 0.3,
265+
}
266+
267+
expected_output_data = {
268+
"parameter1": ParameterChangeDict(
269+
root={
270+
"2023-01-01.2023-12-31": 0.1,
271+
"2024-01-01.2024-12-31": 0.2,
272+
}
273+
),
274+
"parameter2": ParameterChangeValue(root=0.3),
275+
}
276+
277+
result = ParametricReform(root=input_data)
278+
279+
assert isinstance(result, ParametricReform)
280+
assert result.root == expected_output_data
281+
95282
def test_schema__given_invalid_key_type__raises_validation_error(self):
96283
input_data = {
97284
123: {"2023-01-01.2023-12-31": 0.1, "2024-01-01.2024-12-31": 0.2},
@@ -105,11 +292,3 @@ def test_schema__given_invalid_key_type__raises_validation_error(self):
105292
ValidationError, match=r"validation errors? for ParametricReform"
106293
):
107294
ParametricReform(root=input_data)
108-
109-
def test_schema__given_dateless_structure__raises_validation_error(self):
110-
input_data = {"parameter1": 0.1, "parameter2": 0.2}
111-
112-
with pytest.raises(
113-
ValidationError, match=r"validation errors? for ParametricReform"
114-
):
115-
ParametricReform(root=input_data)

0 commit comments

Comments
 (0)