Skip to content

Commit bcb1600

Browse files
committed
fix: Convert Infinity and -Infinity to np.inf and -np.inf
1 parent 317eac4 commit bcb1600

File tree

2 files changed

+151
-4
lines changed

2 files changed

+151
-4
lines changed

policyengine/utils/reforms.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,45 @@
11
import re
2-
from pydantic import RootModel, ValidationError, Field, model_validator
3-
from typing import Dict, TYPE_CHECKING
4-
from annotated_types import Ge, Le
2+
from pydantic import (
3+
RootModel,
4+
ValidationError,
5+
Field,
6+
model_validator,
7+
field_validator,
8+
)
9+
from typing import Dict, Self, Any, TYPE_CHECKING
510
from typing_extensions import Annotated
611
from typing import Callable
712
from policyengine_core.simulations import Simulation
813

914

15+
class ParameterChangeDict(RootModel):
16+
"""A dict of changes to a parameter, with custom date string as keys
17+
and various possible value types."""
18+
19+
root: Dict[str, Any]
20+
21+
@model_validator(mode="after")
22+
def check_keys(self) -> Self:
23+
for key in self.root.keys():
24+
# Check if key is YYYY-MM-DD.YYYY-MM-DD
25+
if not re.match(r"^\d{4}-\d{2}-\d{2}\.\d{4}-\d{2}-\d{2}$", key):
26+
raise ValueError(f"Invalid date format in key: {key}")
27+
return self
28+
29+
# Convert "Infinity" to "np.inf" and "-Infinity" to "-np.inf"
30+
@field_validator("root", mode="after")
31+
@classmethod
32+
def convert_infinity(cls, value: Dict[str, Any]) -> Dict[str, Any]:
33+
for key, val in value.items():
34+
if isinstance(val, str):
35+
if val == "Infinity":
36+
value[key] = float("inf")
37+
elif val == "-Infinity":
38+
value[key] = float("-inf")
39+
return value
40+
41+
1042
class ParametricReform(RootModel):
1143
"""A reform that just changes parameter values."""
1244

13-
root: Dict[str, Dict | float | bool]
45+
root: Dict[str, ParameterChangeDict]

tests/utils/test_reforms.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
from pydantic import ValidationError
2+
import pytest
3+
import numpy as np
4+
5+
from policyengine.utils.reforms import ParameterChangeDict, ParametricReform
6+
7+
8+
class TestParameterChangeDict:
9+
def test_schema__given_float_inputs__returns_valid_dict(self):
10+
input_data = {
11+
"2023-01-01.2023-12-31": 0.1,
12+
"2024-01-01.2024-12-31": 0.2,
13+
}
14+
15+
result = ParameterChangeDict(root=input_data)
16+
17+
assert isinstance(result, ParameterChangeDict)
18+
assert result.root == input_data
19+
20+
def test_schema__given_string_inputs__returns_valid_dict(self):
21+
input_data = {
22+
"2023-01-01.2023-12-31": "0.1",
23+
"2024-01-01.2024-12-31": "0.2",
24+
}
25+
26+
result = ParameterChangeDict(root=input_data)
27+
28+
assert isinstance(result, ParameterChangeDict)
29+
assert result.root == input_data
30+
31+
def test_schema__given_infinity_string__returns_valid_dict(self):
32+
input_data = {
33+
"2023-01-01.2023-12-31": "Infinity",
34+
"2024-01-01.2024-12-31": "-Infinity",
35+
}
36+
37+
result = ParameterChangeDict(root=input_data)
38+
39+
assert isinstance(result, ParameterChangeDict)
40+
assert result.root == {
41+
"2023-01-01.2023-12-31": np.inf,
42+
"2024-01-01.2024-12-31": -np.inf,
43+
}
44+
45+
def test_schema__given_invalid_date_format__raises_validation_error(self):
46+
input_data = {"2023-01-01.2023-12-31": 0.1, "invalid_date_format": 0.2}
47+
48+
with pytest.raises(
49+
ValidationError, match="Invalid date format in key"
50+
):
51+
ParameterChangeDict(root=input_data)
52+
53+
def test_schema__given_invalid_key_type__raises_validation_error(self):
54+
input_data = {123: 0.1, "2024-01-01.2024-12-31": 0.2}
55+
56+
with pytest.raises(
57+
ValidationError, match="validation error for ParameterChangeDict"
58+
):
59+
ParameterChangeDict(root=input_data)
60+
61+
62+
class TestParametricReform:
63+
def test_schema__given_valid_dict__returns_valid_reform(self):
64+
input_data = {
65+
"parameter1": {
66+
"2023-01-01.2023-12-31": 0.1,
67+
"2024-01-01.2024-12-31": 0.2,
68+
},
69+
"parameter2": {
70+
"2023-01-01.2023-12-31": 0.3,
71+
"2024-01-01.2024-12-31": 0.4,
72+
},
73+
}
74+
75+
expected_output_data = {
76+
"parameter1": ParameterChangeDict(
77+
root={
78+
"2023-01-01.2023-12-31": 0.1,
79+
"2024-01-01.2024-12-31": 0.2,
80+
}
81+
),
82+
"parameter2": ParameterChangeDict(
83+
root={
84+
"2023-01-01.2023-12-31": 0.3,
85+
"2024-01-01.2024-12-31": 0.4,
86+
}
87+
),
88+
}
89+
90+
result = ParametricReform(root=input_data)
91+
92+
assert isinstance(result, ParametricReform)
93+
assert result.root == expected_output_data
94+
95+
def test_schema__given_invalid_key_type__raises_validation_error(self):
96+
input_data = {
97+
123: {"2023-01-01.2023-12-31": 0.1, "2024-01-01.2024-12-31": 0.2},
98+
"valid_parameter": {
99+
"2023-01-01.2023-12-31": 0.3,
100+
"2024-01-01.2024-12-31": 0.4,
101+
},
102+
}
103+
104+
with pytest.raises(
105+
ValidationError, match=r"validation errors? for ParametricReform"
106+
):
107+
ParametricReform(root=input_data)
108+
109+
def test_schema__given_dateless_structure__raises_validation_error(self):
110+
input_data = {"parameter1": 0.1, "parameter2": 0.2}
111+
112+
with pytest.raises(
113+
ValidationError, match=r"validation errors? for ParametricReform"
114+
):
115+
ParametricReform(root=input_data)

0 commit comments

Comments
 (0)