diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29b..53ad29f2 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,6 @@ +- bump: patch + changes: + fixed: + - Bug in state tax revenue calculation. + added: + - Default dataset handling (extra backups added). diff --git a/policyengine/constants.py b/policyengine/constants.py index 651de219..c6bca554 100644 --- a/policyengine/constants.py +++ b/policyengine/constants.py @@ -1,5 +1,8 @@ """Mainly simulation options and parameters.""" +from policyengine_core.data import Dataset +from policyengine.utils.data_download import download + # Datasets ENHANCED_FRS = "hf://policyengine/policyengine-uk-data/enhanced_frs_2022_23.h5" @@ -8,7 +11,32 @@ CPS = "hf://policyengine/policyengine-us-data/cps_2023.h5" POOLED_CPS = "hf://policyengine/policyengine-us-data/pooled_3_year_cps_2023.h5" -DEFAULT_DATASETS_BY_COUNTRY = { - "uk": ENHANCED_FRS, - "us": CPS, -} + +def get_default_dataset(country: str, region: str): + if country == "uk": + data_file = download( + filepath="enhanced_frs_2022_23.h5", + huggingface_repo="policyengine-uk-data", + gcs_bucket="policyengine-uk-data-private", + ) + time_period = None + elif country == "us": + if region is not None and region != "us": + data_file = download( + filepath="pooled_3_year_cps_2023.h5", + huggingface_repo="policyengine-us-data", + gcs_bucket="policyengine-us-data", + ) + time_period = 2023 + else: + data_file = download( + filepath="cps_2023.h5", + huggingface_repo="policyengine-us-data", + gcs_bucket="policyengine-us-data", + ) + time_period = 2023 + + return Dataset.from_file( + file_path=data_file, + time_period=time_period, + ) diff --git a/policyengine/outputs/macro/single/calculate_single_economy.py b/policyengine/outputs/macro/single/calculate_single_economy.py index 7a21133a..3ed9f517 100644 --- a/policyengine/outputs/macro/single/calculate_single_economy.py +++ b/policyengine/outputs/macro/single/calculate_single_economy.py @@ -376,7 +376,7 @@ def calculate_single_economy( if country_id == "us": try: - total_state_tax = simulation.calculate( + total_state_tax = task_manager.simulation.calculate( "household_state_income_tax" ).sum() except: diff --git a/policyengine/simulation.py b/policyengine/simulation.py index c53d30f0..61f00acb 100644 --- a/policyengine/simulation.py +++ b/policyengine/simulation.py @@ -2,7 +2,7 @@ from pydantic import BaseModel, Field from typing import Literal -from .constants import DEFAULT_DATASETS_BY_COUNTRY +from .constants import get_default_dataset from policyengine_core.simulations import Simulation as CountrySimulation from policyengine_core.simulations import ( Microsimulation as CountryMicrosimulation, @@ -73,11 +73,6 @@ class Simulation: def __init__(self, **options: SimulationOptions): self.options = SimulationOptions(**options) - if self.options.data is None: - self.options.data = DEFAULT_DATASETS_BY_COUNTRY[ - self.options.country - ] - self._set_data() self._initialise_simulations() self._add_output_functions() @@ -115,11 +110,12 @@ def _add_output_functions(self): def _set_data(self): if self.options.data is None: - self.options.data = DEFAULT_DATASETS_BY_COUNTRY[ - self.options.country - ] + self.options.data = get_default_dataset( + country=self.options.country, + region=self.options.region, + ) - if isinstance(self.options.data, str): + elif isinstance(self.options.data, str): filename = self.options.data if "://" in self.options.data: bucket = None @@ -129,6 +125,7 @@ def _set_data(self): bucket, filename = self.options.data.split("://")[ -1 ].split("/") + hf_org = "policyengine" elif "hf://" in self.options.data: hf_org, hf_repo, filename = self.options.data.split("://")[ -1 @@ -221,6 +218,8 @@ def _initialise_simulation( if subsample is not None: simulation = simulation.subsample(subsample) + simulation.default_calculation_period = time_period + return simulation def _apply_region_to_simulation(