Skip to content

Commit 642952b

Browse files
committed
fix: Redo package version management
1 parent 9f2ee8b commit 642952b

File tree

6 files changed

+85
-83
lines changed

6 files changed

+85
-83
lines changed

policyengine/constants.py

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3,40 +3,7 @@
33
from policyengine_core.data import Dataset
44
from policyengine.utils.data_download import download
55

6-
SUPPORTED_COUNTRY_IDS = [
7-
"us",
8-
"uk",
9-
]
10-
11-
UNSUPPORTED_COUNTRY_IDS = [
12-
"ca",
13-
"il",
14-
"ng",
15-
]
16-
17-
SUPPORTED_COUNTRY_PACKAGES = [
18-
f"policyengine_{country}" for country in SUPPORTED_COUNTRY_IDS
19-
]
20-
21-
22-
def _package_name_for(country_id: str) -> str:
23-
return (
24-
f"policyengine_{country_id}"
25-
if country_id != "ca"
26-
else "policyengine_canada"
27-
)
28-
29-
30-
UNSUPPORTED_COUNTRY_PACKAGES = [
31-
_package_name_for(country) for country in UNSUPPORTED_COUNTRY_IDS
32-
]
33-
34-
ALL_COUNTRY_PACKAGES = (
35-
SUPPORTED_COUNTRY_PACKAGES + UNSUPPORTED_COUNTRY_PACKAGES
36-
)
37-
386
# Datasets
39-
407
ENHANCED_FRS = "hf://policyengine/policyengine-uk-data/enhanced_frs_2022_23.h5"
418
FRS = "hf://policyengine/policyengine-uk-data/frs_2022_23.h5"
429
ENHANCED_CPS = "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5"

policyengine/utils/packages.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from importlib.metadata import version
2+
3+
4+
def get_country_package_name(country_id: str) -> str:
5+
if country_id in NON_STANDARD_COUNTRY_CODES:
6+
return f"policyengine_{NON_STANDARD_COUNTRY_CODES[country_id]}"
7+
if country_id in COUNTRY_IDS:
8+
return f"policyengine_{country_id}"
9+
raise ValueError(
10+
f"Unsupported country ID: {country_id}. Supported IDs are: {COUNTRY_IDS}"
11+
)
12+
13+
14+
def get_country_package_version(country_id: str) -> str:
15+
16+
package_name = get_country_package_name(country_id)
17+
return version(package_name)
18+
19+
20+
COUNTRY_IDS = ["us", "uk", "ca", "il", "ng"]
21+
22+
NON_STANDARD_COUNTRY_CODES = {
23+
"ca": "canada",
24+
}
25+
26+
COUNTRY_PACKAGES = [
27+
get_country_package_name(country) for country in COUNTRY_IDS
28+
]

policyengine/utils/versioning.py

Lines changed: 0 additions & 14 deletions
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
from policyengine.constants import ALL_COUNTRY_PACKAGES
2-
from importlib.metadata import PackageNotFoundError
1+
from policyengine.utils.packages import COUNTRY_PACKAGES
32
import pytest
43
from unittest.mock import patch
54

@@ -9,12 +8,12 @@
98
@pytest.fixture
109
def patch_importlib_version():
1110
def mock_version(package_name):
12-
if package_name in ALL_COUNTRY_PACKAGES:
11+
if package_name in COUNTRY_PACKAGES:
1312
return MOCK_VERSION
1413
else:
1514
raise Exception(f"Package {package_name} not found")
1615

1716
with patch(
18-
"policyengine.utils.versioning.version", side_effect=mock_version
17+
"policyengine.utils.packages.version", side_effect=mock_version
1918
) as mock:
2019
yield mock

tests/utils/test_packages.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import pytest
2+
from policyengine.utils.packages import (
3+
get_country_package_name,
4+
get_country_package_version,
5+
)
6+
from tests.fixtures.utils.packages import (
7+
patch_importlib_version,
8+
MOCK_VERSION,
9+
)
10+
11+
12+
class TestGetCountryPackageName:
13+
def test__given_country_id__then_return_package_name(self):
14+
test_country = "us"
15+
16+
test_package_name = get_country_package_name(test_country)
17+
18+
assert test_package_name == "policyengine_us"
19+
20+
def test__given_non_standard_country_id__then_return_package_name(self):
21+
test_country = "ca"
22+
23+
test_package_name = get_country_package_name(test_country)
24+
25+
assert test_package_name == "policyengine_canada"
26+
27+
def test__given_unsupported_country_id__then_return_package_name(self):
28+
test_country = "zz"
29+
30+
with pytest.raises(Exception, match="Unsupported country ID: zz"):
31+
get_country_package_name(test_country)
32+
33+
34+
class TestGetCountryPackageVersion:
35+
def test__given_package_exists__then_return_version(
36+
self, patch_importlib_version
37+
):
38+
test_country = "us"
39+
40+
test_version = get_country_package_version(test_country)
41+
42+
# Version number defined by mock
43+
assert test_version == MOCK_VERSION
44+
45+
def test__given_package_does_not_exist__then_raise_exception(
46+
self, patch_importlib_version
47+
):
48+
test_country = "zz"
49+
50+
with pytest.raises(
51+
Exception,
52+
match=f"Unsupported country ID: {test_country}. Supported IDs are: ",
53+
):
54+
get_country_package_version(test_country)

tests/utils/test_versioning.py

Lines changed: 0 additions & 32 deletions
This file was deleted.

0 commit comments

Comments
 (0)