Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
44d2664
Add `versions` argument to `Simulation`
nikhilwoodruff May 23, 2025
7800d58
Add actual data version check
nikhilwoodruff May 23, 2025
0bb9348
Add tests
nikhilwoodruff May 23, 2025
a078b2e
Add Google auth
nikhilwoodruff May 23, 2025
cb93042
Add perms
nikhilwoodruff May 23, 2025
67754f6
Fix bug
nikhilwoodruff May 23, 2025
686dc61
Fix bug
nikhilwoodruff May 23, 2025
fcf8489
Add permissions
nikhilwoodruff May 23, 2025
2180465
Begin removal of HF code and GCS versioning
nikhilwoodruff May 26, 2025
ee076ab
Add changes
nikhilwoodruff May 26, 2025
827776c
Add pip install prompt
nikhilwoodruff May 26, 2025
b0928f8
Minor improvements
nikhilwoodruff May 26, 2025
78cdfa9
Add handling for no metadata version
nikhilwoodruff May 26, 2025
a0de606
Fix some tests
nikhilwoodruff May 26, 2025
9faeeb2
Add dataset constants
nikhilwoodruff May 26, 2025
c3fdbde
Fix str | None
nikhilwoodruff May 26, 2025
8bc2c22
Address comment
nikhilwoodruff May 26, 2025
88ef3e6
Call check package version
nikhilwoodruff May 26, 2025
040d393
Model, not package
nikhilwoodruff May 26, 2025
2aff263
Fix data version check
nikhilwoodruff May 26, 2025
6a94359
Revert log level to debug
nikhilwoodruff May 26, 2025
cce966d
Fix syntax error
nikhilwoodruff May 26, 2025
643f5a5
Fix bugs in tests
nikhilwoodruff May 26, 2025
84456f4
Address comment
nikhilwoodruff May 26, 2025
7885b0d
Add to gitignore
nikhilwoodruff May 26, 2025
e728bf6
Fix tests
nikhilwoodruff May 26, 2025
20ae78b
Nit
nikhilwoodruff May 26, 2025
e9f528c
Nit 2
nikhilwoodruff May 26, 2025
2edf16b
Fix test
nikhilwoodruff May 26, 2025
8e72495
Add tests
nikhilwoodruff May 26, 2025
ec07085
Fix test
nikhilwoodruff May 26, 2025
925ef2d
minor bug fixes
nikhilwoodruff May 26, 2025
ac1fd53
Add permissions to actions
nikhilwoodruff May 26, 2025
2208bcc
Fix test
nikhilwoodruff May 26, 2025
45bc6b7
Fix optional types
nikhilwoodruff May 26, 2025
60b20cf
Use constant
nikhilwoodruff May 26, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions .github/workflows/any_changes.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ on:

jobs:
docs:
permissions:
contents: "read"
id-token: "write"
name: Test documentation builds
runs-on: ubuntu-latest
steps:
Expand All @@ -20,14 +23,16 @@ jobs:
uses: actions/setup-python@v2
with:
python-version: '3.11'
- uses: "google-github-actions/auth@v2"
with:
workload_identity_provider: "projects/322898545428/locations/global/workloadIdentityPools/policyengine-research-id-pool/providers/prod-github-provider"
service_account: "policyengine-research@policyengine-research.iam.gserviceaccount.com"

- name: Install package
run: uv pip install .[dev] --system

- name: Test documentation builds
run: make documentation
env:
HUGGING_FACE_TOKEN: ${{ secrets.HUGGING_FACE_TOKEN }}

- name: Check documentation build
run: |
Expand Down
11 changes: 8 additions & 3 deletions .github/workflows/code_changes.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ jobs:
args: ". -l 79 --check"
Test:
runs-on: ubuntu-latest
permissions:
contents: "read"
id-token: "write"
steps:
- name: Checkout repo
uses: actions/checkout@v2
Expand All @@ -31,11 +34,13 @@ jobs:
uses: actions/setup-python@v2
with:
python-version: '3.11'
- uses: "google-github-actions/auth@v2"
with:
workload_identity_provider: "projects/322898545428/locations/global/workloadIdentityPools/policyengine-research-id-pool/providers/prod-github-provider"
service_account: "policyengine-research@policyengine-research.iam.gserviceaccount.com"

