Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
- bump: patch
changes:
changed:
- Refactored ParametricReform schema into clearer subschemas.
- Added conversion of Infinity and -Infinity to np.inf and -np.inf.
80 changes: 72 additions & 8 deletions policyengine/utils/reforms.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,77 @@
import re
from pydantic import RootModel, ValidationError, Field, model_validator
from typing import Dict, TYPE_CHECKING
from annotated_types import Ge, Le
from typing_extensions import Annotated
from typing import Callable
from policyengine_core.simulations import Simulation
from pydantic import (
RootModel,
field_validator,
)
from typing import Dict, Any


class ParameterChangeValue(RootModel):
"""A value for a parameter change, which can be any primitive type or 'Infinity'/'-Infinity'"""

# To prevent validation errors, allow all types except containers
# via field validator
root: Any

@field_validator("root", mode="after")
@classmethod
def check_type(cls, value: Any) -> Any:
# Check if the value is not a container type
if isinstance(value, (dict, list, set, tuple)):
raise ValueError(
"ParameterChangeValue must not be a container type (dict, list, set, tuple)"
)
return value

# Convert "Infinity" to "np.inf" and "-Infinity" to "-np.inf"
@field_validator("root", mode="after")
@classmethod
def convert_infinity(cls, value: Any) -> Any:
if isinstance(value, str):
if value == "Infinity":
value = float("inf")
elif value == "-Infinity":
value = float("-inf")
return value


class ParameterChangeDict(RootModel):
"""
A dict of changes to a parameter, with custom date string as keys
and various possible value types.

Keys can be formatted one of two ways:
1. A single year (e.g., "YYYY")
2. A date range (e.g., "YYYY-MM-DD.YYYY-MM-DD")
"""

root: Dict[str, ParameterChangeValue]

@field_validator("root", mode="after")
@classmethod
def validate_dates(
cls, value: Dict[str, ParameterChangeValue]
) -> Dict[str, ParameterChangeValue]:

year_keys_re = r"^\d{4}$"
date_range_keys_re = r"^\d{4}-\d{2}-\d{2}\.\d{4}-\d{2}-\d{2}$"

for key in value.keys():
if not re.match(year_keys_re, key) and not re.match(
date_range_keys_re, key
):
raise ValueError(
f"Key '{key}' must be a single year (YYYY) or a date range (YYYY-MM-DD.YYYY-MM-DD)"
)
return value


class ParametricReform(RootModel):
"""A reform that just changes parameter values."""
"""
A reform that just changes parameter values.

This is a dict that equates a parameter name to either a single value or a dict of changes.

"""

root: Dict[str, Dict | float | bool]
root: Dict[str, ParameterChangeValue | ParameterChangeDict]
268 changes: 268 additions & 0 deletions tests/utils/test_reforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,268 @@
from pydantic import ValidationError
import pytest
import numpy as np

from policyengine.utils.reforms import (
ParameterChangeDict,
ParametricReform,
ParameterChangeValue,
)


class TestParameterChangeDict:
def test_schema__given_float_inputs__returns_valid_dict(self):
input_data = {
"2023-01-01.2023-12-31": 0.1,
"2024-01-01.2024-12-31": 0.2,
}

expected_output_data = {
"2023-01-01.2023-12-31": ParameterChangeValue(root=0.1),
"2024-01-01.2024-12-31": ParameterChangeValue(root=0.2),
}

result = ParameterChangeDict(root=input_data)

assert isinstance(result, ParameterChangeDict)
assert result.root == expected_output_data

def test_schema__given_string_inputs__returns_valid_dict(self):
input_data = {
"2023-01-01.2023-12-31": "0.1",
"2024-01-01.2024-12-31": "0.2",
}

expected_output_data = {
"2023-01-01.2023-12-31": ParameterChangeValue(root="0.1"),
"2024-01-01.2024-12-31": ParameterChangeValue(root="0.2"),
}

result = ParameterChangeDict(root=input_data)

assert isinstance(result, ParameterChangeDict)
assert result.root == expected_output_data

def test_schema__given_infinity_string__returns_valid_dict(self):
input_data = {
"2023-01-01.2023-12-31": "Infinity",
"2024-01-01.2024-12-31": "-Infinity",
}

expected_output_data = {
"2023-01-01.2023-12-31": ParameterChangeValue(root=np.inf),
"2024-01-01.2024-12-31": ParameterChangeValue(root=-np.inf),
}

result = ParameterChangeDict(root=input_data)

assert isinstance(result, ParameterChangeDict)
assert result.root == expected_output_data

def test_schema__given_yearly_input__returns_valid_dict(self):
input_data = {
"2023": 0.1,
"2024": 0.2,
}

expected_output_data = {
"2023": ParameterChangeValue(root=0.1),
"2024": ParameterChangeValue(root=0.2),
}

result = ParameterChangeDict(root=input_data)

assert isinstance(result, ParameterChangeDict)
assert result.root == expected_output_data

def test_schema__given_invalid_date_format__raises_validation_error(self):
input_data = {"2023-01-01.2023-12-31": 0.1, "invalid_date_format": 0.2}

