Skip to content

Commit aac814a

Browse files
authored
Merge pull request #155 from PolicyEngine/fix/correct-set-data
Redo data setting in simulation
2 parents da9bc17 + f1c9434 commit aac814a

File tree

6 files changed

+270
-62
lines changed

6 files changed

+270
-62
lines changed

changelog_entry.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
- bump: patch
2+
changes:
3+
changed:
4+
- Disambiguated filepath management in Simulation._set_data()
5+
- Refactored Simulation._set_data() to divide functionality into smaller methods
6+
- Prevented passage of non-Path URIs to Dataset.from_file() at end of Simulation._set_data() execution
7+
added:
8+
- Tests for Simulation._set_data()

policyengine/constants.py

Lines changed: 0 additions & 27 deletions
This file was deleted.

policyengine/simulation.py

Lines changed: 77 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
"""Simulate tax-benefit policy and derive society-level output statistics."""
22

3+
import sys
34
from pydantic import BaseModel, Field
45
from typing import Literal
5-
from .constants import get_default_dataset
6+
from .utils.data.datasets import (
7+
get_default_dataset,
8+
process_gs_path,
9+
POLICYENGINE_DATASETS,
10+
DATASET_TIME_PERIODS,
11+
)
612
from policyengine_core.simulations import Simulation as CountrySimulation
713
from policyengine_core.simulations import (
814
Microsimulation as CountryMicrosimulation,
@@ -22,16 +28,16 @@
2228
import h5py
2329
from pathlib import Path
2430
import pandas as pd
25-
from typing import Type, Optional
31+
from typing import Type, Any, Optional
2632
from functools import wraps, partial
27-
from typing import Dict, Any, Callable
33+
from typing import Callable
2834
import importlib
2935
from policyengine.utils.data_download import download
3036

3137
CountryType = Literal["uk", "us"]
3238
ScopeType = Literal["household", "macro"]
3339
DataType = (
34-
str | dict | Any | None
40+
str | dict[Any, Any] | Dataset | None
3541
) # Needs stricter typing. Any==policyengine_core.data.Dataset, but pydantic refuses for some reason.
3642
TimePeriodType = int
3743
ReformType = ParametricReform | Type[StructuralReform] | None
@@ -72,6 +78,10 @@ class SimulationOptions(BaseModel):
7278
description="The version of the data used in the simulation. If not provided, the current data version will be used. If provided, this package will throw an error if the data version does not match. Use this as an extra safety check.",
7379
)
7480

81+
model_config = {
82+
"arbitrary_types_allowed": True,
83+
}
84+
7585

7686
class Simulation:
7787
"""Simulate tax-benefit policy and derive society-level output statistics."""
@@ -89,7 +99,10 @@ class Simulation:
8999
def __init__(self, **options: SimulationOptions):
90100
self.options = SimulationOptions(**options)
91101
self.check_model_version()
92-
self._set_data()
102+
if not isinstance(self.options.data, dict) and not isinstance(
103+
self.options.data, Dataset
104+
):
105+
self._set_data(self.options.data)
93106
self._initialise_simulations()
94107
self.check_data_version()
95108
self._add_output_functions()
@@ -125,39 +138,37 @@ def _add_output_functions(self):
125138
wrapped_func,
126139
)
127140

128-
def _set_data(self):
129-
if self.options.data is None:
130-
self.options.data = get_default_dataset(
131-
country=self.options.country,
132-
region=self.options.region,
133-
)
141+
def _set_data(self, file_address: str | None = None) -> None:
134142

135-
if isinstance(self.options.data, str):
136-
filename = self.options.data
137-
if self.options.data[:6] == "gcs://":
138-
bucket, filename = self.options.data.split("://")[-1].split(
139-
"/"
140-
)
141-
version = self.options.data_version
143+
# filename refers to file's unique name + extension;
144+
# file_address refers to URI + filename
142145

