Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
- bump: patch
changes:
changed:
- Disambiguated filepath management in Simulation._set_data()
- Refactored Simulation._set_data() to divide functionality into smaller methods
- Prevented passage of non-Path URIs to Dataset.from_file() at end of Simulation._set_data() execution
added:
- Tests for Simulation._set_data()
27 changes: 0 additions & 27 deletions policyengine/constants.py

This file was deleted.

112 changes: 77 additions & 35 deletions policyengine/simulation.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
"""Simulate tax-benefit policy and derive society-level output statistics."""

import sys
from pydantic import BaseModel, Field
from typing import Literal
from .constants import get_default_dataset
from .utils.data.datasets import (
get_default_dataset,
process_gs_path,
POLICYENGINE_DATASETS,
DATASET_TIME_PERIODS,
)
from policyengine_core.simulations import Simulation as CountrySimulation
from policyengine_core.simulations import (
Microsimulation as CountryMicrosimulation,
Expand All @@ -22,16 +28,16 @@
import h5py
from pathlib import Path
import pandas as pd
from typing import Type, Optional
from typing import Type, Any, Optional
from functools import wraps, partial
from typing import Dict, Any, Callable
from typing import Callable
import importlib
from policyengine.utils.data_download import download

CountryType = Literal["uk", "us"]
ScopeType = Literal["household", "macro"]
DataType = (
str | dict | Any | None
str | dict[Any, Any] | Dataset | None
) # Needs stricter typing. Any==policyengine_core.data.Dataset, but pydantic refuses for some reason.
TimePeriodType = int
ReformType = ParametricReform | Type[StructuralReform] | None
Expand Down Expand Up @@ -72,6 +78,10 @@ class SimulationOptions(BaseModel):
description="The version of the data used in the simulation. If not provided, the current data version will be used. If provided, this package will throw an error if the data version does not match. Use this as an extra safety check.",
)

model_config = {
"arbitrary_types_allowed": True,
}


class Simulation:
"""Simulate tax-benefit policy and derive society-level output statistics."""
Expand All @@ -89,7 +99,10 @@ class Simulation:
def __init__(self, **options: SimulationOptions):
self.options = SimulationOptions(**options)
self.check_model_version()
self._set_data()
if not isinstance(self.options.data, dict) and not isinstance(
self.options.data, Dataset
):
self._set_data(self.options.data)
self._initialise_simulations()
self.check_data_version()
self._add_output_functions()
Expand Down Expand Up @@ -125,39 +138,37 @@ def _add_output_functions(self):
wrapped_func,
)

def _set_data(self):
if self.options.data is None:
self.options.data = get_default_dataset(
country=self.options.country,
region=self.options.region,
)
def _set_data(self, file_address: str | None = None) -> None:

if isinstance(self.options.data, str):
filename = self.options.data
if self.options.data[:6] == "gcs://":
bucket, filename = self.options.data.split("://")[-1].split(
"/"
)
version = self.options.data_version
# filename refers to file's unique name + extension;
# file_address refers to URI + filename

