diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29b..760395ca 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,8 @@ +- bump: patch + changes: + changed: + - Disambiguated filepath management in Simulation._set_data() + - Refactored Simulation._set_data() to divide functionality into smaller methods + - Prevented passage of non-Path URIs to Dataset.from_file() at end of Simulation._set_data() execution + added: + - Tests for Simulation._set_data() \ No newline at end of file diff --git a/policyengine/constants.py b/policyengine/constants.py deleted file mode 100644 index de9b6799..00000000 --- a/policyengine/constants.py +++ /dev/null @@ -1,27 +0,0 @@ -"""Mainly simulation options and parameters.""" - -from policyengine_core.data import Dataset -from policyengine.utils.data_download import download -from typing import Tuple, Optional - -EFRS_2022 = "gcs://policyengine-uk-data-private/enhanced_frs_2022_23.h5" -FRS_2022 = "gcs://policyengine-uk-data-private/frs_2022_23.h5" -CPS_2023_POOLED = "gcs://policyengine-us-data/pooled_3_year_cps_2023.h5" -CPS_2023 = "gcs://policyengine-us-data/cps_2023.h5" -ECPS_2024 = "gcs://policyengine-us-data/ecps_2024.h5" - - -def get_default_dataset( - country: str, region: str, version: Optional[str] = None -) -> str: - if country == "uk": - return EFRS_2022 - elif country == "us": - if region is not None and region != "us": - return CPS_2023_POOLED - else: - return CPS_2023 - - raise ValueError( - f"Unable to select a default dataset for country {country} and region {region}." - ) diff --git a/policyengine/simulation.py b/policyengine/simulation.py index 670cff67..a8a818fb 100644 --- a/policyengine/simulation.py +++ b/policyengine/simulation.py @@ -1,8 +1,14 @@ """Simulate tax-benefit policy and derive society-level output statistics.""" +import sys from pydantic import BaseModel, Field from typing import Literal -from .constants import get_default_dataset +from .utils.data.datasets import ( + get_default_dataset, + process_gs_path, + POLICYENGINE_DATASETS, + DATASET_TIME_PERIODS, +) from policyengine_core.simulations import Simulation as CountrySimulation from policyengine_core.simulations import ( Microsimulation as CountryMicrosimulation, @@ -22,16 +28,16 @@ import h5py from pathlib import Path import pandas as pd -from typing import Type, Optional +from typing import Type, Any, Optional from functools import wraps, partial -from typing import Dict, Any, Callable +from typing import Callable import importlib from policyengine.utils.data_download import download CountryType = Literal["uk", "us"] ScopeType = Literal["household", "macro"] DataType = ( - str | dict | Any | None + str | dict[Any, Any] | Dataset | None ) # Needs stricter typing. Any==policyengine_core.data.Dataset, but pydantic refuses for some reason. TimePeriodType = int ReformType = ParametricReform | Type[StructuralReform] | None @@ -72,6 +78,10 @@ class SimulationOptions(BaseModel): description="The version of the data used in the simulation. If not provided, the current data version will be used. If provided, this package will throw an error if the data version does not match. Use this as an extra safety check.", ) + model_config = { + "arbitrary_types_allowed": True, + } + class Simulation: """Simulate tax-benefit policy and derive society-level output statistics.""" @@ -89,7 +99,10 @@ class Simulation: def __init__(self, **options: SimulationOptions): self.options = SimulationOptions(**options) self.check_model_version() - self._set_data() + if not isinstance(self.options.data, dict) and not isinstance( + self.options.data, Dataset + ): + self._set_data(self.options.data) self._initialise_simulations() self.check_data_version() self._add_output_functions() @@ -125,39 +138,37 @@ def _add_output_functions(self): wrapped_func, ) - def _set_data(self): - if self.options.data is None: - self.options.data = get_default_dataset( - country=self.options.country, - region=self.options.region, - ) + def _set_data(self, file_address: str | None = None) -> None: - if isinstance(self.options.data, str): - filename = self.options.data - if self.options.data[:6] == "gcs://": - bucket, filename = self.options.data.split("://")[-1].split( - "/" - ) - version = self.options.data_version + # filename refers to file's unique name + extension; + # file_address refers to URI + filename - file_path, version = download( - filepath=filename, - gcs_bucket=bucket, - version=version, - return_version=True, - ) - self.data_version = version - filename = str(Path(file_path)) - else: - # If it's a local file, we can't infer the version. - version = None - if "cps_2023" in filename: - time_period = 2023 - else: - time_period = None - self.options.data = Dataset.from_file( - filename, time_period=time_period + # If None is passed, user wants default dataset; get URL, then continue initializing. + if file_address is None: + file_address = get_default_dataset( + country=self.options.country, region=self.options.region ) + print( + f"No data provided, using default dataset: {file_address}", + file=sys.stderr, + ) + + if file_address not in POLICYENGINE_DATASETS: + # If it's a local file, no URI present and unable to infer version. + filename = file_address + version = None + + else: + # All official PolicyEngine datasets are stored in GCS; + # load accordingly + filename, version = self._set_data_from_gs(file_address) + self.data_version = version + + time_period = self._set_data_time_period(file_address) + + self.options.data = Dataset.from_file( + filename, time_period=time_period + ) def _initialise_simulations(self): self.baseline_simulation = self._initialise_simulation( @@ -361,3 +372,34 @@ def check_data_version(self) -> None: raise ValueError( f"Data version {self.data_version} does not match expected version {self.options.data_version}." ) + + def _set_data_time_period(self, file_address: str) -> Optional[int]: + """ + Set the time period based on the file address. + If the file address is a PE dataset, return the time period from the dataset. + If it's a local file, return None. + """ + if file_address in DATASET_TIME_PERIODS: + return DATASET_TIME_PERIODS[file_address] + else: + # Local file, no time period available + return None + + def _set_data_from_gs(self, file_address: str) -> tuple[str, str | None]: + """ + Set the data from a GCS path and return the filename and version. + """ + + bucket, filename = process_gs_path(file_address) + version = self.options.data_version + + print(f"Downloading {filename} from bucket {bucket}", file=sys.stderr) + + filepath, version = download( + filepath=filename, + gcs_bucket=bucket, + version=version, + return_version=True, + ) + + return filename, version diff --git a/policyengine/utils/data/datasets.py b/policyengine/utils/data/datasets.py new file mode 100644 index 00000000..4dcd8af6 --- /dev/null +++ b/policyengine/utils/data/datasets.py @@ -0,0 +1,50 @@ +"""Mainly simulation options and parameters.""" + +from typing import Tuple, Optional + +EFRS_2022 = "gs://policyengine-uk-data-private/enhanced_frs_2022_23.h5" +FRS_2022 = "gs://policyengine-uk-data-private/frs_2022_23.h5" +CPS_2023 = "gs://policyengine-us-data/cps_2023.h5" +CPS_2023_POOLED = "gs://policyengine-us-data/pooled_3_year_cps_2023.h5" +ECPS_2024 = "gs://policyengine-us-data/ecps_2024.h5" + +POLICYENGINE_DATASETS = [ + EFRS_2022, + FRS_2022, + CPS_2023, + CPS_2023_POOLED, + ECPS_2024, +] + +# Contains datasets that map to particular time_period values +DATASET_TIME_PERIODS = { + CPS_2023: 2023, + CPS_2023_POOLED: 2023, + ECPS_2024: 2023, +} + + +def get_default_dataset( + country: str, region: str, version: Optional[str] = None +) -> str: + if country == "uk": + return EFRS_2022 + elif country == "us": + if region is not None and region != "us": + return CPS_2023_POOLED + else: + return CPS_2023 + + raise ValueError( + f"Unable to select a default dataset for country {country} and region {region}." + ) + + +def process_gs_path(path: str) -> Tuple[str, str]: + """Process a GS path to return bucket and object.""" + if not path.startswith("gs://"): + raise ValueError(f"Invalid GS path: {path}") + + path = path[5:] # Remove 'gs://' + bucket, obj = path.split("/", 1) + return bucket, obj diff --git a/tests/fixtures/simulation.py b/tests/fixtures/simulation.py new file mode 100644 index 00000000..7cc51720 --- /dev/null +++ b/tests/fixtures/simulation.py @@ -0,0 +1,63 @@ +from policyengine.simulation import SimulationOptions +from unittest.mock import patch, Mock +import pytest +from policyengine.utils.data.datasets import CPS_2023 + +non_data_uk_sim_options = { + "country": "uk", + "scope": "macro", + "region": "uk", + "time_period": 2025, + "reform": None, + "baseline": None, +} + +non_data_us_sim_options = { + "country": "us", + "scope": "macro", + "region": "us", + "time_period": 2025, + "reform": None, + "baseline": None, +} + +uk_sim_options_no_data = SimulationOptions.model_validate( + { + **non_data_uk_sim_options, + "data": None, + } +) + +us_sim_options_cps_dataset = SimulationOptions.model_validate( + {**non_data_us_sim_options, "data": CPS_2023} +) + +SAMPLE_DATASET_FILENAME = "sample_value.h5" +SAMPLE_DATASET_BUCKET_NAME = "policyengine-uk-data-private" +SAMPLE_DATASET_URI_PREFIX = "gs://" +SAMPLE_DATASET_FILE_ADDRESS = f"{SAMPLE_DATASET_URI_PREFIX}{SAMPLE_DATASET_BUCKET_NAME}/{SAMPLE_DATASET_FILENAME}" + +uk_sim_options_pe_dataset = SimulationOptions.model_validate( + {**non_data_uk_sim_options, "data": SAMPLE_DATASET_FILE_ADDRESS} +) + + +@pytest.fixture +def mock_get_default_dataset(): + with patch( + "policyengine.simulation.get_default_dataset", + return_value=SAMPLE_DATASET_FILE_ADDRESS, + ) as mock_get_default_dataset: + yield mock_get_default_dataset + + +@pytest.fixture +def mock_dataset(): + """Simple Dataset mock fixture""" + with patch("policyengine.simulation.Dataset") as mock_dataset_class: + mock_instance = Mock() + # Set file_path to mimic Dataset's behavior of clipping URI and bucket name from GCS paths + mock_instance.from_file = Mock() + mock_instance.file_path = SAMPLE_DATASET_FILENAME + mock_dataset_class.from_file.return_value = mock_instance + yield mock_instance diff --git a/tests/test_simulation.py b/tests/test_simulation.py new file mode 100644 index 00000000..a3697bc2 --- /dev/null +++ b/tests/test_simulation.py @@ -0,0 +1,72 @@ +from .fixtures.simulation import ( + uk_sim_options_no_data, + uk_sim_options_pe_dataset, + us_sim_options_cps_dataset, + mock_get_default_dataset, + mock_dataset, + SAMPLE_DATASET_FILENAME, +) +import sys +from copy import deepcopy + +from policyengine import Simulation + + +class TestSimulation: + class TestSetData: + def test__given_no_data_option__sets_default_dataset( + self, mock_get_default_dataset, mock_dataset + ): + + # Don't run entire init script + sim = object.__new__(Simulation) + sim.options = deepcopy(uk_sim_options_no_data) + sim._set_data(uk_sim_options_no_data.data) + + assert str(sim.options.data.file_path) == SAMPLE_DATASET_FILENAME + + def test__given_pe_dataset__sets_data_option_to_dataset( + self, mock_dataset + ): + + sim = object.__new__(Simulation) + sim.options = deepcopy(uk_sim_options_pe_dataset) + sim._set_data(uk_sim_options_pe_dataset.data) + + assert str(sim.options.data.file_path) == SAMPLE_DATASET_FILENAME + + def test__given_cps_2023_in_filename__sets_time_period_to_2023( + self, mock_dataset + ): + from policyengine import Simulation + + sim = object.__new__(Simulation) + sim.options = deepcopy(us_sim_options_cps_dataset) + sim._set_data(us_sim_options_cps_dataset.data) + + assert mock_dataset.from_file.called_with( + us_sim_options_cps_dataset.data, time_period=2023 + ) + + class TestSetDataTimePeriod: + def test__given_dataset_with_time_period__sets_time_period(self): + from policyengine import Simulation + + sim = object.__new__(Simulation) + + print("Dataset:", us_sim_options_cps_dataset.data, file=sys.stderr) + assert ( + sim._set_data_time_period(us_sim_options_cps_dataset.data) + == 2023 + ) + + def test__given_dataset_without_time_period__does_not_set_time_period( + self, + ): + from policyengine import Simulation + + sim = object.__new__(Simulation) + assert ( + sim._set_data_time_period(uk_sim_options_pe_dataset.data) + == None + )