- name: Install package
run: uv pip install .[dev] --system

- name: Run tests
run: make test
env:
HUGGING_FACE_TOKEN: ${{ secrets.HUGGING_FACE_TOKEN }}
run: make test
9 changes: 7 additions & 2 deletions .github/workflows/publish_documentation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ on:

jobs:
Publish:
permissions:
contents: "read"
id-token: "write"
runs-on: ubuntu-latest
steps:
- name: Checkout repo
Expand All @@ -15,15 +18,17 @@ jobs:
uses: actions/setup-python@v5
with:
python-version: 3.12
- uses: "google-github-actions/auth@v2"
with:
workload_identity_provider: "projects/322898545428/locations/global/workloadIdentityPools/policyengine-research-id-pool/providers/prod-github-provider"
service_account: "policyengine-research@policyengine-research.iam.gserviceaccount.com"
- name: Publish a git tag
run: ".github/publish-git-tag.sh || true"
- name: Install package
run: make install

- name: Build documentation
run: make documentation
env:
HUGGING_FACE_TOKEN: ${{ secrets.HUGGING_FACE_TOKEN }}

- name: Deploy documentation
uses: JamesIves/github-pages-deploy-action@releases/v3
Expand Down
2 changes: 0 additions & 2 deletions .github/workflows/publish_package.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ jobs:

- name: Test documentation builds
run: make documentation
env:
HUGGING_FACE_TOKEN: ${{ secrets.HUGGING_FACE_TOKEN }}

- name: Deploy documentation
uses: JamesIves/github-pages-deploy-action@releases/v3
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,6 @@ cython_debug/
*.ipynb

!docs/**/*.ipynb

**/*.h5
**/*.csv
4 changes: 4 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
- bump: minor
changes:
added:
- Error handling for data and package version mismatches.
36 changes: 31 additions & 5 deletions docs/concepts/simulation.ipynb

Large diffs are not rendered by default.

42 changes: 14 additions & 28 deletions policyengine/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,40 +2,26 @@

from policyengine_core.data import Dataset
from policyengine.utils.data_download import download
from typing import Tuple, Optional

# Datasets
ENHANCED_FRS = "hf://policyengine/policyengine-uk-data/enhanced_frs_2022_23.h5"
FRS = "hf://policyengine/policyengine-uk-data/frs_2022_23.h5"
ENHANCED_CPS = "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5"
CPS = "hf://policyengine/policyengine-us-data/cps_2023.h5"
POOLED_CPS = "hf://policyengine/policyengine-us-data/pooled_3_year_cps_2023.h5"
EFRS_2022 = "gcs://policyengine-uk-data-private/enhanced_frs_2022_23.h5"
FRS_2022 = "gcs://policyengine-uk-data-private/frs_2022_23.h5"
CPS_2023_POOLED = "gcs://policyengine-us-data/pooled_3_year_cps_2023.h5"
CPS_2023 = "gcs://policyengine-us-data/cps_2023.h5"
ECPS_2024 = "gcs://policyengine-us-data/ecps_2024.h5"


def get_default_dataset(country: str, region: str):
def get_default_dataset(
country: str, region: str, version: Optional[str] = None
) -> str:
if country == "uk":
data_file = download(
filepath="enhanced_frs_2022_23.h5",
huggingface_repo="policyengine-uk-data",
gcs_bucket="policyengine-uk-data-private",
)
time_period = None
return EFRS_2022
elif country == "us":
if region is not None and region != "us":
data_file = download(
filepath="pooled_3_year_cps_2023.h5",
huggingface_repo="policyengine-us-data",
gcs_bucket="policyengine-us-data",
)
time_period = 2023
return CPS_2023_POOLED
else:
data_file = download(
filepath="cps_2023.h5",
huggingface_repo="policyengine-us-data",
gcs_bucket="policyengine-us-data",
)
time_period = 2023
return CPS_2023