with pytest.raises(
ValidationError,
match=r"validation errors? for ParameterChangeDict",
):
ParameterChangeDict(root=input_data)

def test_schema__given_non_date_key_type__raises_validation_error(self):
input_data = {123: 0.1, "2024-01-01.2024-12-31": 0.2}

with pytest.raises(
ValidationError, match="validation error for ParameterChangeDict"
):
ParameterChangeDict(root=input_data)

def test_schema__given_incorrect_date_key_type__raises_validation_error(
self,
):
input_data = {
"2023-01-01.2023-12-31": 0.1,
"2024-01-01.2024-12-31": 0.2,
"2024.01.01-2025.12.31": 0.3,
}

with pytest.raises(
ValidationError, match="validation error for ParameterChangeDict"
):
ParameterChangeDict(root=input_data)


class TestParameterChangeValue:
def test_schema__given_float_input__returns_valid_value(self):
input_data = 0.1

result = ParameterChangeValue(root=input_data)

assert isinstance(result, ParameterChangeValue)
assert result.root == input_data

def test_schema__given_string_input__returns_valid_value(self):
input_data = "0.1"

result = ParameterChangeValue(root=input_data)

assert isinstance(result, ParameterChangeValue)
assert result.root == input_data

def test_schema__given_bool_input__returns_valid_value(self):
input_data = True

result = ParameterChangeValue(root=input_data)

assert isinstance(result, ParameterChangeValue)
assert result.root == input_data

def test_schema__given_infinity_string__returns_valid_value(self):
input_data = "Infinity"

result = ParameterChangeValue(root=input_data)

assert isinstance(result, ParameterChangeValue)
assert result.root == float("inf")

def test_schema__given_negative_infinity_string__returns_valid_value(self):
input_data = "-Infinity"

result = ParameterChangeValue(root=input_data)

assert isinstance(result, ParameterChangeValue)
assert result.root == float("-inf")

def test_schema__given_invalid_type__raises_validation_error(self):
input_data = [0.1, 0.2]

with pytest.raises(
ValidationError, match="validation error for ParameterChangeValue"
):
ParameterChangeValue(root=input_data)

def test_schema__given_dict_input__raises_validation_error(self):
input_data = {"key": "value"}

with pytest.raises(
ValidationError, match="validation error for ParameterChangeValue"
):
ParameterChangeValue(root=input_data)


class TestParametricReform:
def test_schema__given_full_date_dict__returns_valid_reform(self):
input_data = {
"parameter1": {
"2023-01-01.2023-12-31": 0.1,
"2024-01-01.2024-12-31": 0.2,
},
"parameter2": {
"2023-01-01.2023-12-31": 0.3,
"2024-01-01.2024-12-31": 0.4,
},
}

expected_output_data = {
"parameter1": ParameterChangeDict(
root={
"2023-01-01.2023-12-31": 0.1,
"2024-01-01.2024-12-31": 0.2,
}
),
"parameter2": ParameterChangeDict(
root={
"2023-01-01.2023-12-31": 0.3,
"2024-01-01.2024-12-31": 0.4,
}
),
}

result = ParametricReform(root=input_data)

assert isinstance(result, ParametricReform)
assert result.root == expected_output_data

def test_schema__given_yearly_dict__returns_valid_reform(self):
input_data = {
"parameter1": {"2023": 0.1, "2024": 0.2},
"parameter2": {"2023": 0.3, "2024": 0.4},
}

expected_output_data = {
"parameter1": ParameterChangeDict(root={"2023": 0.1, "2024": 0.2}),
"parameter2": ParameterChangeDict(root={"2023": 0.3, "2024": 0.4}),
}

result = ParametricReform(root=input_data)

assert isinstance(result, ParametricReform)
assert result.root == expected_output_data

def test_schema__given_single_value_dict__returns_valid_reform(self):
input_data = {
"parameter1": 0.1,
"parameter2": 0.2,
}

expected_output_data = {
"parameter1": ParameterChangeValue(root=0.1),
"parameter2": ParameterChangeValue(root=0.2),
}

result = ParametricReform(root=input_data)

assert isinstance(result, ParametricReform)
assert result.root == expected_output_data

def test_schema__given_mixed_dict__returns_valid_reform(self):
input_data = {
"parameter1": {
"2023-01-01.2023-12-31": 0.1,
"2024-01-01.2024-12-31": 0.2,
},
"parameter2": 0.3,
}

expected_output_data = {
"parameter1": ParameterChangeDict(
root={
"2023-01-01.2023-12-31": 0.1,
"2024-01-01.2024-12-31": 0.2,
}
),
"parameter2": ParameterChangeValue(root=0.3),
}

result = ParametricReform(root=input_data)

assert isinstance(result, ParametricReform)
assert result.root == expected_output_data

def test_schema__given_invalid_key_type__raises_validation_error(self):
input_data = {
123: {"2023-01-01.2023-12-31": 0.1, "2024-01-01.2024-12-31": 0.2},
"valid_parameter": {
"2023-01-01.2023-12-31": 0.3,
"2024-01-01.2024-12-31": 0.4,
},
}

with pytest.raises(
ValidationError, match=r"validation errors? for ParametricReform"
):
ParametricReform(root=input_data)
Loading