Skip to content

Commit 1e10d56

Browse files
committed
chore: Lint and changelog
1 parent c735db0 commit 1e10d56

File tree

5 files changed

+70
-46
lines changed

5 files changed

+70
-46
lines changed

changelog_entry.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
- bump: patch
2+
changes:
3+
changed:
4+
- Disambiguated filepath management in Simulation._set_data()
5+
- Refactored Simulation._set_data() to divide functionality into smaller methods
6+
- Prevented passage of non-Path URIs to Dataset.from_file() at end of Simulation._set_data() execution
7+
added:
8+
- Tests for Simulation._set_data()

policyengine/simulation.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,12 @@
33
import sys
44
from pydantic import BaseModel, Field
55
from typing import Literal
6-
from .utils.data.datasets import get_default_dataset, process_gs_path, POLICYENGINE_DATASETS, DATASET_TIME_PERIODS
6+
from .utils.data.datasets import (
7+
get_default_dataset,
8+
process_gs_path,
9+
POLICYENGINE_DATASETS,
10+
DATASET_TIME_PERIODS,
11+
)
712
from policyengine_core.simulations import Simulation as CountrySimulation
813
from policyengine_core.simulations import (
914
Microsimulation as CountryMicrosimulation,
@@ -31,9 +36,7 @@
3136

3237
CountryType = Literal["uk", "us"]
3338
ScopeType = Literal["household", "macro"]
34-
DataType = (
35-
str | Dataset | None
36-
)
39+
DataType = str | Dataset | None
3740
TimePeriodType = int
3841
ReformType = ParametricReform | Type[StructuralReform] | None
3942
RegionType = Optional[str]
@@ -95,7 +98,7 @@ def __init__(self, **options: SimulationOptions):
9598
self.options = SimulationOptions(**options)
9699
self.check_model_version()
97100
if not isinstance(self.options.data, Dataset):
98-
self._set_data(self.options.data)
101+
self._set_data(self.options.data)
99102
self._initialise_simulations()
100103
self.check_data_version()
101104
self._add_output_functions()
@@ -139,8 +142,7 @@ def _set_data(self, file_address: str | None = None) -> None:
139142
# If None is passed, user wants default dataset; get URL, then continue initializing.
140143
if file_address is None:
141144
file_address = get_default_dataset(
142-
country=self.options.country,
143-
region=self.options.region
145+
country=self.options.country, region=self.options.region
144146
)
145147
print(
146148
f"No data provided, using default dataset: {file_address}",
@@ -155,15 +157,11 @@ def _set_data(self, file_address: str | None = None) -> None:
155157
else:
156158
# All official PolicyEngine datasets are stored in GCS;
157159
# load accordingly
158-
filename, version = self._set_data_from_gs(
159-
file_address
160-
)
160+
filename, version = self._set_data_from_gs(file_address)
161161
self.data_version = version
162162

163-
time_period = self._set_data_time_period(
164-
file_address
165-
)
166-
163+
time_period = self._set_data_time_period(file_address)
164+
167165
self.options.data = Dataset.from_file(
168166
filename, time_period=time_period
169167
)
@@ -370,7 +368,7 @@ def check_data_version(self) -> None:
370368
raise ValueError(
371369
f"Data version {self.data_version} does not match expected version {self.options.data_version}."
372370
)
373-
371+
374372
def _set_data_time_period(self, file_address: str) -> Optional[int]:
375373
"""
376374
Set the time period based on the file address.
@@ -383,9 +381,7 @@ def _set_data_time_period(self, file_address: str) -> Optional[int]:
383381
# Local file, no time period available
384382
return None
385383

386-
def _set_data_from_gs(
387-
self, file_address: str
388-
) -> tuple[str, str | None]:
384+
def _set_data_from_gs(self, file_address: str) -> tuple[str, str | None]:
389385
"""
390386
Set the data from a GCS path and return the filename and version.
391387
"""
@@ -403,4 +399,3 @@ def _set_data_from_gs(
403399
)
404400

405401
return filename, version
406-

policyengine/utils/data/datasets.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
ECPS_2024: 2023,
2424
}
2525

26+
2627
def get_default_dataset(
2728
country: str, region: str, version: Optional[str] = None
2829
) -> str:
@@ -38,11 +39,12 @@ def get_default_dataset(
3839
f"Unable to select a default dataset for country {country} and region {region}."
3940
)
4041

42+
4143
def process_gs_path(path: str) -> Tuple[str, str]:
4244
"""Process a GS path to return bucket and object."""
4345
if not path.startswith("gs://"):
4446
raise ValueError(f"Invalid GS path: {path}")
45-
47+
4648
path = path[5:] # Remove 'gs://'
4749
bucket, obj = path.split("/", 1)
48-
return bucket, obj
50+
return bucket, obj

tests/fixtures/simulation.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from policyengine.simulation import SimulationOptions
22
from unittest.mock import patch, Mock
33
import pytest
4-
from policyengine.utils.data.datasets import CPS_2023
4+
from policyengine.utils.data.datasets import CPS_2023
55

66
non_data_uk_sim_options = {
77
"country": "uk",
@@ -21,41 +21,43 @@
2121
"baseline": None,
2222
}
2323

24-
uk_sim_options_no_data = SimulationOptions.model_validate({
25-
**non_data_uk_sim_options,
26-
"data": None,
27-
})
24+
uk_sim_options_no_data = SimulationOptions.model_validate(
25+
{
26+
**non_data_uk_sim_options,
27+
"data": None,
28+
}
29+
)
2830

29-
us_sim_options_cps_dataset = SimulationOptions.model_validate({
30-
**non_data_us_sim_options,
31-
"data": CPS_2023
32-
})
31+
us_sim_options_cps_dataset = SimulationOptions.model_validate(
32+
{**non_data_us_sim_options, "data": CPS_2023}
33+
)
3334

3435
SAMPLE_DATASET_FILENAME = "sample_value.h5"
3536
SAMPLE_DATASET_BUCKET_NAME = "policyengine-uk-data-private"
3637
SAMPLE_DATASET_URI_PREFIX = "gs://"
3738
SAMPLE_DATASET_FILE_ADDRESS = f"{SAMPLE_DATASET_URI_PREFIX}{SAMPLE_DATASET_BUCKET_NAME}/{SAMPLE_DATASET_FILENAME}"
3839

39-
uk_sim_options_pe_dataset = SimulationOptions.model_validate({
40-
**non_data_uk_sim_options,
41-
"data": SAMPLE_DATASET_FILE_ADDRESS
42-
})
40+
uk_sim_options_pe_dataset = SimulationOptions.model_validate(
41+
{**non_data_uk_sim_options, "data": SAMPLE_DATASET_FILE_ADDRESS}
42+
)
43+
4344

4445
@pytest.fixture
4546
def mock_get_default_dataset():
4647
with patch(
4748
"policyengine.simulation.get_default_dataset",
48-
return_value=SAMPLE_DATASET_FILE_ADDRESS
49+
return_value=SAMPLE_DATASET_FILE_ADDRESS,
4950
) as mock_get_default_dataset:
5051
yield mock_get_default_dataset
5152

53+
5254
@pytest.fixture
5355
def mock_dataset():
5456
"""Simple Dataset mock fixture"""
55-
with patch('policyengine.simulation.Dataset') as mock_dataset_class:
57+
with patch("policyengine.simulation.Dataset") as mock_dataset_class:
5658
mock_instance = Mock()
5759
# Set file_path to mimic Dataset's behavior of clipping URI and bucket name from GCS paths
5860
mock_instance.from_file = Mock()
5961
mock_instance.file_path = SAMPLE_DATASET_FILENAME
6062
mock_dataset_class.from_file.return_value = mock_instance
61-
yield mock_instance
63+
yield mock_instance

tests/test_simulation.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,52 +4,69 @@
44
us_sim_options_cps_dataset,
55
mock_get_default_dataset,
66
mock_dataset,
7-
SAMPLE_DATASET_FILENAME
7+
SAMPLE_DATASET_FILENAME,
88
)
99
import sys
1010
from copy import deepcopy
1111

1212
from policyengine import Simulation
1313

14+
1415
class TestSimulation:
1516
class TestSetData:
16-
def test__given_no_data_option__sets_default_dataset(self, mock_get_default_dataset, mock_dataset):
17+
def test__given_no_data_option__sets_default_dataset(
18+
self, mock_get_default_dataset, mock_dataset
19+
):
1720

1821
# Don't run entire init script
1922
sim = object.__new__(Simulation)
2023
sim.options = deepcopy(uk_sim_options_no_data)
2124
sim._set_data(uk_sim_options_no_data.data)
2225

2326
assert str(sim.options.data.file_path) == SAMPLE_DATASET_FILENAME
24-
def test__given_pe_dataset__sets_data_option_to_dataset(self, mock_dataset):
27+
28+
def test__given_pe_dataset__sets_data_option_to_dataset(
29+
self, mock_dataset
30+
):
2531

2632
sim = object.__new__(Simulation)
2733
sim.options = deepcopy(uk_sim_options_pe_dataset)
2834
sim._set_data(uk_sim_options_pe_dataset.data)
2935

3036
assert str(sim.options.data.file_path) == SAMPLE_DATASET_FILENAME
31-
def test__given_cps_2023_in_filename__sets_time_period_to_2023(self, mock_dataset):
37+
38+
def test__given_cps_2023_in_filename__sets_time_period_to_2023(
39+
self, mock_dataset
40+
):
3241
from policyengine import Simulation
3342

3443
sim = object.__new__(Simulation)
3544
sim.options = deepcopy(us_sim_options_cps_dataset)
3645
sim._set_data(us_sim_options_cps_dataset.data)
3746

3847
assert mock_dataset.from_file.called_with(
39-
us_sim_options_cps_dataset.data,
40-
time_period=2023
48+
us_sim_options_cps_dataset.data, time_period=2023
4149
)
50+
4251
class TestSetDataTimePeriod:
4352
def test__given_dataset_with_time_period__sets_time_period(self):
4453
from policyengine import Simulation
4554

4655
sim = object.__new__(Simulation)
4756

4857
print("Dataset:", us_sim_options_cps_dataset.data, file=sys.stderr)
49-
assert sim._set_data_time_period(us_sim_options_cps_dataset.data) == 2023
58+
assert (
59+
sim._set_data_time_period(us_sim_options_cps_dataset.data)
60+
== 2023
61+
)
5062

51-
def test__given_dataset_without_time_period__does_not_set_time_period(self):
63+
def test__given_dataset_without_time_period__does_not_set_time_period(
64+
self,
65+
):
5266
from policyengine import Simulation
5367

5468
sim = object.__new__(Simulation)
55-
assert sim._set_data_time_period(uk_sim_options_pe_dataset.data) == None
69+
assert (
70+
sim._set_data_time_period(uk_sim_options_pe_dataset.data)
71+
== None
72+
)

0 commit comments

Comments
 (0)