return Dataset.from_file(
file_path=data_file,
time_period=time_period,
raise ValueError(
f"Unable to select a default dataset for country {country} and region {region}."
)
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
from policyengine.outputs.macro.single.calculate_single_economy import (
SingleEconomy,
)
from policyengine.utils.packages import get_country_package_version
from typing import List, Dict
from typing import List, Dict, Optional


class BudgetaryImpact(BaseModel):
Expand Down Expand Up @@ -711,7 +710,6 @@ def uk_constituency_breakdown(
reform_hnet = reform.household_net_income

constituency_weights_path = download(
huggingface_repo="policyengine-uk-data",
gcs_bucket="policyengine-uk-data-private",
filepath="parliamentary_constituency_weights.h5",
)
Expand All @@ -721,7 +719,6 @@ def uk_constituency_breakdown(
] # {2025: array(650, 100180) where cell i, j is the weight of household record i in constituency j}

constituency_names_path = download(
huggingface_repo="policyengine-uk-data",
gcs_bucket="policyengine-uk-data-private",
filepath="constituencies_2024.csv",
)
Expand Down Expand Up @@ -786,7 +783,10 @@ class CliffImpact(BaseModel):


class EconomyComparison(BaseModel):
country_package_version: str
model_version: Optional[str] = (
None # Optional while some datasets have no tagged version.
)
data_version: Optional[str] = None
budget: BudgetaryImpact
detailed_budget: DetailedBudgetaryImpact
decile: DecileImpact
Expand Down Expand Up @@ -849,7 +849,8 @@ def calculate_economy_comparison(
cliff_impact = None

return EconomyComparison(
country_package_version=get_country_package_version(country_id),
model_version=simulation.model_version,
data_version=simulation.data_version,
budget=budgetary_impact_data,
detailed_budget=detailed_budgetary_impact_data,
decile=decile_impact_data,
Expand Down
84 changes: 58 additions & 26 deletions policyengine/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@
Simulation as UKSimulation,
Microsimulation as UKMicrosimulation,
)
from importlib import metadata
import h5py
from pathlib import Path
import pandas as pd
from typing import Type
from typing import Type, Optional
from functools import wraps, partial
from typing import Dict, Any, Callable
import importlib
Expand All @@ -34,8 +35,8 @@
) # Needs stricter typing. Any==policyengine_core.data.Dataset, but pydantic refuses for some reason.
TimePeriodType = int
ReformType = ParametricReform | Type[StructuralReform] | None
RegionType = str | None
SubsampleType = int | None
RegionType = Optional[str]
SubsampleType = Optional[int]


class SimulationOptions(BaseModel):
Expand All @@ -54,14 +55,22 @@ class SimulationOptions(BaseModel):
None,
description="How many, if a subsample, households to randomly simulate.",
)
title: str | None = Field(
title: Optional[str] = Field(
"[Analysis title]",
description="The title of the analysis (for charts). If not provided, a default title will be generated.",
)
include_cliffs: bool | None = Field(
include_cliffs: Optional[bool] = Field(
False,
description="Whether to include tax-benefit cliffs in the simulation analyses. If True, cliffs will be included.",
)
model_version: Optional[str] = Field(
None,
description="The version of the country model used in the simulation. If not provided, the current package version will be used. If provided, this package will throw an error if the package version does not match. Use this as an extra safety check.",
)
data_version: Optional[str] = Field(
None,
description="The version of the data used in the simulation. If not provided, the current data version will be used. If provided, this package will throw an error if the data version does not match. Use this as an extra safety check.",
)


class Simulation:
Expand All @@ -73,12 +82,16 @@ class Simulation:
"""The baseline tax-benefit simulation."""
reform_simulation: CountrySimulation | None = None
"""The reform tax-benefit simulation."""
data_version: Optional[str] = None
"""The version of the data used in the simulation."""
model_version: Optional[str] = None

def __init__(self, **options: SimulationOptions):
self.options = SimulationOptions(**options)

self.check_model_version()
self._set_data()
self._initialise_simulations()
self.check_data_version()
self._add_output_functions()

def _add_output_functions(self):
Expand Down Expand Up @@ -119,29 +132,23 @@ def _set_data(self):
region=self.options.region,
)

elif isinstance(self.options.data, str):
if isinstance(self.options.data, str):
filename = self.options.data
if "://" in self.options.data:
bucket = None
hf_repo = None
hf_org = None
if "gs://" in self.options.data:
bucket, filename = self.options.data.split("://")[
-1
].split("/")
hf_org = "policyengine"
elif "hf://" in self.options.data:
hf_org, hf_repo, filename = self.options.data.split("://")[
-1
].split("/", 2)
if self.options.data[:6] == "gcs://":
bucket, filename = self.options.data.split("://")[-1].split(
"/"
)
version = self.options.data_version

file_path = download(
filepath=filename,
huggingface_org=hf_org,
huggingface_repo=hf_repo,
gcs_bucket=bucket,
version=version,
)
filename = str(Path(file_path))
else:
# If it's a local file, we can't infer the version.
version = None
if "cps_2023" in filename:
time_period = 2023
else:
Expand Down Expand Up @@ -260,7 +267,6 @@ def _apply_region_to_simulation(
elif "constituency/" in region:
constituency = region.split("/")[1]
constituency_names_file_path = download(
huggingface_repo="policyengine-uk-data",
gcs_bucket="policyengine-uk-data-private",
filepath="constituencies_2024.csv",
)
Expand All @@ -281,7 +287,6 @@ def _apply_region_to_simulation(
f"Constituency {constituency} not found. See {constituency_names_file_path} for the list of available constituencies."
)
weights_file_path = download(
huggingface_repo="policyengine-uk-data",
gcs_bucket="policyengine-uk-data-private",
filepath="parliamentary_constituency_weights.h5",
)
Expand All @@ -297,7 +302,6 @@ def _apply_region_to_simulation(
elif "local_authority/" in region:
la = region.split("/")[1]
la_names_file_path = download(
huggingface_repo="policyengine-uk-data",
gcs_bucket="policyengine-uk-data-private",
filepath="local_authorities_2021.csv",
)
Expand All @@ -312,7 +316,6 @@ def _apply_region_to_simulation(
f"Local authority {la} not found. See {la_names_file_path} for the list of available local authorities."
)
weights_file_path = download(
huggingface_repo="policyengine-uk-data",
gcs_bucket="policyengine-uk-data-private",
filepath="local_authority_weights.h5",
)
Expand All @@ -327,3 +330,32 @@ def _apply_region_to_simulation(
)

return simulation

def check_model_version(self) -> None:
"""
Check the package versions of the simulation against the current package versions.
"""
if self.options.model_version is not None:
target_version = self.options.model_version
package = f"policyengine-{self.options.country}"
try:
installed_version = metadata.version(package)
self.model_version = installed_version
except metadata.PackageNotFoundError:
raise ValueError(
f"Package {package} not found. Try running `pip install {package}`."
)
if installed_version != target_version:
raise ValueError(
f"Package {package} version {installed_version} does not match expected version {target_version}. Try running `pip install {package}=={target_version}`."
)

def check_data_version(self) -> None:
"""
Check the data versions of the simulation against the current data versions.
"""
if self.options.data_version is not None:
if self.data_version != self.options.data_version:
raise ValueError(
f"Data version {self.data_version} does not match expected version {self.options.data_version}."
)
Loading
Loading