Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
44d2664
Add `versions` argument to `Simulation`
nikhilwoodruff May 23, 2025
7800d58
Add actual data version check
nikhilwoodruff May 23, 2025
0bb9348
Add tests
nikhilwoodruff May 23, 2025
a078b2e
Add Google auth
nikhilwoodruff May 23, 2025
cb93042
Add perms
nikhilwoodruff May 23, 2025
67754f6
Fix bug
nikhilwoodruff May 23, 2025
686dc61
Fix bug
nikhilwoodruff May 23, 2025
fcf8489
Add permissions
nikhilwoodruff May 23, 2025
2180465
Begin removal of HF code and GCS versioning
nikhilwoodruff May 26, 2025
ee076ab
Add changes
nikhilwoodruff May 26, 2025
827776c
Add pip install prompt
nikhilwoodruff May 26, 2025
b0928f8
Minor improvements
nikhilwoodruff May 26, 2025
78cdfa9
Add handling for no metadata version
nikhilwoodruff May 26, 2025
a0de606
Fix some tests
nikhilwoodruff May 26, 2025
9faeeb2
Add dataset constants
nikhilwoodruff May 26, 2025
c3fdbde
Fix str | None
nikhilwoodruff May 26, 2025
8bc2c22
Address comment
nikhilwoodruff May 26, 2025
88ef3e6
Call check package version
nikhilwoodruff May 26, 2025
040d393
Model, not package
nikhilwoodruff May 26, 2025
2aff263
Fix data version check
nikhilwoodruff May 26, 2025
6a94359
Revert log level to debug
nikhilwoodruff May 26, 2025
cce966d
Fix syntax error
nikhilwoodruff May 26, 2025
643f5a5
Fix bugs in tests
nikhilwoodruff May 26, 2025
84456f4
Address comment
nikhilwoodruff May 26, 2025
7885b0d
Add to gitignore
nikhilwoodruff May 26, 2025
e728bf6
Fix tests
nikhilwoodruff May 26, 2025
20ae78b
Nit
nikhilwoodruff May 26, 2025
e9f528c
Nit 2
nikhilwoodruff May 26, 2025
2edf16b
Fix test
nikhilwoodruff May 26, 2025
8e72495
Add tests
nikhilwoodruff May 26, 2025
ec07085
Fix test
nikhilwoodruff May 26, 2025
925ef2d
minor bug fixes
nikhilwoodruff May 26, 2025
ac1fd53
Add permissions to actions
nikhilwoodruff May 26, 2025
2208bcc
Fix test
nikhilwoodruff May 26, 2025
45bc6b7
Fix optional types
nikhilwoodruff May 26, 2025
60b20cf
Use constant
nikhilwoodruff May 26, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .github/workflows/code_changes.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ jobs:
args: ". -l 79 --check"
Test:
runs-on: ubuntu-latest
permissions:
contents: "read"
id-token: "write"
steps:
- name: Checkout repo
uses: actions/checkout@v2
Expand All @@ -31,6 +34,10 @@ jobs:
uses: actions/setup-python@v2
with:
python-version: '3.11'
- uses: "google-github-actions/auth@v2"
with:
workload_identity_provider: "projects/322898545428/locations/global/workloadIdentityPools/policyengine-research-id-pool/providers/prod-github-provider"
service_account: "policyengine-research@policyengine-research.iam.gserviceaccount.com"

- name: Install package
run: uv pip install .[dev] --system
Expand Down
7 changes: 7 additions & 0 deletions .github/workflows/publish_documentation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ on:

