diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29bb..747786c96 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,4 @@ +- bump: minor + changes: + changed: + - Society-wide reports for the US nationwide now call district breakdown-enabled simulation API \ No newline at end of file diff --git a/policyengine_api/routes/economy_routes.py b/policyengine_api/routes/economy_routes.py index c0de06730..59e7fdf4c 100644 --- a/policyengine_api/routes/economy_routes.py +++ b/policyengine_api/routes/economy_routes.py @@ -34,6 +34,16 @@ def get_economic_impact( region = options.pop("region") dataset = options.pop("dataset", "default") time_period = options.pop("time_period") + + # Handle district breakdowns - only for US national simulations + include_district_breakdowns_raw = options.pop( + "include_district_breakdowns", "false" + ) + include_district_breakdowns = ( + include_district_breakdowns_raw.lower() == "true" + ) + if include_district_breakdowns and country_id == "us" and region == "us": + dataset = "national-with-breakdowns" target: Literal["general", "cliff"] = options.pop("target", "general") api_version = options.pop( "version", COUNTRY_PACKAGE_VERSIONS.get(country_id) diff --git a/policyengine_api/services/economy_service.py b/policyengine_api/services/economy_service.py index 9ca08b69d..1b1c9c84c 100644 --- a/policyengine_api/services/economy_service.py +++ b/policyengine_api/services/economy_service.py @@ -419,6 +419,7 @@ def _handle_create_impact( baseline_policy=baseline_policy, region=setup_options.region, time_period=setup_options.time_period, + dataset=setup_options.dataset, scope="macro", include_cliffs=setup_options.target == "cliff", model_version=setup_options.model_version, @@ -460,6 +461,7 @@ def _setup_sim_options( include_cliffs: bool = False, model_version: str | None = None, data_version: str | None = None, + dataset: str = "default", ) -> SimulationOptions: """ Set up the simulation options for the simulation API job. @@ -476,7 +478,9 @@ def _setup_sim_options( "region": self._setup_region( country_id=country_id, region=region ), - "data": self._setup_data(country_id=country_id, region=region), + "data": self._setup_data( + country_id=country_id, region=region, dataset=dataset + ), "model_version": model_version, "data_version": data_version, } @@ -520,13 +524,27 @@ def _validate_us_region(self, region: str) -> None: else: raise ValueError(f"Invalid US region: '{region}'") - def _setup_data(self, country_id: str, region: str) -> str: + # Dataset keywords that are passed directly to the simulation API + # instead of being resolved via get_default_dataset + PASSTHROUGH_DATASETS = { + "national-with-breakdowns", + "national-with-breakdowns-test", + } + + def _setup_data( + self, country_id: str, region: str, dataset: str = "default" + ) -> str: """ Determine the dataset to use based on the country and region. - Uses policyengine's get_default_dataset to resolve the appropriate - GCS path, making the dataset visible in GCP Console workflow inputs. + If the dataset is in PASSTHROUGH_DATASETS, it will be passed directly + to the simulation API. Otherwise, uses policyengine's get_default_dataset + to resolve the appropriate GCS path. """ + # If the dataset is a recognized passthrough keyword, use it directly + if dataset in self.PASSTHROUGH_DATASETS: + return dataset + try: return get_default_dataset(country_id, region) except ValueError as e: diff --git a/tests/unit/services/test_economy_service.py b/tests/unit/services/test_economy_service.py index 1220c24b8..6161cb377 100644 --- a/tests/unit/services/test_economy_service.py +++ b/tests/unit/services/test_economy_service.py @@ -968,6 +968,38 @@ def test__given_invalid_country__raises_value_error(self, mock_logger): service._setup_data("invalid", "region") assert "invalid" in str(exc_info.value).lower() + def test__given_passthrough_dataset__returns_dataset_directly(self): + # Test with passthrough dataset (national-with-breakdowns) + service = EconomyService() + result = service._setup_data( + "us", "us", dataset="national-with-breakdowns" + ) + assert result == "national-with-breakdowns" + + def test__given_passthrough_test_dataset__returns_dataset_directly( + self, + ): + # Test with passthrough test dataset + service = EconomyService() + result = service._setup_data( + "us", "us", dataset="national-with-breakdowns-test" + ) + assert result == "national-with-breakdowns-test" + + def test__given_default_dataset__uses_get_default_dataset(self): + # Test that "default" falls through to get_default_dataset + service = EconomyService() + result = service._setup_data("us", "state/ca", dataset="default") + assert result == "gs://policyengine-us-data/states/CA.h5" + + def test__given_unknown_dataset__uses_get_default_dataset(self): + # Test that unknown dataset values fall through to get_default_dataset + service = EconomyService() + result = service._setup_data( + "us", "state/ca", dataset="unknown-dataset" + ) + assert result == "gs://policyengine-us-data/states/CA.h5" + class TestValidateUsRegion: """Tests for the _validate_us_region method."""