Skip to content

Commit c735db0

Browse files
committed
feat: Redo data setting in simulation
1 parent da9bc17 commit c735db0

File tree

5 files changed

+239
-59
lines changed

5 files changed

+239
-59
lines changed

policyengine/constants.py

Lines changed: 0 additions & 27 deletions
This file was deleted.

policyengine/simulation.py

Lines changed: 75 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
"""Simulate tax-benefit policy and derive society-level output statistics."""
22

3+
import sys
34
from pydantic import BaseModel, Field
45
from typing import Literal
5-
from .constants import get_default_dataset
6+
from .utils.data.datasets import get_default_dataset, process_gs_path, POLICYENGINE_DATASETS, DATASET_TIME_PERIODS
67
from policyengine_core.simulations import Simulation as CountrySimulation
78
from policyengine_core.simulations import (
89
Microsimulation as CountryMicrosimulation,
@@ -31,8 +32,8 @@
3132
CountryType = Literal["uk", "us"]
3233
ScopeType = Literal["household", "macro"]
3334
DataType = (
34-
str | dict | Any | None
35-
) # Needs stricter typing. Any==policyengine_core.data.Dataset, but pydantic refuses for some reason.
35+
str | Dataset | None
36+
)
3637
TimePeriodType = int
3738
ReformType = ParametricReform | Type[StructuralReform] | None
3839
RegionType = Optional[str]
@@ -72,6 +73,10 @@ class SimulationOptions(BaseModel):
7273
description="The version of the data used in the simulation. If not provided, the current data version will be used. If provided, this package will throw an error if the data version does not match. Use this as an extra safety check.",
7374
)
7475

76+
model_config = {
77+
"arbitrary_types_allowed": True,
78+
}
79+
7580

7681
class Simulation:
7782
"""Simulate tax-benefit policy and derive society-level output statistics."""
@@ -89,7 +94,8 @@ class Simulation:
8994
def __init__(self, **options: SimulationOptions):
9095
self.options = SimulationOptions(**options)
9196
self.check_model_version()
92-
self._set_data()
97+
if not isinstance(self.options.data, Dataset):
98+
self._set_data(self.options.data)
9399
self._initialise_simulations()
94100
self.check_data_version()
95101
self._add_output_functions()
@@ -125,39 +131,42 @@ def _add_output_functions(self):
125131
wrapped_func,
126132
)
127133