jobs:
Publish:
permissions:
contents: "read"
id-token: "write"
runs-on: ubuntu-latest
steps:
- name: Checkout repo
Expand All @@ -15,6 +18,10 @@ jobs:
uses: actions/setup-python@v5
with:
python-version: 3.12
- uses: "google-github-actions/auth@v2"
with:
workload_identity_provider: "projects/322898545428/locations/global/workloadIdentityPools/policyengine-research-id-pool/providers/prod-github-provider"
service_account: "policyengine-research@policyengine-research.iam.gserviceaccount.com"
- name: Publish a git tag
run: ".github/publish-git-tag.sh || true"
- name: Install package
Expand Down
4 changes: 4 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
- bump: minor
changes:
added:
- Error handling for data and package version mismatches.
28 changes: 14 additions & 14 deletions policyengine/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,40 +2,40 @@

from policyengine_core.data import Dataset
from policyengine.utils.data_download import download
from typing import Tuple

# Datasets
ENHANCED_FRS = "hf://policyengine/policyengine-uk-data/enhanced_frs_2022_23.h5"
FRS = "hf://policyengine/policyengine-uk-data/frs_2022_23.h5"
ENHANCED_CPS = "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5"
CPS = "hf://policyengine/policyengine-us-data/cps_2023.h5"
POOLED_CPS = "hf://policyengine/policyengine-us-data/pooled_3_year_cps_2023.h5"


