Skip to content

Commit ccfb1e4

Browse files
Bug prevents state tax calculation in some cases
Fixes #113
1 parent 8455b68 commit ccfb1e4

File tree

4 files changed

+47
-15
lines changed

4 files changed

+47
-15
lines changed

changelog_entry.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
- bump: patch
2+
changes:
3+
fixed:
4+
- Bug in state tax revenue calculation.
5+
added:
6+
- Default dataset handling (extra backups added).

policyengine/constants.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
"""Mainly simulation options and parameters."""
22

3+
from policyengine_core.data import Dataset
4+
from policyengine.utils.data_download import download
5+
36
# Datasets
47

58
ENHANCED_FRS = "hf://policyengine/policyengine-uk-data/enhanced_frs_2022_23.h5"
@@ -8,7 +11,31 @@
811
CPS = "hf://policyengine/policyengine-us-data/cps_2023.h5"
912
POOLED_CPS = "hf://policyengine/policyengine-us-data/pooled_3_year_cps_2023.h5"
1013

11-
DEFAULT_DATASETS_BY_COUNTRY = {
12-
"uk": ENHANCED_FRS,
13-
"us": CPS,
14-
}
14+
def get_default_dataset(country: str, region: str):
15+
if country == "uk":
16+
data_file = download(
17+
filepath="enhanced_frs_2022_23.h5",
18+
huggingface_repo="policyengine-uk-data",
19+
gcs_bucket="policyengine-uk-data-private",
20+
)
21+
time_period = None
22+
elif country == "us":
23+
if region is not None and region != "us":
24+
data_file = download(
25+
filepath="pooled_3_year_cps_2023.h5",
26+
huggingface_repo="policyengine-us-data",
27+
gcs_bucket="policyengine-us-data",
28+
)
29+
time_period = 2023
30+
else:
31+
data_file = download(
32+
filepath="cps_2023.h5",
33+
huggingface_repo="policyengine-us-data",
34+
gcs_bucket="policyengine-us-data",
35+
)
36+
time_period = 2023
37+
38+
return Dataset.from_file(
39+
file_path=data_file,
40+
time_period=time_period,
41+
)

policyengine/outputs/macro/single/calculate_single_economy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ def calculate_single_economy(
376376

377377
if country_id == "us":
378378
try:
379-
total_state_tax = simulation.calculate(
379+
total_state_tax = task_manager.simulation.calculate(
380380
"household_state_income_tax"
381381
).sum()
382382
except:

policyengine/simulation.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from pydantic import BaseModel, Field
44
from typing import Literal
5-
from .constants import DEFAULT_DATASETS_BY_COUNTRY
5+
from .constants import get_default_dataset
66
from policyengine_core.simulations import Simulation as CountrySimulation
77
from policyengine_core.simulations import (
88
Microsimulation as CountryMicrosimulation,
@@ -73,11 +73,6 @@ class Simulation:
7373
def __init__(self, **options: SimulationOptions):
7474
self.options = SimulationOptions(**options)
7575

76-
if self.options.data is None:
77-
self.options.data = DEFAULT_DATASETS_BY_COUNTRY[
78-
self.options.country
79-
]
80-
8176
self._set_data()
8277
self._initialise_simulations()
8378
self._add_output_functions()
@@ -115,11 +110,12 @@ def _add_output_functions(self):
115110

116111
def _set_data(self):
117112
if self.options.data is None:
118-
self.options.data = DEFAULT_DATASETS_BY_COUNTRY[
119-
self.options.country
120-
]
113+
self.options.data = get_default_dataset(
114+
country=self.options.country,
115+
region=self.options.region,
116+
)
121117

122-
if isinstance(self.options.data, str):
118+
elif isinstance(self.options.data, str):
123119
filename = self.options.data
124120
if "://" in self.options.data:
125121
bucket = None
@@ -129,6 +125,7 @@ def _set_data(self):
129125
bucket, filename = self.options.data.split("://")[
130126
-1
131127
].split("/")
128+
hf_org = "policyengine"
132129
elif "hf://" in self.options.data:
133130
hf_org, hf_repo, filename = self.options.data.split("://")[
134131
-1
@@ -221,6 +218,8 @@ def _initialise_simulation(
221218
if subsample is not None:
222219
simulation = simulation.subsample(subsample)
223220

221+
simulation.default_calculation_period = time_period
222+
224223
return simulation
225224

226225
def _apply_region_to_simulation(

0 commit comments

Comments
 (0)