file_path, version = download(
filepath=filename,
gcs_bucket=bucket,
version=version,
return_version=True,
)
self.data_version = version
filename = str(Path(file_path))
else:
# If it's a local file, we can't infer the version.
version = None
if "cps_2023" in filename:
time_period = 2023
else:
time_period = None
self.options.data = Dataset.from_file(
filename, time_period=time_period
# If None is passed, user wants default dataset; get URL, then continue initializing.
if file_address is None:
file_address = get_default_dataset(
country=self.options.country, region=self.options.region
)
print(
f"No data provided, using default dataset: {file_address}",
file=sys.stderr,
)

if file_address not in POLICYENGINE_DATASETS:
# If it's a local file, no URI present and unable to infer version.
filename = file_address
version = None

else:
# All official PolicyEngine datasets are stored in GCS;
# load accordingly
filename, version = self._set_data_from_gs(file_address)
self.data_version = version

time_period = self._set_data_time_period(file_address)

self.options.data = Dataset.from_file(
filename, time_period=time_period
)

def _initialise_simulations(self):
self.baseline_simulation = self._initialise_simulation(
Expand Down Expand Up @@ -361,3 +372,34 @@ def check_data_version(self) -> None:
raise ValueError(
f"Data version {self.data_version} does not match expected version {self.options.data_version}."
)

def _set_data_time_period(self, file_address: str) -> Optional[int]:
"""
Set the time period based on the file address.
If the file address is a PE dataset, return the time period from the dataset.
If it's a local file, return None.
"""
if file_address in DATASET_TIME_PERIODS:
return DATASET_TIME_PERIODS[file_address]
else:
# Local file, no time period available
return None

def _set_data_from_gs(self, file_address: str) -> tuple[str, str | None]:
"""
Set the data from a GCS path and return the filename and version.
"""

bucket, filename = process_gs_path(file_address)
version = self.options.data_version

print(f"Downloading {filename} from bucket {bucket}", file=sys.stderr)

filepath, version = download(
filepath=filename,
gcs_bucket=bucket,
version=version,
return_version=True,
)

return filename, version
50 changes: 50 additions & 0 deletions policyengine/utils/data/datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""Mainly simulation options and parameters."""

from typing import Tuple, Optional

EFRS_2022 = "gs://policyengine-uk-data-private/enhanced_frs_2022_23.h5"
FRS_2022 = "gs://policyengine-uk-data-private/frs_2022_23.h5"
CPS_2023 = "gs://policyengine-us-data/cps_2023.h5"
CPS_2023_POOLED = "gs://policyengine-us-data/pooled_3_year_cps_2023.h5"
ECPS_2024 = "gs://policyengine-us-data/ecps_2024.h5"

POLICYENGINE_DATASETS = [
EFRS_2022,
FRS_2022,
CPS_2023,
CPS_2023_POOLED,
ECPS_2024,
]

# Contains datasets that map to particular time_period values
DATASET_TIME_PERIODS = {
CPS_2023: 2023,
CPS_2023_POOLED: 2023,
ECPS_2024: 2023,
}


def get_default_dataset(
country: str, region: str, version: Optional[str] = None
) -> str:
if country == "uk":
return EFRS_2022
elif country == "us":
if region is not None and region != "us":
return CPS_2023_POOLED
else:
return CPS_2023

raise ValueError(
f"Unable to select a default dataset for country {country} and region {region}."
)


def process_gs_path(path: str) -> Tuple[str, str]:
"""Process a GS path to return bucket and object."""
if not path.startswith("gs://"):
raise ValueError(f"Invalid GS path: {path}")

path = path[5:] # Remove 'gs://'
bucket, obj = path.split("/", 1)
return bucket, obj
63 changes: 63 additions & 0 deletions tests/fixtures/simulation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from policyengine.simulation import SimulationOptions
from unittest.mock import patch, Mock
import pytest
from policyengine.utils.data.datasets import CPS_2023

non_data_uk_sim_options = {
"country": "uk",
"scope": "macro",
"region": "uk",
"time_period": 2025,
"reform": None,
"baseline": None,
}

non_data_us_sim_options = {
"country": "us",
"scope": "macro",
"region": "us",
"time_period": 2025,
"reform": None,
"baseline": None,
}

uk_sim_options_no_data = SimulationOptions.model_validate(
{
**non_data_uk_sim_options,
"data": None,
}
)

us_sim_options_cps_dataset = SimulationOptions.model_validate(
{**non_data_us_sim_options, "data": CPS_2023}
)

SAMPLE_DATASET_FILENAME = "sample_value.h5"
SAMPLE_DATASET_BUCKET_NAME = "policyengine-uk-data-private"
SAMPLE_DATASET_URI_PREFIX = "gs://"
SAMPLE_DATASET_FILE_ADDRESS = f"{SAMPLE_DATASET_URI_PREFIX}{SAMPLE_DATASET_BUCKET_NAME}/{SAMPLE_DATASET_FILENAME}"

uk_sim_options_pe_dataset = SimulationOptions.model_validate(
{**non_data_uk_sim_options, "data": SAMPLE_DATASET_FILE_ADDRESS}
)


@pytest.fixture
def mock_get_default_dataset():
with patch(
"policyengine.simulation.get_default_dataset",
return_value=SAMPLE_DATASET_FILE_ADDRESS,
) as mock_get_default_dataset:
yield mock_get_default_dataset


@pytest.fixture
def mock_dataset():
"""Simple Dataset mock fixture"""
with patch("policyengine.simulation.Dataset") as mock_dataset_class:
mock_instance = Mock()
# Set file_path to mimic Dataset's behavior of clipping URI and bucket name from GCS paths
mock_instance.from_file = Mock()
mock_instance.file_path = SAMPLE_DATASET_FILENAME
mock_dataset_class.from_file.return_value = mock_instance
yield mock_instance
72 changes: 72 additions & 0 deletions tests/test_simulation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from .fixtures.simulation import (
uk_sim_options_no_data,
uk_sim_options_pe_dataset,
us_sim_options_cps_dataset,
mock_get_default_dataset,
mock_dataset,
SAMPLE_DATASET_FILENAME,
)
import sys
from copy import deepcopy

from policyengine import Simulation


class TestSimulation:
class TestSetData:
def test__given_no_data_option__sets_default_dataset(
self, mock_get_default_dataset, mock_dataset
):

# Don't run entire init script
sim = object.__new__(Simulation)
sim.options = deepcopy(uk_sim_options_no_data)
sim._set_data(uk_sim_options_no_data.data)

assert str(sim.options.data.file_path) == SAMPLE_DATASET_FILENAME

def test__given_pe_dataset__sets_data_option_to_dataset(
self, mock_dataset
):

sim = object.__new__(Simulation)
sim.options = deepcopy(uk_sim_options_pe_dataset)
sim._set_data(uk_sim_options_pe_dataset.data)

assert str(sim.options.data.file_path) == SAMPLE_DATASET_FILENAME

def test__given_cps_2023_in_filename__sets_time_period_to_2023(
self, mock_dataset
):
from policyengine import Simulation

sim = object.__new__(Simulation)
sim.options = deepcopy(us_sim_options_cps_dataset)
sim._set_data(us_sim_options_cps_dataset.data)

assert mock_dataset.from_file.called_with(
us_sim_options_cps_dataset.data, time_period=2023
)

class TestSetDataTimePeriod:
def test__given_dataset_with_time_period__sets_time_period(self):
from policyengine import Simulation

sim = object.__new__(Simulation)

print("Dataset:", us_sim_options_cps_dataset.data, file=sys.stderr)
assert (
sim._set_data_time_period(us_sim_options_cps_dataset.data)
== 2023
)

def test__given_dataset_without_time_period__does_not_set_time_period(
self,
):
from policyengine import Simulation

sim = object.__new__(Simulation)
assert (
sim._set_data_time_period(uk_sim_options_pe_dataset.data)
== None
)
Loading