def get_default_dataset(country: str, region: str):
def get_default_dataset(country: str, region: str) -> Tuple[Dataset, str]:
if country == "uk":
data_file = download(
data_file, version = download(
filepath="enhanced_frs_2022_23.h5",
huggingface_repo="policyengine-uk-data",
gcs_bucket="policyengine-uk-data-private",
return_version=True,
)
time_period = None
elif country == "us":
if region is not None and region != "us":
data_file = download(
data_file, version = download(
filepath="pooled_3_year_cps_2023.h5",
huggingface_repo="policyengine-us-data",
gcs_bucket="policyengine-us-data",
return_version=True,
)
time_period = 2023
else:
data_file = download(
data_file, version = download(
filepath="cps_2023.h5",
huggingface_repo="policyengine-us-data",
gcs_bucket="policyengine-us-data",
return_version=True,
)
time_period = 2023

return Dataset.from_file(
file_path=data_file,
time_period=time_period,
return (
Dataset.from_file(
file_path=data_file,
time_period=time_period,
),
version,
)
51 changes: 49 additions & 2 deletions policyengine/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Simulation as UKSimulation,
Microsimulation as UKMicrosimulation,
)
from importlib import metadata
import h5py
from pathlib import Path
import pandas as pd
Expand Down Expand Up @@ -62,6 +63,14 @@ class SimulationOptions(BaseModel):
False,
description="Whether to include tax-benefit cliffs in the simulation analyses. If True, cliffs will be included.",
)
package_versions: Dict[str, str] | None = Field(
None,
description="The versions of the packages used in the simulation. If not provided, the current package versions will be used. If provided, this package will throw an error if the package versions do not match.",
)
data_versions: Dict[str, str] | None = Field(
None,
description="The versions of the data used in the simulation. If not provided, the current data versions will be used. If provided, this package will throw an error if the data versions do not match.",
)


class Simulation:
Expand All @@ -73,6 +82,7 @@ class Simulation:
"""The baseline tax-benefit simulation."""
reform_simulation: CountrySimulation | None = None
"""The reform tax-benefit simulation."""
data_versions: Dict[str, str] | None = None

def __init__(self, **options: SimulationOptions):
self.options = SimulationOptions(**options)
Expand Down Expand Up @@ -113,8 +123,9 @@ def _add_output_functions(self):
)

def _set_data(self):
self.data_versions = {}
if self.options.data is None:
self.options.data = get_default_dataset(
self.options.data, version = get_default_dataset(
country=self.options.country,
region=self.options.region,
)
Expand All @@ -135,13 +146,17 @@ def _set_data(self):
-1
].split("/", 2)

file_path = download(
file_path, version = download(
filepath=filename,
huggingface_org=hf_org,
huggingface_repo=hf_repo,
gcs_bucket=bucket,
return_version=True,
)
filename = str(Path(file_path))
else:
# If it's a local file, we can't infer the version.
version = None
if "cps_2023" in filename:
time_period = 2023
else:
Expand All @@ -150,6 +165,8 @@ def _set_data(self):
filename, time_period=time_period
)

self.data_versions[self.options.data.file_path.name] = version

def _initialise_simulations(self):
self.baseline_simulation = self._initialise_simulation(
scope=self.options.scope,
Expand Down Expand Up @@ -327,3 +344,33 @@ def _apply_region_to_simulation(
)

return simulation

def check_package_versions(self) -> None:
"""
Check the package versions of the simulation against the current package versions.
"""
if self.options.package_versions is not None:
for package, version in self.options.package_versions.items():
try:
installed_version = metadata.version(package)
except metadata.PackageNotFoundError:
raise ValueError(f"Package {package} not found.")
if installed_version != version:
raise ValueError(
f"Package {package} version {installed_version} does not match expected version {version}."
)

def check_data_versions(self) -> None:
"""
Check the data versions of the simulation against the current data versions.
"""
if self.options.data_versions is not None:
for file, version in self.options.data_versions.items():
if file not in self.data_versions:
raise ValueError(
f"Data file {file} not found in simulation."
)
if self.data_versions[file] != version:
raise ValueError(
f"Data file {file} version {self.data_versions[file]} does not match expected version {version}."
)
28 changes: 26 additions & 2 deletions policyengine/utils/data_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from policyengine.utils.huggingface import download_from_hf
from policyengine.utils.google_cloud_bucket import download_file_from_gcs
from pydantic import BaseModel
import json
from typing import Tuple


class DataFile(BaseModel):
Expand All @@ -18,7 +20,8 @@ def download(
huggingface_repo: str = None,
gcs_bucket: str = None,
huggingface_org: str = "policyengine",
):
return_version: bool = False,
) -> str | Tuple[str, str]:
data_file = DataFile(
filepath=filepath,
huggingface_org=huggingface_org,
Expand All @@ -31,12 +34,24 @@ def download(
if data_file.huggingface_repo is not None:
logging.info("Using Hugging Face for download.")
try:
return download_from_hf(
data = download_from_hf(
repo=data_file.huggingface_org
+ "/"
+ data_file.huggingface_repo,
repo_filename=data_file.filepath,
)
if return_version:
version_file = download_from_hf(
repo=data_file.huggingface_org
+ "/"
+ data_file.huggingface_repo,
repo_filename="version.json",
return_version=True,
)
with open(version_file, "r") as f:
version = json.load(f).get("version")
return data, version
return data
except:
logging.info("Failed to download from Hugging Face.")

Expand All @@ -47,6 +62,15 @@ def download(
file_name=filepath,
destination_path=filepath,
)
if return_version:
download_file_from_gcs(
bucket_name=data_file.gcs_bucket,
file_name="version.json",
destination_path="version.json",
)
with open("version.json", "r") as f:
version = json.load(f).get("version")
return filepath, version
return filepath

raise ValueError(
Expand Down
42 changes: 42 additions & 0 deletions tests/country/test_uk.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,45 @@ def test_uk_macro_comparison():
)

sim.calculate_economy_comparison()


def test_uk_macro_bad_package_versions_fail():
from policyengine import Simulation

try:
sim = Simulation(
scope="macro",
country="uk",
reform={
"gov.hmrc.income_tax.allowances.personal_allowance.amount": 15_000,
},
package_versions={
"policyengine-uk": "-1.0.0",
},
)
raise ValueError(
"Simulation should have failed with a bad package version."
)
except:
pass


def test_uk_macro_bad_data_versions_fail():
from policyengine import Simulation

try:
sim = Simulation(
scope="macro",
country="uk",
reform={
"gov.hmrc.income_tax.allowances.personal_allowance.amount": 15_000,
},
data_versions={
"enhanced_frs_2022_23.h5": "-1.0.0",
},
)
raise ValueError(
"Simulation should have failed with a bad data version."
)
except:
pass
Loading