128-
def _set_data(self):
129-
if self.options.data is None:
130-
self.options.data = get_default_dataset(
134+
def _set_data(self, file_address: str | None = None) -> None:
135+
136+
# filename refers to file's unique name + extension;
137+
# file_address refers to URI + filename
138+
139+
# If None is passed, user wants default dataset; get URL, then continue initializing.
140+
if file_address is None:
141+
file_address = get_default_dataset(
131142
country=self.options.country,
132-
region=self.options.region,
143+
region=self.options.region
144+
)
145+
print(
146+
f"No data provided, using default dataset: {file_address}",
147+
file=sys.stderr,
133148
)
134149

135-
if isinstance(self.options.data, str):
136-
filename = self.options.data
137-
if self.options.data[:6] == "gcs://":
138-
bucket, filename = self.options.data.split("://")[-1].split(
139-
"/"
140-
)
141-
version = self.options.data_version
150+
if file_address not in POLICYENGINE_DATASETS:
151+
# If it's a local file, no URI present and unable to infer version.
152+
filename = file_address
153+
version = None
142154

143-
file_path, version = download(
144-
filepath=filename,
145-
gcs_bucket=bucket,
146-
version=version,
147-
return_version=True,
148-
)
149-
self.data_version = version
150-
filename = str(Path(file_path))
151-
else:
152-
# If it's a local file, we can't infer the version.
153-
version = None
154-
if "cps_2023" in filename:
155-
time_period = 2023
156-
else:
157-
time_period = None
158-
self.options.data = Dataset.from_file(
159-
filename, time_period=time_period
155+
else:
156+
# All official PolicyEngine datasets are stored in GCS;
157+
# load accordingly
158+
filename, version = self._set_data_from_gs(
159+
file_address
160160
)
161+
self.data_version = version
162+
163+
time_period = self._set_data_time_period(
164+
file_address
165+
)
166+
167+
self.options.data = Dataset.from_file(
168+
filename, time_period=time_period
169+
)
161170

162171
def _initialise_simulations(self):
163172
self.baseline_simulation = self._initialise_simulation(
@@ -361,3 +370,37 @@ def check_data_version(self) -> None:
361370
raise ValueError(
362371
f"Data version {self.data_version} does not match expected version {self.options.data_version}."
363372
)
373+
374+
def _set_data_time_period(self, file_address: str) -> Optional[int]:
375+
"""
376+
Set the time period based on the file address.
377+
If the file address is a PE dataset, return the time period from the dataset.
378+
If it's a local file, return None.
379+
"""
380+
if file_address in DATASET_TIME_PERIODS:
381+
return DATASET_TIME_PERIODS[file_address]
382+
else:
383+
# Local file, no time period available
384+
return None
385+
386+
def _set_data_from_gs(
387+
self, file_address: str
388+
) -> tuple[str, str | None]:
389+
"""
390+
Set the data from a GCS path and return the filename and version.
391+
"""
392+
393+
bucket, filename = process_gs_path(file_address)
394+
version = self.options.data_version
395+
396+
print(f"Downloading {filename} from bucket {bucket}", file=sys.stderr)
397+
398+
filepath, version = download(
399+
filepath=filename,
400+
gcs_bucket=bucket,
401+
version=version,
402+
return_version=True,
403+
)
404+
405+
return filename, version
406+
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
"""Mainly simulation options and parameters."""
2+
3+
from typing import Tuple, Optional
4+
5+
EFRS_2022 = "gs://policyengine-uk-data-private/enhanced_frs_2022_23.h5"
6+
FRS_2022 = "gs://policyengine-uk-data-private/frs_2022_23.h5"
7+
CPS_2023 = "gs://policyengine-us-data/cps_2023.h5"
8+
CPS_2023_POOLED = "gs://policyengine-us-data/pooled_3_year_cps_2023.h5"
9+
ECPS_2024 = "gs://policyengine-us-data/ecps_2024.h5"
10+
11+
POLICYENGINE_DATASETS = [
12+
EFRS_2022,
13+
FRS_2022,
14+
CPS_2023,
15+
CPS_2023_POOLED,
16+
ECPS_2024,
17+
]
18+
19+
# Contains datasets that map to particular time_period values
20+
DATASET_TIME_PERIODS = {
21+
CPS_2023: 2023,
22+
CPS_2023_POOLED: 2023,
23+
ECPS_2024: 2023,
24+
}
25+
26+
def get_default_dataset(
27+
country: str, region: str, version: Optional[str] = None
28+
) -> str:
29+
if country == "uk":
30+
return EFRS_2022
31+
elif country == "us":
32+
if region is not None and region != "us":
33+
return CPS_2023_POOLED
34+
else:
35+
return CPS_2023
36+
37+
raise ValueError(
38+
f"Unable to select a default dataset for country {country} and region {region}."
39+
)
40+
41+
def process_gs_path(path: str) -> Tuple[str, str]:
42+
"""Process a GS path to return bucket and object."""
43+
if not path.startswith("gs://"):
44+
raise ValueError(f"Invalid GS path: {path}")
45+
46+
path = path[5:] # Remove 'gs://'
47+
bucket, obj = path.split("/", 1)
48+
return bucket, obj

tests/fixtures/simulation.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from policyengine.simulation import SimulationOptions
2+
from unittest.mock import patch, Mock
3+
import pytest
4+
from policyengine.utils.data.datasets import CPS_2023
5+
6+
non_data_uk_sim_options = {
7+
"country": "uk",
8+
"scope": "macro",
9+
"region": "uk",
10+
"time_period": 2025,
11+
"reform": None,
12+
"baseline": None,
13+
}
14+
15+
non_data_us_sim_options = {
16+
"country": "us",
17+
"scope": "macro",
18+
"region": "us",
19+
"time_period": 2025,
20+
"reform": None,
21+
"baseline": None,
22+
}
23+
24+
uk_sim_options_no_data = SimulationOptions.model_validate({
25+
**non_data_uk_sim_options,
26+
"data": None,
27+
})
28+
29+
us_sim_options_cps_dataset = SimulationOptions.model_validate({
30+
**non_data_us_sim_options,
31+
"data": CPS_2023
32+
})
33+
34+
SAMPLE_DATASET_FILENAME = "sample_value.h5"
35+
SAMPLE_DATASET_BUCKET_NAME = "policyengine-uk-data-private"
36+
SAMPLE_DATASET_URI_PREFIX = "gs://"
37+
SAMPLE_DATASET_FILE_ADDRESS = f"{SAMPLE_DATASET_URI_PREFIX}{SAMPLE_DATASET_BUCKET_NAME}/{SAMPLE_DATASET_FILENAME}"
38+
39+
uk_sim_options_pe_dataset = SimulationOptions.model_validate({
40+
**non_data_uk_sim_options,
41+
"data": SAMPLE_DATASET_FILE_ADDRESS
42+
})
43+
44+
@pytest.fixture
45+
def mock_get_default_dataset():
46+
with patch(
47+
"policyengine.simulation.get_default_dataset",
48+
return_value=SAMPLE_DATASET_FILE_ADDRESS
49+
) as mock_get_default_dataset:
50+
yield mock_get_default_dataset
51+
52+
@pytest.fixture
53+
def mock_dataset():
54+
"""Simple Dataset mock fixture"""
55+
with patch('policyengine.simulation.Dataset') as mock_dataset_class:
56+
mock_instance = Mock()
57+
# Set file_path to mimic Dataset's behavior of clipping URI and bucket name from GCS paths
58+
mock_instance.from_file = Mock()
59+
mock_instance.file_path = SAMPLE_DATASET_FILENAME
60+
mock_dataset_class.from_file.return_value = mock_instance
61+
yield mock_instance

tests/test_simulation.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
from .fixtures.simulation import (
2+
uk_sim_options_no_data,
3+
uk_sim_options_pe_dataset,
4+
us_sim_options_cps_dataset,
5+
mock_get_default_dataset,
6+
mock_dataset,
7+
SAMPLE_DATASET_FILENAME
8+
)
9+
import sys
10+
from copy import deepcopy
11+
12+
from policyengine import Simulation
13+
14+
class TestSimulation:
15+
class TestSetData:
16+
def test__given_no_data_option__sets_default_dataset(self, mock_get_default_dataset, mock_dataset):
17+
18+
# Don't run entire init script
19+
sim = object.__new__(Simulation)
20+
sim.options = deepcopy(uk_sim_options_no_data)
21+
sim._set_data(uk_sim_options_no_data.data)
22+
23+
assert str(sim.options.data.file_path) == SAMPLE_DATASET_FILENAME
24+
def test__given_pe_dataset__sets_data_option_to_dataset(self, mock_dataset):
25+
26+
sim = object.__new__(Simulation)
27+
sim.options = deepcopy(uk_sim_options_pe_dataset)
28+
sim._set_data(uk_sim_options_pe_dataset.data)
29+
30+
assert str(sim.options.data.file_path) == SAMPLE_DATASET_FILENAME
31+
def test__given_cps_2023_in_filename__sets_time_period_to_2023(self, mock_dataset):
32+
from policyengine import Simulation
33+
34+
sim = object.__new__(Simulation)
35+
sim.options = deepcopy(us_sim_options_cps_dataset)
36+
sim._set_data(us_sim_options_cps_dataset.data)
37+
38+
assert mock_dataset.from_file.called_with(
39+
us_sim_options_cps_dataset.data,
40+
time_period=2023
41+
)
42+
class TestSetDataTimePeriod:
43+
def test__given_dataset_with_time_period__sets_time_period(self):
44+
from policyengine import Simulation
45+
46+
sim = object.__new__(Simulation)
47+
48+
print("Dataset:", us_sim_options_cps_dataset.data, file=sys.stderr)
49+
assert sim._set_data_time_period(us_sim_options_cps_dataset.data) == 2023
50+
51+
def test__given_dataset_without_time_period__does_not_set_time_period(self):
52+
from policyengine import Simulation
53+
54+
sim = object.__new__(Simulation)
55+
assert sim._set_data_time_period(uk_sim_options_pe_dataset.data) == None

0 commit comments

Comments
 (0)