From 843b7d7cec3c22120a5ef4f46b186e917197c33a Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Mon, 17 Feb 2025 13:17:29 +0000 Subject: [PATCH] Move simulation parameters down to the function level Fixes #101 --- policyengine/__init__.py | 2 +- .../calculate_household_comparison.py | 4 +- .../single/calculate_single_household.py | 4 +- .../calculate_economy_comparison.py | 56 ++- .../outputs/macro/comparison/charts/budget.py | 4 +- .../comparison/charts/budget_by_program.py | 4 +- .../outputs/macro/comparison/charts/decile.py | 4 +- .../macro/comparison/charts/inequality.py | 4 +- .../macro/comparison/charts/winners_losers.py | 4 +- .../outputs/macro/comparison/decile.py | 2 +- policyengine/outputs/macro/single/budget.py | 2 +- .../single/calculate_average_earnings.py | 4 +- .../macro/single/calculate_single_economy.py | 4 +- .../outputs/macro/single/inequality.py | 2 +- policyengine/policyengine.py | 291 +++++++++++++++ policyengine/simulation.py | 330 ------------------ policyengine/utils/types.py | 15 + 17 files changed, 376 insertions(+), 360 deletions(-) create mode 100644 policyengine/policyengine.py delete mode 100644 policyengine/simulation.py create mode 100644 policyengine/utils/types.py diff --git a/policyengine/__init__.py b/policyengine/__init__.py index 63a753be..380bb5a2 100644 --- a/policyengine/__init__.py +++ b/policyengine/__init__.py @@ -1 +1 @@ -from .simulation import Simulation, SimulationOptions +from .policyengine import PolicyEngine diff --git a/policyengine/outputs/household/comparison/calculate_household_comparison.py b/policyengine/outputs/household/comparison/calculate_household_comparison.py index 756be8e2..be15da6b 100644 --- a/policyengine/outputs/household/comparison/calculate_household_comparison.py +++ b/policyengine/outputs/household/comparison/calculate_household_comparison.py @@ -2,7 +2,7 @@ import typing -from policyengine import Simulation +from policyengine import PolicyEngine from pydantic import BaseModel from policyengine.utils.calculations import get_change @@ -27,7 +27,7 @@ class HouseholdComparison(BaseModel): def calculate_household_comparison( - simulation: Simulation, + engine: PolicyEngine, ) -> HouseholdComparison: """Calculate comparison statistics between two household scenarios.""" if not simulation.is_comparison: diff --git a/policyengine/outputs/household/single/calculate_single_household.py b/policyengine/outputs/household/single/calculate_single_household.py index 19cab8e2..2bd0432d 100644 --- a/policyengine/outputs/household/single/calculate_single_household.py +++ b/policyengine/outputs/household/single/calculate_single_household.py @@ -2,7 +2,7 @@ import typing -from policyengine import Simulation +from policyengine import PolicyEngine from pydantic import BaseModel from policyengine_core.simulations import Simulation as CountrySimulation @@ -33,7 +33,7 @@ class SingleHousehold(BaseModel): def calculate_single_household( - simulation: Simulation, + engine: PolicyEngine, ) -> SingleHousehold: """Calculate household statistics for a single household scenario.""" if simulation.is_comparison: diff --git a/policyengine/outputs/macro/comparison/calculate_economy_comparison.py b/policyengine/outputs/macro/comparison/calculate_economy_comparison.py index 41398eee..aaa1211f 100644 --- a/policyengine/outputs/macro/comparison/calculate_economy_comparison.py +++ b/policyengine/outputs/macro/comparison/calculate_economy_comparison.py @@ -2,10 +2,11 @@ import typing -from policyengine import Simulation +from policyengine import PolicyEngine from pydantic import BaseModel from policyengine.utils.calculations import get_change +from policyengine.constants import DEFAULT_DATASETS_BY_COUNTRY from policyengine.outputs.macro.single import ( _calculate_government_balance, @@ -17,7 +18,24 @@ from .decile import calculate_decile_impacts, DecileImpacts from typing import Literal, List - +from policyengine.utils.types import * +from pydantic import Field + +class EconomyComparisonOptions(BaseModel): + country: CountryType = Field(..., description="The country to simulate.") + data: DataType = Field(None, description="The data to simulate.") + time_period: TimePeriodType = Field( + 2025, description="The time period to simulate.", ge=2024, le=2035 + ) + reform: PolicyType = Field(None, description="The reform to simulate.") + baseline: PolicyType = Field(None, description="The baseline to simulate.") + region: RegionType = Field( + None, description="The region to simulate within the country." + ) + subsample: SubsampleType = Field( + None, + description="How many, if a subsample, households to randomly simulate.", + ) class FiscalComparison(BaseModel): baseline: FiscalSummary @@ -52,15 +70,37 @@ class EconomyComparison(BaseModel): def calculate_economy_comparison( - simulation: Simulation, + engine: PolicyEngine, + **options, ) -> EconomyComparison: """Calculate comparison statistics between two economic scenarios.""" - if not simulation.is_comparison: - raise ValueError("Simulation must be a comparison simulation.") - baseline = simulation.baseline_simulation - reform = simulation.reform_simulation - options = simulation.options + options = EconomyComparisonOptions(**options) + + if options.data is None: + options.data = DEFAULT_DATASETS_BY_COUNTRY[ + options.country + ] + + baseline = engine.expect_simulation( + name="baseline", + country=options.country, + scope="macro", + policy=options.baseline, + data=options.data, + time_period=options.time_period, + region=options.region, + ) + + reform = engine.expect_simulation( + name="reform", + country=options.country, + scope="macro", + policy=options.reform, + data=options.data, + time_period=options.time_period, + region=options.region, + ) baseline_balance = _calculate_government_balance(baseline, options) reform_balance = _calculate_government_balance(reform, options) diff --git a/policyengine/outputs/macro/comparison/charts/budget.py b/policyengine/outputs/macro/comparison/charts/budget.py index f7e453fd..78cf2e40 100644 --- a/policyengine/outputs/macro/comparison/charts/budget.py +++ b/policyengine/outputs/macro/comparison/charts/budget.py @@ -2,14 +2,14 @@ import plotly.graph_objects as go import typing -from policyengine import Simulation +from policyengine import PolicyEngine from pydantic import BaseModel from policyengine.utils.charts import * def create_budget_comparison_chart( - simulation: Simulation, + engine: PolicyEngine, ) -> go.Figure: """Create a budget comparison chart.""" if not simulation.is_comparison: diff --git a/policyengine/outputs/macro/comparison/charts/budget_by_program.py b/policyengine/outputs/macro/comparison/charts/budget_by_program.py index 382cdfcf..e5985334 100644 --- a/policyengine/outputs/macro/comparison/charts/budget_by_program.py +++ b/policyengine/outputs/macro/comparison/charts/budget_by_program.py @@ -2,14 +2,14 @@ import plotly.graph_objects as go import typing -from policyengine import Simulation +from policyengine import PolicyEngine from pydantic import BaseModel from policyengine.utils.charts import * def create_budget_program_comparison_chart( - simulation: Simulation, + engine: PolicyEngine, ) -> go.Figure: """Create a budget comparison chart.""" if not simulation.is_comparison: diff --git a/policyengine/outputs/macro/comparison/charts/decile.py b/policyengine/outputs/macro/comparison/charts/decile.py index bd2c69e7..c14c7498 100644 --- a/policyengine/outputs/macro/comparison/charts/decile.py +++ b/policyengine/outputs/macro/comparison/charts/decile.py @@ -2,7 +2,7 @@ import plotly.graph_objects as go import typing -from policyengine import Simulation +from policyengine import PolicyEngine from pydantic import BaseModel from policyengine.utils.charts import * @@ -10,7 +10,7 @@ def create_decile_chart( - simulation: Simulation, + engine: PolicyEngine, decile_variable: Literal["income", "wealth"], relative: bool, ) -> go.Figure: diff --git a/policyengine/outputs/macro/comparison/charts/inequality.py b/policyengine/outputs/macro/comparison/charts/inequality.py index e0db5c29..ad63f1b1 100644 --- a/policyengine/outputs/macro/comparison/charts/inequality.py +++ b/policyengine/outputs/macro/comparison/charts/inequality.py @@ -2,14 +2,14 @@ import plotly.graph_objects as go import typing -from policyengine import Simulation +from policyengine import PolicyEngine from pydantic import BaseModel from policyengine.utils.charts import * def create_inequality_chart( - simulation: Simulation, + engine: PolicyEngine, relative: bool, ) -> go.Figure: """Create a budget comparison chart.""" diff --git a/policyengine/outputs/macro/comparison/charts/winners_losers.py b/policyengine/outputs/macro/comparison/charts/winners_losers.py index db21e18d..4d9df801 100644 --- a/policyengine/outputs/macro/comparison/charts/winners_losers.py +++ b/policyengine/outputs/macro/comparison/charts/winners_losers.py @@ -2,7 +2,7 @@ import plotly.graph_objects as go import typing -from policyengine import Simulation +from policyengine import PolicyEngine from pydantic import BaseModel from policyengine.utils.charts import * @@ -27,7 +27,7 @@ def create_winners_losers_chart( - simulation: Simulation, + engine: PolicyEngine, decile_variable: Literal["income", "wealth"], ) -> go.Figure: """Create a budget comparison chart.""" diff --git a/policyengine/outputs/macro/comparison/decile.py b/policyengine/outputs/macro/comparison/decile.py index 9c84b295..929a20ee 100644 --- a/policyengine/outputs/macro/comparison/decile.py +++ b/policyengine/outputs/macro/comparison/decile.py @@ -1,6 +1,6 @@ import typing -from policyengine import Simulation, SimulationOptions +from policyengine import PolicyEngine from policyengine_core.simulations import Microsimulation diff --git a/policyengine/outputs/macro/single/budget.py b/policyengine/outputs/macro/single/budget.py index 468ec6c8..6d7cde47 100644 --- a/policyengine/outputs/macro/single/budget.py +++ b/policyengine/outputs/macro/single/budget.py @@ -1,6 +1,6 @@ import typing -from policyengine import Simulation, SimulationOptions +from policyengine import PolicyEngine from policyengine_core.simulations import Microsimulation diff --git a/policyengine/outputs/macro/single/calculate_average_earnings.py b/policyengine/outputs/macro/single/calculate_average_earnings.py index 764071e1..770b412d 100644 --- a/policyengine/outputs/macro/single/calculate_average_earnings.py +++ b/policyengine/outputs/macro/single/calculate_average_earnings.py @@ -1,7 +1,7 @@ -from policyengine import Simulation +from policyengine import PolicyEngine -def calculate_average_earnings(simulation: Simulation) -> float: +def calculate_average_earnings(engine: PolicyEngine) -> float: """Calculate average earnings.""" employment_income = simulation.baseline_simulation.calculate( "employment_income" diff --git a/policyengine/outputs/macro/single/calculate_single_economy.py b/policyengine/outputs/macro/single/calculate_single_economy.py index 206fd489..714dbfce 100644 --- a/policyengine/outputs/macro/single/calculate_single_economy.py +++ b/policyengine/outputs/macro/single/calculate_single_economy.py @@ -2,7 +2,7 @@ import typing -from policyengine import Simulation +from policyengine import PolicyEngine from pydantic import BaseModel @@ -19,7 +19,7 @@ class SingleEconomy(BaseModel): def calculate_single_economy( - simulation: Simulation, + engine: PolicyEngine, ) -> SingleEconomy: """Calculate economy statistics for a single economic scenario.""" options = simulation.options diff --git a/policyengine/outputs/macro/single/inequality.py b/policyengine/outputs/macro/single/inequality.py index 827a9b6a..79e6e5f1 100644 --- a/policyengine/outputs/macro/single/inequality.py +++ b/policyengine/outputs/macro/single/inequality.py @@ -1,6 +1,6 @@ import typing -from policyengine import Simulation, SimulationOptions +from policyengine import PolicyEngine from policyengine_core.simulations import Microsimulation diff --git a/policyengine/policyengine.py b/policyengine/policyengine.py new file mode 100644 index 00000000..93eab719 --- /dev/null +++ b/policyengine/policyengine.py @@ -0,0 +1,291 @@ +"""Simulate tax-benefit policy and derive society-level output statistics.""" + +from pydantic import BaseModel, Field +from typing import Literal +from .constants import DEFAULT_DATASETS_BY_COUNTRY +from policyengine_core.simulations import Simulation +from policyengine_core.simulations import ( + Microsimulation as CountryMicrosimulation, +) +from .utils.reforms import ParametricReform, SimulationAdjustment +from policyengine_core.reforms import Reform as StructuralReform +from policyengine_core.data import Dataset +from .utils.huggingface import download +from policyengine_us import ( + Simulation as USSimulation, + Microsimulation as USMicrosimulation, +) +from policyengine_uk import ( + Simulation as UKSimulation, + Microsimulation as UKMicrosimulation, +) +import h5py +from pathlib import Path +import pandas as pd +from typing import Type +from functools import wraps, partial +from typing import Dict, Any, Callable +import importlib +from policyengine.utils.types import * + + +class SimulationOptions(BaseModel): + country: CountryType = Field(..., description="The country to simulate.") + scope: ScopeType = Field(..., description="The scope of the simulation.") + data: DataType = Field(None, description="The data to simulate.") + time_period: TimePeriodType = Field( + 2025, description="The time period to simulate.", ge=2024, le=2035 + ) + reform: PolicyType = Field(None, description="The reform to simulate.") + baseline: PolicyType = Field(None, description="The baseline to simulate.") + region: RegionType = Field( + None, description="The region to simulate within the country." + ) + subsample: SubsampleType = Field( + None, + description="How many, if a subsample, households to randomly simulate.", + ) + title: str | None = Field( + "[Analysis title]", + description="The title of the analysis (for charts). If not provided, a default title will be generated.", + ) + + +class PolicyEngine: + """Simulate tax-benefit policies and derive society-level output statistics.""" + simulations: Dict[str, Simulation] = {} + + def __init__(self, **options: SimulationOptions): + self._add_output_functions() + + + def _add_output_functions(self): + folder = Path(__file__).parent / "outputs" + + for module in folder.glob("**/*.py"): + if module.stem == "__init__": + continue + python_module = ( + module.relative_to(folder.parent) + .with_suffix("") + .as_posix() + .replace("/", ".") + ) + module = importlib.import_module("policyengine." + python_module) + for name in dir(module): + func = getattr(module, name) + if isinstance(func, Callable): + if hasattr(func, "__annotations__"): + if ( + func.__annotations__.get("engine") + == PolicyEngine + ): + wrapped_func = wraps(func)( + partial(func, engine=self) + ) + wrapped_func.__annotations__ = func.__annotations__ + setattr( + self, + func.__name__, + wrapped_func, + ) + + def expect_simulation( + self, + name: str, + country: CountryType, + scope: ScopeType, + policy: PolicyType, + data: DataType, + time_period: TimePeriodType, + region: RegionType, + subsample: SubsampleType, + ) -> Simulation: + if name in self.simulations: + return self.simulations[name] + else: + simulation = self.build_simulation( + country=country, + scope=scope, + policy=policy, + data=data, + time_period=time_period, + region=region, + subsample=subsample, + ) + self.simulations[name] = simulation + return simulation + + def build_simulation( + self, + country: CountryType, + scope: ScopeType, + policy: PolicyType, + data: DataType, + time_period: TimePeriodType, + region: RegionType, + subsample: SubsampleType, + ): + macro = scope == "macro" + _simulation_type: Type[Simulation] = { + "uk": { + True: UKMicrosimulation, + False: UKSimulation, + }, + "us": { + True: USMicrosimulation, + False: USSimulation, + }, + }[country][macro] + + data = _data_handle_cps_special_case(data) + + simulation: Simulation = _simulation_type( + dataset=data if macro else None, + situation=data if not macro else None, + reform=policy, + ) + + simulation.default_calculation_period = time_period + + if region is not None: + simulation = _apply_region_to_simulation( + country=country, + simulation=simulation, + simulation_type=_simulation_type, + region=region, + policy=policy, + time_period=time_period, + ) + + if subsample is not None: + simulation = simulation.subsample(subsample) + + return simulation + +def _apply_region_to_simulation( + country: CountryType, + simulation: CountryMicrosimulation, + simulation_type: type, + region: RegionType, + policy: PolicyType | None, +) -> Simulation: + if country == "us": + df = simulation.to_input_dataframe() + state_code = simulation.calculate( + "state_code_str", map_to="person" + ).values + if region == "city/nyc": + in_nyc = simulation.calculate("in_nyc", map_to="person").values + simulation = simulation_type(dataset=df[in_nyc], reform=policy) + elif "state/" in region: + state = region.split("/")[1] + simulation = simulation_type( + dataset=df[state_code == state.upper()], reform=policy + ) + elif country == "uk": + if "country/" in region: + region = region.split("/")[1] + df = simulation.to_input_dataframe() + country = simulation.calculate( + "country", map_to="person" + ).values + simulation = simulation_type( + dataset=df[country == region.upper()], reform=policy + ) + elif "constituency/" in region: + constituency = region.split("/")[1] + constituency_names_file_path = download( + repo="policyengine/policyengine-uk-data", + repo_filename="constituencies_2024.csv", + local_folder=None, + version=None, + ) + constituency_names_file_path = Path( + constituency_names_file_path + ) + constituency_names = pd.read_csv(constituency_names_file_path) + if constituency in constituency_names.code.values: + constituency_id = constituency_names[ + constituency_names.code == constituency + ].index[0] + elif constituency in constituency_names.name.values: + constituency_id = constituency_names[ + constituency_names.name == constituency + ].index[0] + else: + raise ValueError( + f"Constituency {constituency} not found. See {constituency_names_file_path} for the list of available constituencies." + ) + weights_file_path = download( + repo="policyengine/policyengine-uk-data", + repo_filename="parliamentary_constituency_weights.h5", + local_folder=None, + version=None, + ) + + with h5py.File(weights_file_path, "r") as f: + weights = f["2025"][...] + + simulation.set_input( + "household_weight", + "2025", + weights[constituency_id], + ) + elif "local_authority/" in region: + la = region.split("/")[1] + la_names_file_path = download( + repo="policyengine/policyengine-uk-data", + repo_filename="local_authorities_2021.csv", + local_folder=None, + version=None, + ) + la_names_file_path = Path(la_names_file_path) + la_names = pd.read_csv(la_names_file_path) + if la in la_names.code.values: + la_id = la_names[la_names.code == la].index[0] + elif la in la_names.name.values: + la_id = la_names[la_names.name == la].index[0] + else: + raise ValueError( + f"Local authority {la} not found. See {la_names_file_path} for the list of available local authorities." + ) + weights_file_path = download( + repo="policyengine/policyengine-uk-data", + repo_filename="local_authority_weights.h5", + local_folder=None, + version=None, + ) + + with h5py.File(weights_file_path, "r") as f: + weights = f["2025"][...] + + simulation.set_input( + "household_weight", + "2025", + weights[la_id], + ) + + return simulation + +def _data_handle_cps_special_case( + data: DataType, +): + """Handle special case for CPS data- this data doesn't specify time periods for each variable, but we still use it intensively.""" + if data is not None and "cps_2023" in data: + if "hf://" in data: + owner, repo, filename = data.split("/")[-3:] + if "@" in filename: + version = filename.split("@")[-1] + filename = filename.split("@")[0] + else: + version = None + data = download( + repo=owner + "/" + repo, + repo_filename=filename, + local_folder=None, + version=version, + ) + data = Dataset.from_file(data, "2023") + + return data diff --git a/policyengine/simulation.py b/policyengine/simulation.py deleted file mode 100644 index bfadd921..00000000 --- a/policyengine/simulation.py +++ /dev/null @@ -1,330 +0,0 @@ -"""Simulate tax-benefit policy and derive society-level output statistics.""" - -from pydantic import BaseModel, Field -from typing import Literal -from .constants import DEFAULT_DATASETS_BY_COUNTRY -from policyengine_core.simulations import Simulation as CountrySimulation -from policyengine_core.simulations import ( - Microsimulation as CountryMicrosimulation, -) -from .utils.reforms import ParametricReform, SimulationAdjustment -from policyengine_core.reforms import Reform as StructuralReform -from policyengine_core.data import Dataset -from .utils.huggingface import download -from policyengine_us import ( - Simulation as USSimulation, - Microsimulation as USMicrosimulation, -) -from policyengine_uk import ( - Simulation as UKSimulation, - Microsimulation as UKMicrosimulation, -) -import h5py -from pathlib import Path -import pandas as pd -from typing import Type -from functools import wraps, partial -from typing import Dict, Any, Callable -import importlib - -CountryType = Literal["uk", "us"] -ScopeType = Literal["household", "macro"] -DataType = ( - str | dict | Any | None -) # Needs stricter typing. Any==policyengine_core.data.Dataset, but pydantic refuses for some reason. -TimePeriodType = int -ReformType = ( - ParametricReform | SimulationAdjustment | Type[StructuralReform] | None -) -RegionType = str | None -SubsampleType = int | None - - -class SimulationOptions(BaseModel): - country: CountryType = Field(..., description="The country to simulate.") - scope: ScopeType = Field(..., description="The scope of the simulation.") - data: DataType = Field(None, description="The data to simulate.") - time_period: TimePeriodType = Field( - 2025, description="The time period to simulate.", ge=2024, le=2035 - ) - reform: ReformType = Field(None, description="The reform to simulate.") - baseline: ReformType = Field(None, description="The baseline to simulate.") - region: RegionType = Field( - None, description="The region to simulate within the country." - ) - subsample: SubsampleType = Field( - None, - description="How many, if a subsample, households to randomly simulate.", - ) - title: str | None = Field( - "[Analysis title]", - description="The title of the analysis (for charts). If not provided, a default title will be generated.", - ) - - -class Simulation: - """Simulate tax-benefit policy and derive society-level output statistics.""" - - is_comparison: bool - """Whether the simulation is a comparison between two scenarios.""" - baseline_simulation: CountrySimulation - """The baseline tax-benefit simulation.""" - reform_simulation: CountrySimulation | None = None - """The reform tax-benefit simulation.""" - - def __init__(self, **options: SimulationOptions): - self.options = SimulationOptions(**options) - - if self.options.data is None: - self.options.data = DEFAULT_DATASETS_BY_COUNTRY[ - self.options.country - ] - - self._initialise_simulations() - self._add_output_functions() - - def _add_output_functions(self): - folder = Path(__file__).parent / "outputs" - - for module in folder.glob("**/*.py"): - if module.stem == "__init__": - continue - python_module = ( - module.relative_to(folder.parent) - .with_suffix("") - .as_posix() - .replace("/", ".") - ) - module = importlib.import_module("policyengine." + python_module) - for name in dir(module): - func = getattr(module, name) - if isinstance(func, Callable): - if hasattr(func, "__annotations__"): - if ( - func.__annotations__.get("simulation") - == Simulation - ): - wrapped_func = wraps(func)( - partial(func, simulation=self) - ) - wrapped_func.__annotations__ = func.__annotations__ - setattr( - self, - func.__name__, - wrapped_func, - ) - - def _set_data(self): - if self.options.data is None: - self.options.data = DEFAULT_DATASETS_BY_COUNTRY[ - self.options.country - ] - - self._data_handle_cps_special_case() - - def _initialise_simulations(self): - self.baseline_simulation = self._initialise_simulation( - scope=self.options.scope, - country=self.options.country, - reform=self.options.baseline, - data=self.options.data, - time_period=self.options.time_period, - region=self.options.region, - subsample=self.options.subsample, - ) - - if self.options.reform is not None: - self.reform_simulation = self._initialise_simulation( - scope=self.options.scope, - country=self.options.country, - reform=self.options.reform, - data=self.options.data, - time_period=self.options.time_period, - region=self.options.region, - subsample=self.options.subsample, - ) - self.is_comparison = True - else: - self.is_comparison = False - - def _initialise_simulation( - self, - country: CountryType, - scope: ScopeType, - reform: ReformType, - data: DataType, - time_period: TimePeriodType, - region: RegionType, - subsample: SubsampleType, - ): - macro = scope == "macro" - _simulation_type: Type[CountrySimulation] = { - "uk": { - True: UKMicrosimulation, - False: UKSimulation, - }, - "us": { - True: USMicrosimulation, - False: USSimulation, - }, - }[country][macro] - - if isinstance(reform, ParametricReform): - reform = reform.model_dump() - - simulation_editing_reform = None - - if isinstance(reform, SimulationAdjustment): - simulation_editing_reform = reform.root - reform = None - - simulation: CountrySimulation = _simulation_type( - dataset=data if macro else None, - situation=data if not macro else None, - reform=reform, - ) - - simulation.default_calculation_period = time_period - - if region is not None: - simulation = self._apply_region_to_simulation( - country=country, - simulation=simulation, - simulation_type=_simulation_type, - region=region, - reform=reform, - time_period=time_period, - ) - - if subsample is not None: - simulation = simulation.subsample(subsample) - - if simulation_editing_reform is not None: - simulation_editing_reform(simulation) - - return simulation - - def _apply_region_to_simulation( - self, - country: CountryType, - simulation: CountryMicrosimulation, - simulation_type: type, - region: RegionType, - reform: ReformType | None, - time_period: TimePeriodType, - ) -> CountrySimulation: - if country == "us": - df = simulation.to_input_dataframe() - state_code = simulation.calculate( - "state_code_str", map_to="person" - ).values - if region == "city/nyc": - in_nyc = simulation.calculate("in_nyc", map_to="person").values - simulation = simulation_type(dataset=df[in_nyc], reform=reform) - elif "state/" in region: - state = region.split("/")[1] - simulation = simulation_type( - dataset=df[state_code == state.upper()], reform=reform - ) - elif country == "uk": - if "country/" in region: - region = region.split("/")[1] - df = simulation.to_input_dataframe() - country = simulation.calculate( - "country", map_to="person" - ).values - simulation = simulation_type( - dataset=df[country == region.upper()], reform=reform - ) - elif "constituency/" in region: - constituency = region.split("/")[1] - constituency_names_file_path = download( - repo="policyengine/policyengine-uk-data", - repo_filename="constituencies_2024.csv", - local_folder=None, - version=None, - ) - constituency_names_file_path = Path( - constituency_names_file_path - ) - constituency_names = pd.read_csv(constituency_names_file_path) - if constituency in constituency_names.code.values: - constituency_id = constituency_names[ - constituency_names.code == constituency - ].index[0] - elif constituency in constituency_names.name.values: - constituency_id = constituency_names[ - constituency_names.name == constituency - ].index[0] - else: - raise ValueError( - f"Constituency {constituency} not found. See {constituency_names_file_path} for the list of available constituencies." - ) - weights_file_path = download( - repo="policyengine/policyengine-uk-data", - repo_filename="parliamentary_constituency_weights.h5", - local_folder=None, - version=None, - ) - - with h5py.File(weights_file_path, "r") as f: - weights = f[str(time_period)][...] - - simulation.set_input( - "household_weight", - simulation.default_calculation_period, - weights[constituency_id], - ) - elif "local_authority/" in region: - la = region.split("/")[1] - la_names_file_path = download( - repo="policyengine/policyengine-uk-data", - repo_filename="local_authorities_2021.csv", - local_folder=None, - version=None, - ) - la_names_file_path = Path(la_names_file_path) - la_names = pd.read_csv(la_names_file_path) - if la in la_names.code.values: - la_id = la_names[la_names.code == la].index[0] - elif la in la_names.name.values: - la_id = la_names[la_names.name == la].index[0] - else: - raise ValueError( - f"Local authority {la} not found. See {la_names_file_path} for the list of available local authorities." - ) - weights_file_path = download( - repo="policyengine/policyengine-uk-data", - repo_filename="local_authority_weights.h5", - local_folder=None, - version=None, - ) - - with h5py.File(weights_file_path, "r") as f: - weights = f[str(self.time_period)][...] - - simulation.set_input( - "household_weight", - simulation.default_calculation_period, - weights[la_id], - ) - - return simulation - - def _data_handle_cps_special_case(self): - """Handle special case for CPS data- this data doesn't specify time periods for each variable, but we still use it intensively.""" - if self.data is not None and "cps_2023" in self.data: - if "hf://" in self.data: - owner, repo, filename = self.data.split("/")[-3:] - if "@" in filename: - version = filename.split("@")[-1] - filename = filename.split("@")[0] - else: - version = None - self.data = download( - repo=owner + "/" + repo, - repo_filename=filename, - local_folder=None, - version=version, - ) - self.data = Dataset.from_file(self.data, "2023") diff --git a/policyengine/utils/types.py b/policyengine/utils/types.py new file mode 100644 index 00000000..9e6c4fcc --- /dev/null +++ b/policyengine/utils/types.py @@ -0,0 +1,15 @@ +from .reforms import ParametricReform +from policyengine_core.reforms import Reform as StructuralReform +from typing import Type, Literal, Any + +CountryType = Literal["uk", "us"] +ScopeType = Literal["household", "macro"] +DataType = ( + str | dict | Any | None +) # Needs stricter typing. Any==policyengine_core.data.Dataset, but pydantic refuses for some reason. +TimePeriodType = int +PolicyType = ( + ParametricReform | Type[StructuralReform] | None +) +RegionType = str | None +SubsampleType = int | None \ No newline at end of file