diff --git a/policyengine/constants.py b/policyengine/constants.py index c6bca554..43351271 100644 --- a/policyengine/constants.py +++ b/policyengine/constants.py @@ -4,7 +4,6 @@ from policyengine.utils.data_download import download # Datasets - ENHANCED_FRS = "hf://policyengine/policyengine-uk-data/enhanced_frs_2022_23.h5" FRS = "hf://policyengine/policyengine-uk-data/frs_2022_23.h5" ENHANCED_CPS = "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5" diff --git a/policyengine/outputs/macro/comparison/calculate_economy_comparison.py b/policyengine/outputs/macro/comparison/calculate_economy_comparison.py index 6592c305..f53691f9 100644 --- a/policyengine/outputs/macro/comparison/calculate_economy_comparison.py +++ b/policyengine/outputs/macro/comparison/calculate_economy_comparison.py @@ -10,6 +10,7 @@ from policyengine.outputs.macro.single.calculate_single_economy import ( SingleEconomy, ) +from policyengine.utils.packages import get_country_package_version from typing import List, Dict @@ -775,6 +776,7 @@ def uk_constituency_breakdown( class EconomyComparison(BaseModel): + country_package_version: str budget: BudgetaryImpact detailed_budget: DetailedBudgetaryImpact decile: DecileImpact @@ -823,6 +825,7 @@ def calculate_economy_comparison( ) return EconomyComparison( + country_package_version=get_country_package_version(country_id), budget=budgetary_impact_data, detailed_budget=detailed_budgetary_impact_data, decile=decile_impact_data, diff --git a/policyengine/simulation.py b/policyengine/simulation.py index 61f00acb..ecb794fc 100644 --- a/policyengine/simulation.py +++ b/policyengine/simulation.py @@ -181,7 +181,7 @@ def _initialise_simulation( time_period: TimePeriodType, region: RegionType, subsample: SubsampleType, - ): + ) -> CountrySimulation: macro = scope == "macro" _simulation_type: Type[CountrySimulation] = { "uk": { diff --git a/policyengine/utils/packages.py b/policyengine/utils/packages.py new file mode 100644 index 00000000..78a254f2 --- /dev/null +++ b/policyengine/utils/packages.py @@ -0,0 +1,28 @@ +from importlib.metadata import version + + +def get_country_package_name(country_id: str) -> str: + if country_id in NON_STANDARD_COUNTRY_CODES: + return f"policyengine_{NON_STANDARD_COUNTRY_CODES[country_id]}" + if country_id in COUNTRY_IDS: + return f"policyengine_{country_id}" + raise ValueError( + f"Unsupported country ID: {country_id}. Supported IDs are: {COUNTRY_IDS}" + ) + + +def get_country_package_version(country_id: str) -> str: + + package_name = get_country_package_name(country_id) + return version(package_name) + + +COUNTRY_IDS = ["us", "uk", "ca", "il", "ng"] + +NON_STANDARD_COUNTRY_CODES = { + "ca": "canada", +} + +COUNTRY_PACKAGES = [ + get_country_package_name(country) for country in COUNTRY_IDS +] diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/fixtures/utils/__init__.py b/tests/fixtures/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/fixtures/utils/packages.py b/tests/fixtures/utils/packages.py new file mode 100644 index 00000000..2877aef7 --- /dev/null +++ b/tests/fixtures/utils/packages.py @@ -0,0 +1,19 @@ +from policyengine.utils.packages import COUNTRY_PACKAGES +import pytest +from unittest.mock import patch + +MOCK_VERSION = "MOCK_VERSION" + + +@pytest.fixture +def patch_importlib_version(): + def mock_version(package_name): + if package_name in COUNTRY_PACKAGES: + return MOCK_VERSION + else: + raise Exception(f"Package {package_name} not found") + + with patch( + "policyengine.utils.packages.version", side_effect=mock_version + ) as mock: + yield mock diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/utils/test_packages.py b/tests/utils/test_packages.py new file mode 100644 index 00000000..822377b0 --- /dev/null +++ b/tests/utils/test_packages.py @@ -0,0 +1,54 @@ +import pytest +from policyengine.utils.packages import ( + get_country_package_name, + get_country_package_version, +) +from tests.fixtures.utils.packages import ( + patch_importlib_version, + MOCK_VERSION, +) + + +class TestGetCountryPackageName: + def test__given_country_id__then_return_package_name(self): + test_country = "us" + + test_package_name = get_country_package_name(test_country) + + assert test_package_name == "policyengine_us" + + def test__given_non_standard_country_id__then_return_package_name(self): + test_country = "ca" + + test_package_name = get_country_package_name(test_country) + + assert test_package_name == "policyengine_canada" + + def test__given_unsupported_country_id__then_return_package_name(self): + test_country = "zz" + + with pytest.raises(Exception, match="Unsupported country ID: zz"): + get_country_package_name(test_country) + + +class TestGetCountryPackageVersion: + def test__given_package_exists__then_return_version( + self, patch_importlib_version + ): + test_country = "us" + + test_version = get_country_package_version(test_country) + + # Version number defined by mock + assert test_version == MOCK_VERSION + + def test__given_package_does_not_exist__then_raise_exception( + self, patch_importlib_version + ): + test_country = "zz" + + with pytest.raises( + Exception, + match=f"Unsupported country ID: {test_country}. Supported IDs are: ", + ): + get_country_package_version(test_country)