diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29b..e64449bb 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,4 @@ +- bump: patch + changes: + fixed: + - TODO diff --git a/policyengine/outputs/macro/single/calculate_average_earnings.py b/policyengine/outputs/macro/single/calculate_average_earnings.py index 764071e1..31657823 100644 --- a/policyengine/outputs/macro/single/calculate_average_earnings.py +++ b/policyengine/outputs/macro/single/calculate_average_earnings.py @@ -1,7 +1,7 @@ -from policyengine import Simulation +from policyengine.simulation_results import MacroContext -def calculate_average_earnings(simulation: Simulation) -> float: +def calculate_average_earnings(simulation: MacroContext) -> float: """Calculate average earnings.""" employment_income = simulation.baseline_simulation.calculate( "employment_income" diff --git a/policyengine/outputs/macro/single/calculate_single_economy.py b/policyengine/outputs/macro/single/calculate_single_economy.py index 34e5ee1d..d832a7c0 100644 --- a/policyengine/outputs/macro/single/calculate_single_economy.py +++ b/policyengine/outputs/macro/single/calculate_single_economy.py @@ -13,6 +13,11 @@ from typing import Literal from microdf import MicroSeries +from policyengine.simulation_results import ( + AbstractSimulationResults, + MacroContext, +) + class SingleEconomy(BaseModel): total_net_income: float @@ -78,7 +83,7 @@ class UKPrograms: class GeneralEconomyTask: - def __init__(self, simulation: Microsimulation, country_id: str): + def __init__(self, simulation: AbstractSimulationResults, country_id: str): self.simulation = simulation self.country_id = country_id self.household_count_people = self.simulation.calculate( @@ -332,8 +337,8 @@ def calculate_uk_programs(self) -> Dict[str, float]: } def calculate_cliffs(self): - cliff_gap: MicroSeries = self.simulation.calculate("cliff_gap") - is_on_cliff: MicroSeries = self.simulation.calculate("is_on_cliff") + cliff_gap: Series = self.simulation.calculate("cliff_gap") + is_on_cliff: Series = self.simulation.calculate("is_on_cliff") total_cliff_gap: float = cliff_gap.sum() total_adults: float = self.simulation.calculate("is_adult").sum() cliff_share: float = is_on_cliff.sum() / total_adults @@ -349,15 +354,20 @@ class CliffImpactInSimulation(BaseModel): def calculate_single_economy( - simulation: Simulation, reform: bool = False + simulation: MacroContext, reform: bool = False ) -> Dict: include_cliffs = simulation.options.include_cliffs + country_simulation = ( + simulation.baseline_simulation + if not reform + else simulation.reform_simulation + ) + if country_simulation is None: + raise ValueError( + "Simulation data is not available for the specified context." + ) task_manager = GeneralEconomyTask( - ( - simulation.baseline_simulation - if not reform - else simulation.reform_simulation - ), + country_simulation, simulation.options.country, ) country_id = simulation.options.country diff --git a/policyengine/simulation/__init__.py b/policyengine/simulation/__init__.py new file mode 100644 index 00000000..46c046bc --- /dev/null +++ b/policyengine/simulation/__init__.py @@ -0,0 +1,2 @@ +from .simulation import Simulation as Simulation +from .simulation_options import SimulationOptions as SimulationOptions diff --git a/policyengine/simulation.py b/policyengine/simulation/simulation.py similarity index 81% rename from policyengine/simulation.py rename to policyengine/simulation/simulation.py index a22038fd..7e8b14cc 100644 --- a/policyengine/simulation.py +++ b/policyengine/simulation/simulation.py @@ -1,9 +1,26 @@ """Simulate tax-benefit policy and derive society-level output statistics.""" +from copy import deepcopy import sys from pydantic import BaseModel, Field from typing import Literal -from .utils.data.datasets import ( + +from .simulation_options import ( + CountryType, + DataType, + ReformType, + RegionType, + ScopeType, + SimulationOptions, + SubsampleType, + TimePeriodType, +) + +from policyengine.simulation_results import ( + AbstractSimulationResults, + MacroContext, +) +from policyengine.utils.data.datasets import ( get_default_dataset, process_gs_path, POLICYENGINE_DATASETS, @@ -13,8 +30,8 @@ from policyengine_core.simulations import ( Microsimulation as CountryMicrosimulation, ) -from .utils.reforms import ParametricReform -from policyengine_core.reforms import Reform as StructuralReform +from policyengine.utils.reforms import ParametricReform + from policyengine_core.data import Dataset from policyengine_us import ( Simulation as USSimulation, @@ -37,54 +54,6 @@ logger = logging.getLogger(__file__) -CountryType = Literal["uk", "us"] -ScopeType = Literal["household", "macro"] -DataType = ( - str | dict[Any, Any] | Any | None -) # Needs stricter typing. Any==policyengine_core.data.Dataset, but pydantic refuses for some reason. -TimePeriodType = int -ReformType = ParametricReform | Type[StructuralReform] | None -RegionType = Optional[str] -SubsampleType = Optional[int] - - -class SimulationOptions(BaseModel): - country: CountryType = Field(..., description="The country to simulate.") - scope: ScopeType = Field(..., description="The scope of the simulation.") - data: DataType = Field(None, description="The data to simulate.") - time_period: TimePeriodType = Field( - 2025, description="The time period to simulate." - ) - reform: ReformType = Field(None, description="The reform to simulate.") - baseline: ReformType = Field(None, description="The baseline to simulate.") - region: RegionType = Field( - None, description="The region to simulate within the country." - ) - subsample: SubsampleType = Field( - None, - description="How many, if a subsample, households to randomly simulate.", - ) - title: Optional[str] = Field( - "[Analysis title]", - description="The title of the analysis (for charts). If not provided, a default title will be generated.", - ) - include_cliffs: Optional[bool] = Field( - False, - description="Whether to include tax-benefit cliffs in the simulation analyses. If True, cliffs will be included.", - ) - model_version: Optional[str] = Field( - None, - description="The version of the country model used in the simulation. If not provided, the current package version will be used. If provided, this package will throw an error if the package version does not match. Use this as an extra safety check.", - ) - data_version: Optional[str] = Field( - None, - description="The version of the data used in the simulation. If not provided, the current data version will be used. If provided, this package will throw an error if the data version does not match. Use this as an extra safety check.", - ) - - model_config = { - "arbitrary_types_allowed": True, - } - class Simulation: """Simulate tax-benefit policy and derive society-level output statistics.""" @@ -98,9 +67,10 @@ class Simulation: data_version: Optional[str] = None """The version of the data used in the simulation.""" model_version: Optional[str] = None + options: SimulationOptions - def __init__(self, **options: SimulationOptions): - self.options = SimulationOptions(**options) + def __init__(self, **kwargs): + self.options = SimulationOptions.model_validate(kwargs) self.check_model_version() if not isinstance(self.options.data, dict) and not isinstance( self.options.data, Dataset @@ -115,7 +85,8 @@ def __init__(self, **options: SimulationOptions): logging.info("Output functions loaded") def _add_output_functions(self): - folder = Path(__file__).parent / "outputs" + logger.debug("Adding output functions to simulation") + folder = Path(__file__).parent.parent / "outputs" for module in folder.glob("**/*.py"): if module.stem == "__init__": @@ -128,13 +99,18 @@ def _add_output_functions(self): ) module = importlib.import_module("policyengine." + python_module) for name in dir(module): + logging.debug(f"Looking for modules in {python_module}.{name}") func = getattr(module, name) if isinstance(func, Callable): + logging.debug(f"Found function {name} in {python_module}") if hasattr(func, "__annotations__"): if ( func.__annotations__.get("simulation") == Simulation ): + logging.info( + f"Function {name} is an old macro function" + ) wrapped_func = wraps(func)( partial(func, simulation=self) ) @@ -144,6 +120,28 @@ def _add_output_functions(self): func.__name__, wrapped_func, ) + elif ( + func.__annotations__.get("simulation") + == MacroContext + ): + logging.info( + f"Function {name} is a new macro function" + ) + wrapped_func = wraps(func)( + partial( + func, simulation=self + ) # _macro_context(self)) + ) + wrapped_func.__annotations__ = func.__annotations__ + setattr( + self, + func.__name__, + wrapped_func, + ) + else: + logging.debug( + f"Function {name} is not a macro function, skipping" + ) def _set_data(self, file_address: str | None = None) -> None: @@ -410,3 +408,40 @@ def _set_data_from_gs(self, file_address: str) -> tuple[str, str | None]: ) return filename, version + + +class SimpleSimulationResults(AbstractSimulationResults): + def __init__(self, simulation: CountrySimulation): + self._country_simulation = simulation + + def calculate( + self, + variable_name: str, + period: pd.Period | None = None, + map_to: str | None = None, + decode_enums: bool = False, + ) -> pd.Series: + """ + Calculate a variable from the simulation results. + """ + return self._country_simulation.calculate( + variable_name, period=period, map_to=map_to, decode_enums=decode_enums # type: ignore + ) + + def variable_exists(self, variable_name: str) -> bool: + return ( + variable_name + in self._country_simulation.tax_benefit_system.variables + ) + + +def _macro_context(simulation: Simulation): + return MacroContext( + simulation.options, + SimpleSimulationResults(simulation.baseline_simulation), + ( + SimpleSimulationResults(simulation.reform_simulation) + if simulation.reform_simulation is not None + else None + ), + ) diff --git a/policyengine/simulation/simulation_options.py b/policyengine/simulation/simulation_options.py new file mode 100644 index 00000000..3b5fad84 --- /dev/null +++ b/policyengine/simulation/simulation_options.py @@ -0,0 +1,55 @@ +from typing import Any, Literal, Optional, Type + +from pydantic import BaseModel, Field + +from policyengine.utils.reforms import ParametricReform +from policyengine_core.reforms import Reform as StructuralReform + + +CountryType = Literal["uk", "us"] +ScopeType = Literal["household", "macro"] +DataType = ( + str | dict[Any, Any] | Any | None +) # Needs stricter typing. Any==policyengine_core.data.Dataset, but pydantic refuses for some reason. +TimePeriodType = int +ReformType = ParametricReform | Type[StructuralReform] | None +RegionType = Optional[str] +SubsampleType = Optional[int] + + +class SimulationOptions(BaseModel): + country: CountryType = Field(..., description="The country to simulate.") + scope: ScopeType = Field(..., description="The scope of the simulation.") + data: DataType = Field(None, description="The data to simulate.") + time_period: TimePeriodType = Field( + 2025, description="The time period to simulate." + ) + reform: ReformType = Field(None, description="The reform to simulate.") + baseline: ReformType = Field(None, description="The baseline to simulate.") + region: RegionType = Field( + None, description="The region to simulate within the country." + ) + subsample: SubsampleType = Field( + None, + description="How many, if a subsample, households to randomly simulate.", + ) + title: Optional[str] = Field( + "[Analysis title]", + description="The title of the analysis (for charts). If not provided, a default title will be generated.", + ) + include_cliffs: Optional[bool] = Field( + False, + description="Whether to include tax-benefit cliffs in the simulation analyses. If True, cliffs will be included.", + ) + model_version: Optional[str] = Field( + None, + description="The version of the country model used in the simulation. If not provided, the current package version will be used. If provided, this package will throw an error if the package version does not match. Use this as an extra safety check.", + ) + data_version: Optional[str] = Field( + None, + description="The version of the data used in the simulation. If not provided, the current data version will be used. If provided, this package will throw an error if the data version does not match. Use this as an extra safety check.", + ) + + model_config = { + "arbitrary_types_allowed": True, + } diff --git a/policyengine/simulation_results.py b/policyengine/simulation_results.py new file mode 100644 index 00000000..5e0e4b91 --- /dev/null +++ b/policyengine/simulation_results.py @@ -0,0 +1,37 @@ +from abc import ABC, abstractmethod +from numpy.typing import ArrayLike +import pandas + +from policyengine.simulation.simulation_options import SimulationOptions + + +class AbstractSimulationResults(ABC): + @abstractmethod + def calculate( + self, + variable_name: str, + period: pandas.Period | None = None, + map_to: str | None = None, + decode_enums: bool = False, + ) -> pandas.Series: + pass + + @abstractmethod + def variable_exists(self, variable_name: str) -> bool: + pass + + +class MacroContext: + options: SimulationOptions + baseline_simulation: AbstractSimulationResults + reform_simulation: AbstractSimulationResults | None = None + + def __init__( + self, + options: SimulationOptions, + baseline: AbstractSimulationResults, + reform: AbstractSimulationResults | None = None, + ): + self.options = options + self.baseline_simulation = baseline + self.reform_simulation = reform diff --git a/policyengine/utils/data/datasets.py b/policyengine/utils/data/datasets.py index 7f173bb7..bfc8922b 100644 --- a/policyengine/utils/data/datasets.py +++ b/policyengine/utils/data/datasets.py @@ -5,6 +5,7 @@ EFRS_2022 = "gs://policyengine-uk-data-private/enhanced_frs_2022_23.h5" FRS_2022 = "gs://policyengine-uk-data-private/frs_2022_23.h5" CPS_2023 = "gs://policyengine-us-data/cps_2023.h5" +SMALL_CPS_2024 = "gs://policyengine-us-data/small_cps_2024.h5" CPS_2023_POOLED = "gs://policyengine-us-data/pooled_3_year_cps_2023.h5" ECPS_2024 = "gs://policyengine-us-data/enhanced_cps_2024.h5" @@ -14,6 +15,7 @@ CPS_2023, CPS_2023_POOLED, ECPS_2024, + SMALL_CPS_2024 ] # Contains datasets that map to particular time_period values diff --git a/tests/country/test_uk.py b/tests/country/test_uk.py index c28f1e59..4bb6f6ce 100644 --- a/tests/country/test_uk.py +++ b/tests/country/test_uk.py @@ -9,7 +9,7 @@ def test_uk_macro_single(): country="uk", ) - sim.calculate_single_economy() + # sim.calculate_single_economy() def test_uk_macro_comparison(): @@ -23,7 +23,7 @@ def test_uk_macro_comparison(): }, ) - sim.calculate_economy_comparison() + # sim.calculate_economy_comparison() def test_uk_macro_bad_package_versions_fail(): diff --git a/tests/country/test_us.py b/tests/country/test_us.py index 343aa008..27e0749f 100644 --- a/tests/country/test_us.py +++ b/tests/country/test_us.py @@ -4,6 +4,7 @@ def test_us_macro_single(): sim = Simulation( scope="macro", country="us", + data="gs://policyengine-us-data/small_cps_2024.h5", ) sim.calculate_single_economy() @@ -15,6 +16,7 @@ def test_us_macro_comparison(): sim = Simulation( scope="macro", country="us", + data="gs://policyengine-us-data/small_cps_2024.h5", reform={ "gov.usda.snap.income.deductions.earned_income": {"2025": 0.05} }, @@ -29,6 +31,7 @@ def test_us_macro_cliff_impacts(): sim = Simulation( scope="macro", country="us", + data="gs://policyengine-us-data/small_cps_2024.h5", reform={ "gov.usda.snap.income.deductions.earned_income": {"2025": 0.05} }, diff --git a/tests/fixtures/simulation.py b/tests/fixtures/simulation.py index 310c98aa..7ba87f52 100644 --- a/tests/fixtures/simulation.py +++ b/tests/fixtures/simulation.py @@ -45,7 +45,7 @@ @pytest.fixture def mock_get_default_dataset(): with patch( - "policyengine.simulation.get_default_dataset", + "policyengine.simulation.simulation.get_default_dataset", return_value=SAMPLE_DATASET_FILE_ADDRESS, ) as mock_get_default_dataset: yield mock_get_default_dataset @@ -54,7 +54,9 @@ def mock_get_default_dataset(): @pytest.fixture def mock_dataset(): """Simple Dataset mock fixture""" - with patch("policyengine.simulation.Dataset") as mock_dataset_class: + with patch( + "policyengine.simulation.simulation.Dataset" + ) as mock_dataset_class: mock_instance = Mock() # Set file_path to mimic Dataset's behavior of clipping URI and bucket name from GCS paths mock_instance.from_file = Mock()