Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
- bump: minor
changes:
changed:
- Society-wide reports for the US nationwide now call district breakdown-enabled simulation API
10 changes: 10 additions & 0 deletions policyengine_api/routes/economy_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,16 @@ def get_economic_impact(
region = options.pop("region")
dataset = options.pop("dataset", "default")
time_period = options.pop("time_period")

# Handle district breakdowns - only for US national simulations
include_district_breakdowns_raw = options.pop(
"include_district_breakdowns", "false"
)
include_district_breakdowns = (
include_district_breakdowns_raw.lower() == "true"
)
if include_district_breakdowns and country_id == "us" and region == "us":
dataset = "national-with-breakdowns"
target: Literal["general", "cliff"] = options.pop("target", "general")
api_version = options.pop(
"version", COUNTRY_PACKAGE_VERSIONS.get(country_id)
Expand Down
26 changes: 22 additions & 4 deletions policyengine_api/services/economy_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,7 @@ def _handle_create_impact(
baseline_policy=baseline_policy,
region=setup_options.region,
time_period=setup_options.time_period,
dataset=setup_options.dataset,
scope="macro",
include_cliffs=setup_options.target == "cliff",
model_version=setup_options.model_version,
Expand Down Expand Up @@ -460,6 +461,7 @@ def _setup_sim_options(
include_cliffs: bool = False,
model_version: str | None = None,
data_version: str | None = None,
dataset: str = "default",
) -> SimulationOptions:
"""
Set up the simulation options for the simulation API job.
Expand All @@ -476,7 +478,9 @@ def _setup_sim_options(
"region": self._setup_region(
country_id=country_id, region=region
),
"data": self._setup_data(country_id=country_id, region=region),
"data": self._setup_data(
country_id=country_id, region=region, dataset=dataset
),
"model_version": model_version,
"data_version": data_version,
}
Expand Down Expand Up @@ -520,13 +524,27 @@ def _validate_us_region(self, region: str) -> None:
else:
raise ValueError(f"Invalid US region: '{region}'")

def _setup_data(self, country_id: str, region: str) -> str:
# Dataset keywords that are passed directly to the simulation API
# instead of being resolved via get_default_dataset
PASSTHROUGH_DATASETS = {
"national-with-breakdowns",
"national-with-breakdowns-test",
}

def _setup_data(
self, country_id: str, region: str, dataset: str = "default"
) -> str:
"""
Determine the dataset to use based on the country and region.

Uses policyengine's get_default_dataset to resolve the appropriate
GCS path, making the dataset visible in GCP Console workflow inputs.
If the dataset is in PASSTHROUGH_DATASETS, it will be passed directly
to the simulation API. Otherwise, uses policyengine's get_default_dataset
to resolve the appropriate GCS path.
"""
# If the dataset is a recognized passthrough keyword, use it directly
if dataset in self.PASSTHROUGH_DATASETS:
return dataset

try:
return get_default_dataset(country_id, region)
except ValueError as e:
Expand Down
32 changes: 32 additions & 0 deletions tests/unit/services/test_economy_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,6 +968,38 @@ def test__given_invalid_country__raises_value_error(self, mock_logger):
service._setup_data("invalid", "region")
assert "invalid" in str(exc_info.value).lower()

def test__given_passthrough_dataset__returns_dataset_directly(self):
# Test with passthrough dataset (national-with-breakdowns)
service = EconomyService()
result = service._setup_data(
"us", "us", dataset="national-with-breakdowns"
)
assert result == "national-with-breakdowns"

def test__given_passthrough_test_dataset__returns_dataset_directly(
self,
):
# Test with passthrough test dataset
service = EconomyService()
result = service._setup_data(
"us", "us", dataset="national-with-breakdowns-test"
)
assert result == "national-with-breakdowns-test"

def test__given_default_dataset__uses_get_default_dataset(self):
# Test that "default" falls through to get_default_dataset
service = EconomyService()
result = service._setup_data("us", "state/ca", dataset="default")
assert result == "gs://policyengine-us-data/states/CA.h5"

def test__given_unknown_dataset__uses_get_default_dataset(self):
# Test that unknown dataset values fall through to get_default_dataset
service = EconomyService()
result = service._setup_data(
"us", "state/ca", dataset="unknown-dataset"
)
assert result == "gs://policyengine-us-data/states/CA.h5"

class TestValidateUsRegion:
"""Tests for the _validate_us_region method."""

Expand Down
Loading