Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion policyengine/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from policyengine.utils.data_download import download

# Datasets

ENHANCED_FRS = "hf://policyengine/policyengine-uk-data/enhanced_frs_2022_23.h5"
FRS = "hf://policyengine/policyengine-uk-data/frs_2022_23.h5"
ENHANCED_CPS = "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from policyengine.outputs.macro.single.calculate_single_economy import (
SingleEconomy,
)
from policyengine.utils.packages import get_country_package_version
from typing import List, Dict


Expand Down Expand Up @@ -775,6 +776,7 @@ def uk_constituency_breakdown(


class EconomyComparison(BaseModel):
country_package_version: str
budget: BudgetaryImpact
detailed_budget: DetailedBudgetaryImpact
decile: DecileImpact
Expand Down Expand Up @@ -823,6 +825,7 @@ def calculate_economy_comparison(
)

return EconomyComparison(
country_package_version=get_country_package_version(country_id),
budget=budgetary_impact_data,
detailed_budget=detailed_budgetary_impact_data,
decile=decile_impact_data,
Expand Down
2 changes: 1 addition & 1 deletion policyengine/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def _initialise_simulation(
time_period: TimePeriodType,
region: RegionType,
subsample: SubsampleType,
):
) -> CountrySimulation:
macro = scope == "macro"
_simulation_type: Type[CountrySimulation] = {
"uk": {
Expand Down
28 changes: 28 additions & 0 deletions policyengine/utils/packages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from importlib.metadata import version


def get_country_package_name(country_id: str) -> str:
if country_id in NON_STANDARD_COUNTRY_CODES:
return f"policyengine_{NON_STANDARD_COUNTRY_CODES[country_id]}"
if country_id in COUNTRY_IDS:
return f"policyengine_{country_id}"
raise ValueError(
f"Unsupported country ID: {country_id}. Supported IDs are: {COUNTRY_IDS}"
)


def get_country_package_version(country_id: str) -> str:

package_name = get_country_package_name(country_id)
return version(package_name)


COUNTRY_IDS = ["us", "uk", "ca", "il", "ng"]

NON_STANDARD_COUNTRY_CODES = {
"ca": "canada",
}

COUNTRY_PACKAGES = [
get_country_package_name(country) for country in COUNTRY_IDS
]
Empty file added tests/__init__.py
Empty file.
Empty file.
19 changes: 19 additions & 0 deletions tests/fixtures/utils/packages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from policyengine.utils.packages import COUNTRY_PACKAGES
import pytest
from unittest.mock import patch

MOCK_VERSION = "MOCK_VERSION"


@pytest.fixture
def patch_importlib_version():
def mock_version(package_name):
if package_name in COUNTRY_PACKAGES:
return MOCK_VERSION
else:
raise Exception(f"Package {package_name} not found")

with patch(
"policyengine.utils.packages.version", side_effect=mock_version
) as mock:
yield mock
Empty file added tests/utils/__init__.py
Empty file.
54 changes: 54 additions & 0 deletions tests/utils/test_packages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import pytest
from policyengine.utils.packages import (
get_country_package_name,
get_country_package_version,
)
from tests.fixtures.utils.packages import (
patch_importlib_version,
MOCK_VERSION,
)


class TestGetCountryPackageName:
def test__given_country_id__then_return_package_name(self):
test_country = "us"

test_package_name = get_country_package_name(test_country)

assert test_package_name == "policyengine_us"

def test__given_non_standard_country_id__then_return_package_name(self):
test_country = "ca"

test_package_name = get_country_package_name(test_country)

assert test_package_name == "policyengine_canada"

def test__given_unsupported_country_id__then_return_package_name(self):
test_country = "zz"

with pytest.raises(Exception, match="Unsupported country ID: zz"):
get_country_package_name(test_country)


class TestGetCountryPackageVersion:
def test__given_package_exists__then_return_version(
self, patch_importlib_version
):
test_country = "us"

test_version = get_country_package_version(test_country)

# Version number defined by mock
assert test_version == MOCK_VERSION

def test__given_package_does_not_exist__then_raise_exception(
self, patch_importlib_version
):
test_country = "zz"

with pytest.raises(
Exception,
match=f"Unsupported country ID: {test_country}. Supported IDs are: ",
):
get_country_package_version(test_country)
Loading