diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29b..5822f7a9 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,5 @@ +- bump: patch + changes: + changed: + - Refactored ParametricReform schema into clearer subschemas. + - Added conversion of Infinity and -Infinity to np.inf and -np.inf. \ No newline at end of file diff --git a/policyengine/utils/reforms.py b/policyengine/utils/reforms.py index 6d681f57..6dfc804b 100644 --- a/policyengine/utils/reforms.py +++ b/policyengine/utils/reforms.py @@ -1,13 +1,77 @@ import re -from pydantic import RootModel, ValidationError, Field, model_validator -from typing import Dict, TYPE_CHECKING -from annotated_types import Ge, Le -from typing_extensions import Annotated -from typing import Callable -from policyengine_core.simulations import Simulation +from pydantic import ( + RootModel, + field_validator, +) +from typing import Dict, Any + + +class ParameterChangeValue(RootModel): + """A value for a parameter change, which can be any primitive type or 'Infinity'/'-Infinity'""" + + # To prevent validation errors, allow all types except containers + # via field validator + root: Any + + @field_validator("root", mode="after") + @classmethod + def check_type(cls, value: Any) -> Any: + # Check if the value is not a container type + if isinstance(value, (dict, list, set, tuple)): + raise ValueError( + "ParameterChangeValue must not be a container type (dict, list, set, tuple)" + ) + return value + + # Convert "Infinity" to "np.inf" and "-Infinity" to "-np.inf" + @field_validator("root", mode="after") + @classmethod + def convert_infinity(cls, value: Any) -> Any: + if isinstance(value, str): + if value == "Infinity": + value = float("inf") + elif value == "-Infinity": + value = float("-inf") + return value + + +class ParameterChangeDict(RootModel): + """ + A dict of changes to a parameter, with custom date string as keys + and various possible value types. + + Keys can be formatted one of two ways: + 1. A single year (e.g., "YYYY") + 2. A date range (e.g., "YYYY-MM-DD.YYYY-MM-DD") + """ + + root: Dict[str, ParameterChangeValue] + + @field_validator("root", mode="after") + @classmethod + def validate_dates( + cls, value: Dict[str, ParameterChangeValue] + ) -> Dict[str, ParameterChangeValue]: + + year_keys_re = r"^\d{4}$" + date_range_keys_re = r"^\d{4}-\d{2}-\d{2}\.\d{4}-\d{2}-\d{2}$" + + for key in value.keys(): + if not re.match(year_keys_re, key) and not re.match( + date_range_keys_re, key + ): + raise ValueError( + f"Key '{key}' must be a single year (YYYY) or a date range (YYYY-MM-DD.YYYY-MM-DD)" + ) + return value class ParametricReform(RootModel): - """A reform that just changes parameter values.""" + """ + A reform that just changes parameter values. + + This is a dict that equates a parameter name to either a single value or a dict of changes. + + """ - root: Dict[str, Dict | float | bool] + root: Dict[str, ParameterChangeValue | ParameterChangeDict] diff --git a/tests/utils/test_reforms.py b/tests/utils/test_reforms.py new file mode 100644 index 00000000..cf43630d --- /dev/null +++ b/tests/utils/test_reforms.py @@ -0,0 +1,268 @@ +from pydantic import ValidationError +import pytest +import numpy as np + +from policyengine.utils.reforms import ( + ParameterChangeDict, + ParametricReform, + ParameterChangeValue, +) + + +class TestParameterChangeDict: + def test_schema__given_float_inputs__returns_valid_dict(self): + input_data = { + "2023-01-01.2023-12-31": 0.1, + "2024-01-01.2024-12-31": 0.2, + } + + expected_output_data = { + "2023-01-01.2023-12-31": ParameterChangeValue(root=0.1), + "2024-01-01.2024-12-31": ParameterChangeValue(root=0.2), + } + + result = ParameterChangeDict(root=input_data) + + assert isinstance(result, ParameterChangeDict) + assert result.root == expected_output_data + + def test_schema__given_string_inputs__returns_valid_dict(self): + input_data = { + "2023-01-01.2023-12-31": "0.1", + "2024-01-01.2024-12-31": "0.2", + } + + expected_output_data = { + "2023-01-01.2023-12-31": ParameterChangeValue(root="0.1"), + "2024-01-01.2024-12-31": ParameterChangeValue(root="0.2"), + } + + result = ParameterChangeDict(root=input_data) + + assert isinstance(result, ParameterChangeDict) + assert result.root == expected_output_data + + def test_schema__given_infinity_string__returns_valid_dict(self): + input_data = { + "2023-01-01.2023-12-31": "Infinity", + "2024-01-01.2024-12-31": "-Infinity", + } + + expected_output_data = { + "2023-01-01.2023-12-31": ParameterChangeValue(root=np.inf), + "2024-01-01.2024-12-31": ParameterChangeValue(root=-np.inf), + } + + result = ParameterChangeDict(root=input_data) + + assert isinstance(result, ParameterChangeDict) + assert result.root == expected_output_data + + def test_schema__given_yearly_input__returns_valid_dict(self): + input_data = { + "2023": 0.1, + "2024": 0.2, + } + + expected_output_data = { + "2023": ParameterChangeValue(root=0.1), + "2024": ParameterChangeValue(root=0.2), + } + + result = ParameterChangeDict(root=input_data) + + assert isinstance(result, ParameterChangeDict) + assert result.root == expected_output_data + + def test_schema__given_invalid_date_format__raises_validation_error(self): + input_data = {"2023-01-01.2023-12-31": 0.1, "invalid_date_format": 0.2} + + with pytest.raises( + ValidationError, + match=r"validation errors? for ParameterChangeDict", + ): + ParameterChangeDict(root=input_data) + + def test_schema__given_non_date_key_type__raises_validation_error(self): + input_data = {123: 0.1, "2024-01-01.2024-12-31": 0.2} + + with pytest.raises( + ValidationError, match="validation error for ParameterChangeDict" + ): + ParameterChangeDict(root=input_data) + + def test_schema__given_incorrect_date_key_type__raises_validation_error( + self, + ): + input_data = { + "2023-01-01.2023-12-31": 0.1, + "2024-01-01.2024-12-31": 0.2, + "2024.01.01-2025.12.31": 0.3, + } + + with pytest.raises( + ValidationError, match="validation error for ParameterChangeDict" + ): + ParameterChangeDict(root=input_data) + + +class TestParameterChangeValue: + def test_schema__given_float_input__returns_valid_value(self): + input_data = 0.1 + + result = ParameterChangeValue(root=input_data) + + assert isinstance(result, ParameterChangeValue) + assert result.root == input_data + + def test_schema__given_string_input__returns_valid_value(self): + input_data = "0.1" + + result = ParameterChangeValue(root=input_data) + + assert isinstance(result, ParameterChangeValue) + assert result.root == input_data + + def test_schema__given_bool_input__returns_valid_value(self): + input_data = True + + result = ParameterChangeValue(root=input_data) + + assert isinstance(result, ParameterChangeValue) + assert result.root == input_data + + def test_schema__given_infinity_string__returns_valid_value(self): + input_data = "Infinity" + + result = ParameterChangeValue(root=input_data) + + assert isinstance(result, ParameterChangeValue) + assert result.root == float("inf") + + def test_schema__given_negative_infinity_string__returns_valid_value(self): + input_data = "-Infinity" + + result = ParameterChangeValue(root=input_data) + + assert isinstance(result, ParameterChangeValue) + assert result.root == float("-inf") + + def test_schema__given_invalid_type__raises_validation_error(self): + input_data = [0.1, 0.2] + + with pytest.raises( + ValidationError, match="validation error for ParameterChangeValue" + ): + ParameterChangeValue(root=input_data) + + def test_schema__given_dict_input__raises_validation_error(self): + input_data = {"key": "value"} + + with pytest.raises( + ValidationError, match="validation error for ParameterChangeValue" + ): + ParameterChangeValue(root=input_data) + + +class TestParametricReform: + def test_schema__given_full_date_dict__returns_valid_reform(self): + input_data = { + "parameter1": { + "2023-01-01.2023-12-31": 0.1, + "2024-01-01.2024-12-31": 0.2, + }, + "parameter2": { + "2023-01-01.2023-12-31": 0.3, + "2024-01-01.2024-12-31": 0.4, + }, + } + + expected_output_data = { + "parameter1": ParameterChangeDict( + root={ + "2023-01-01.2023-12-31": 0.1, + "2024-01-01.2024-12-31": 0.2, + } + ), + "parameter2": ParameterChangeDict( + root={ + "2023-01-01.2023-12-31": 0.3, + "2024-01-01.2024-12-31": 0.4, + } + ), + } + + result = ParametricReform(root=input_data) + + assert isinstance(result, ParametricReform) + assert result.root == expected_output_data + + def test_schema__given_yearly_dict__returns_valid_reform(self): + input_data = { + "parameter1": {"2023": 0.1, "2024": 0.2}, + "parameter2": {"2023": 0.3, "2024": 0.4}, + } + + expected_output_data = { + "parameter1": ParameterChangeDict(root={"2023": 0.1, "2024": 0.2}), + "parameter2": ParameterChangeDict(root={"2023": 0.3, "2024": 0.4}), + } + + result = ParametricReform(root=input_data) + + assert isinstance(result, ParametricReform) + assert result.root == expected_output_data + + def test_schema__given_single_value_dict__returns_valid_reform(self): + input_data = { + "parameter1": 0.1, + "parameter2": 0.2, + } + + expected_output_data = { + "parameter1": ParameterChangeValue(root=0.1), + "parameter2": ParameterChangeValue(root=0.2), + } + + result = ParametricReform(root=input_data) + + assert isinstance(result, ParametricReform) + assert result.root == expected_output_data + + def test_schema__given_mixed_dict__returns_valid_reform(self): + input_data = { + "parameter1": { + "2023-01-01.2023-12-31": 0.1, + "2024-01-01.2024-12-31": 0.2, + }, + "parameter2": 0.3, + } + + expected_output_data = { + "parameter1": ParameterChangeDict( + root={ + "2023-01-01.2023-12-31": 0.1, + "2024-01-01.2024-12-31": 0.2, + } + ), + "parameter2": ParameterChangeValue(root=0.3), + } + + result = ParametricReform(root=input_data) + + assert isinstance(result, ParametricReform) + assert result.root == expected_output_data + + def test_schema__given_invalid_key_type__raises_validation_error(self): + input_data = { + 123: {"2023-01-01.2023-12-31": 0.1, "2024-01-01.2024-12-31": 0.2}, + "valid_parameter": { + "2023-01-01.2023-12-31": 0.3, + "2024-01-01.2024-12-31": 0.4, + }, + } + + with pytest.raises( + ValidationError, match=r"validation errors? for ParametricReform" + ): + ParametricReform(root=input_data)