143-
file_path, version = download(
144-
filepath=filename,
145-
gcs_bucket=bucket,
146-
version=version,
147-
return_version=True,
148-
)
149-
self.data_version = version
150-
filename = str(Path(file_path))
151-
else:
152-
# If it's a local file, we can't infer the version.
153-
version = None
154-
if "cps_2023" in filename:
155-
time_period = 2023
156-
else:
157-
time_period = None
158-
self.options.data = Dataset.from_file(
159-
filename, time_period=time_period
146+
# If None is passed, user wants default dataset; get URL, then continue initializing.
147+
if file_address is None:
148+
file_address = get_default_dataset(
149+
country=self.options.country, region=self.options.region
160150
)
151+
print(
152+
f"No data provided, using default dataset: {file_address}",
153+
file=sys.stderr,
154+
)
155+
156+
if file_address not in POLICYENGINE_DATASETS:
157+
# If it's a local file, no URI present and unable to infer version.
158+
filename = file_address
159+
version = None
160+
161+
else:
162+
# All official PolicyEngine datasets are stored in GCS;
163+
# load accordingly
164+
filename, version = self._set_data_from_gs(file_address)
165+
self.data_version = version
166+
167+
time_period = self._set_data_time_period(file_address)
168+
169+
self.options.data = Dataset.from_file(
170+
filename, time_period=time_period
171+
)
161172

162173
def _initialise_simulations(self):
163174
self.baseline_simulation = self._initialise_simulation(
@@ -361,3 +372,34 @@ def check_data_version(self) -> None:
361372
raise ValueError(
362373
f"Data version {self.data_version} does not match expected version {self.options.data_version}."
363374
)
375+
376+
def _set_data_time_period(self, file_address: str) -> Optional[int]:
377+
"""
378+
Set the time period based on the file address.
379+
If the file address is a PE dataset, return the time period from the dataset.
380+
If it's a local file, return None.
381+
"""
382+
if file_address in DATASET_TIME_PERIODS:
383+
return DATASET_TIME_PERIODS[file_address]
384+
else:
385+
# Local file, no time period available
386+
return None
387+
388+
def _set_data_from_gs(self, file_address: str) -> tuple[str, str | None]:
389+
"""
390+
Set the data from a GCS path and return the filename and version.
391+
"""
392+
393+
bucket, filename = process_gs_path(file_address)
394+
version = self.options.data_version
395+
396+
print(f"Downloading {filename} from bucket {bucket}", file=sys.stderr)
397+
398+
filepath, version = download(
399+
filepath=filename,
400+
gcs_bucket=bucket,
401+
version=version,
402+
return_version=True,
403+
)
404+
405+
return filename, version
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
"""Mainly simulation options and parameters."""
2+
3+
from typing import Tuple, Optional
4+
5+
EFRS_2022 = "gs://policyengine-uk-data-private/enhanced_frs_2022_23.h5"
6+
FRS_2022 = "gs://policyengine-uk-data-private/frs_2022_23.h5"
7+
CPS_2023 = "gs://policyengine-us-data/cps_2023.h5"
8+
CPS_2023_POOLED = "gs://policyengine-us-data/pooled_3_year_cps_2023.h5"
9+
ECPS_2024 = "gs://policyengine-us-data/ecps_2024.h5"
10+
11+
POLICYENGINE_DATASETS = [
12+
EFRS_2022,
13+
FRS_2022,
14+
CPS_2023,
15+
CPS_2023_POOLED,
16+
ECPS_2024,
17+
]
18+
19+
# Contains datasets that map to particular time_period values
20+
DATASET_TIME_PERIODS = {
21+
CPS_2023: 2023,
22+
CPS_2023_POOLED: 2023,
23+
ECPS_2024: 2023,
24+
}
25+
26+
27+
def get_default_dataset(
28+
country: str, region: str, version: Optional[str] = None
29+
) -> str:
30+
if country == "uk":
31+
return EFRS_2022
32+
elif country == "us":
33+
if region is not None and region != "us":
34+
return CPS_2023_POOLED
35+
else:
36+
return CPS_2023
37+
38+
raise ValueError(
39+
f"Unable to select a default dataset for country {country} and region {region}."
40+
)
41+
42+
43+
def process_gs_path(path: str) -> Tuple[str, str]:
44+
"""Process a GS path to return bucket and object."""
45+
if not path.startswith("gs://"):
46+
raise ValueError(f"Invalid GS path: {path}")
47+
48+
path = path[5:] # Remove 'gs://'
49+
bucket, obj = path.split("/", 1)
50+
return bucket, obj

tests/fixtures/simulation.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
from policyengine.simulation import SimulationOptions
2+
from unittest.mock import patch, Mock
3+
import pytest
4+
from policyengine.utils.data.datasets import CPS_2023
5+
6+
non_data_uk_sim_options = {
7+
"country": "uk",
8+
"scope": "macro",
9+
"region": "uk",
10+
"time_period": 2025,
11+
"reform": None,
12+
"baseline": None,
13+
}
14+
15+
non_data_us_sim_options = {
16+
"country": "us",
17+
"scope": "macro",
18+
"region": "us",
19+
"time_period": 2025,
20+
"reform": None,
21+
"baseline": None,
22+
}
23+
24+
uk_sim_options_no_data = SimulationOptions.model_validate(
25+
{
26+
**non_data_uk_sim_options,
27+
"data": None,
28+
}
29+
)
30+
31+
us_sim_options_cps_dataset = SimulationOptions.model_validate(
32+
{**non_data_us_sim_options, "data": CPS_2023}
33+
)
34+
35+
SAMPLE_DATASET_FILENAME = "sample_value.h5"
36+
SAMPLE_DATASET_BUCKET_NAME = "policyengine-uk-data-private"
37+
SAMPLE_DATASET_URI_PREFIX = "gs://"
38+
SAMPLE_DATASET_FILE_ADDRESS = f"{SAMPLE_DATASET_URI_PREFIX}{SAMPLE_DATASET_BUCKET_NAME}/{SAMPLE_DATASET_FILENAME}"
39+
40+
uk_sim_options_pe_dataset = SimulationOptions.model_validate(
41+
{**non_data_uk_sim_options, "data": SAMPLE_DATASET_FILE_ADDRESS}
42+
)
43+
44+
45+
@pytest.fixture
46+
def mock_get_default_dataset():
47+
with patch(
48+
"policyengine.simulation.get_default_dataset",
49+
return_value=SAMPLE_DATASET_FILE_ADDRESS,
50+
) as mock_get_default_dataset:
51+
yield mock_get_default_dataset
52+
53+
54+
@pytest.fixture
55+
def mock_dataset():
56+
"""Simple Dataset mock fixture"""
57+
with patch("policyengine.simulation.Dataset") as mock_dataset_class:
58+
mock_instance = Mock()
59+
# Set file_path to mimic Dataset's behavior of clipping URI and bucket name from GCS paths
60+
mock_instance.from_file = Mock()
61+
mock_instance.file_path = SAMPLE_DATASET_FILENAME
62+
mock_dataset_class.from_file.return_value = mock_instance
63+
yield mock_instance

tests/test_simulation.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from .fixtures.simulation import (
2+
uk_sim_options_no_data,
3+
uk_sim_options_pe_dataset,
4+
us_sim_options_cps_dataset,
5+
mock_get_default_dataset,
6+
mock_dataset,
7+
SAMPLE_DATASET_FILENAME,
8+
)
9+
import sys
10+
from copy import deepcopy
11+
12+
from policyengine import Simulation
13+
14+
15+
class TestSimulation:
16+
class TestSetData:
17+
def test__given_no_data_option__sets_default_dataset(
18+
self, mock_get_default_dataset, mock_dataset
19+
):
20+
21+
# Don't run entire init script
22+
sim = object.__new__(Simulation)
23+
sim.options = deepcopy(uk_sim_options_no_data)
24+
sim._set_data(uk_sim_options_no_data.data)
25+
26+
assert str(sim.options.data.file_path) == SAMPLE_DATASET_FILENAME
27+
28+
def test__given_pe_dataset__sets_data_option_to_dataset(
29+
self, mock_dataset
30+
):
31+
32+
sim = object.__new__(Simulation)
33+
sim.options = deepcopy(uk_sim_options_pe_dataset)
34+
sim._set_data(uk_sim_options_pe_dataset.data)
35+
36+
assert str(sim.options.data.file_path) == SAMPLE_DATASET_FILENAME
37+
38+
def test__given_cps_2023_in_filename__sets_time_period_to_2023(
39+
self, mock_dataset
40+
):
41+
from policyengine import Simulation
42+
43+
sim = object.__new__(Simulation)
44+
sim.options = deepcopy(us_sim_options_cps_dataset)
45+
sim._set_data(us_sim_options_cps_dataset.data)
46+
47+
assert mock_dataset.from_file.called_with(
48+
us_sim_options_cps_dataset.data, time_period=2023
49+
)
50+
51+
class TestSetDataTimePeriod:
52+
def test__given_dataset_with_time_period__sets_time_period(self):
53+
from policyengine import Simulation
54+
55+
sim = object.__new__(Simulation)
56+
57+
print("Dataset:", us_sim_options_cps_dataset.data, file=sys.stderr)
58+
assert (
59+
sim._set_data_time_period(us_sim_options_cps_dataset.data)
60+
== 2023
61+
)
62+
63+
def test__given_dataset_without_time_period__does_not_set_time_period(
64+
self,
65+
):
66+
from policyengine import Simulation
67+
68+
sim = object.__new__(Simulation)
69+
assert (
70+
sim._set_data_time_period(uk_sim_options_pe_dataset.data)
71+
== None
72+
)

0 commit comments

Comments
 (0)