Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
- bump: minor
changes:
added:
- Google Cloud Storage data downloads.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from microdf import MicroSeries
import numpy as np
from policyengine_core.tools.hugging_face import download_huggingface_dataset
from policyengine.utils.data_download import download
import pandas as pd
import h5py
from pydantic import BaseModel
Expand Down Expand Up @@ -709,18 +709,20 @@ def uk_constituency_breakdown(
baseline_hnet = baseline.household_net_income
reform_hnet = reform.household_net_income

constituency_weights_path = download_huggingface_dataset(
repo="policyengine/policyengine-uk-data",
repo_filename="parliamentary_constituency_weights.h5",
constituency_weights_path = download(
huggingface_repo="policyengine-uk-data",
gcs_bucket="policyengine-uk-data-private",
filepath="parliamentary_constituency_weights.h5",
)
with h5py.File(constituency_weights_path, "r") as f:
weights = f["2025"][
...
] # {2025: array(650, 100180) where cell i, j is the weight of household record i in constituency j}

constituency_names_path = download_huggingface_dataset(
repo="policyengine/policyengine-uk-data",
repo_filename="constituencies_2024.csv",
constituency_names_path = download(
huggingface_repo="policyengine-uk-data",
gcs_bucket="policyengine-uk-data-private",
filepath="constituencies_2024.csv",
)
constituency_names = pd.read_csv(
constituency_names_path
Expand Down
82 changes: 45 additions & 37 deletions policyengine/simulation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Simulate tax-benefit policy and derive society-level output statistics."""

from pydantic import BaseModel, Field
from typing import Literal
from .constants import DEFAULT_DATASETS_BY_COUNTRY
Expand All @@ -10,7 +10,6 @@
from .utils.reforms import ParametricReform
from policyengine_core.reforms import Reform as StructuralReform
from policyengine_core.data import Dataset
from .utils.huggingface import download
from policyengine_us import (
Simulation as USSimulation,
Microsimulation as USMicrosimulation,
Expand All @@ -26,6 +25,7 @@
from functools import wraps, partial
from typing import Dict, Any, Callable
import importlib
from policyengine.utils.data_download import download

CountryType = Literal["uk", "us"]
ScopeType = Literal["household", "macro"]
Expand Down Expand Up @@ -78,6 +78,7 @@ def __init__(self, **options: SimulationOptions):
self.options.country
]

self._set_data()
self._initialise_simulations()
self._add_output_functions()

Expand Down Expand Up @@ -118,7 +119,36 @@ def _set_data(self):
self.options.country
]

self._data_handle_cps_special_case()
if isinstance(self.options.data, str):
filename = self.options.data
if "://" in self.options.data:
bucket = None
hf_repo = None
hf_org = None
if "gs://" in self.options.data:
bucket, filename = self.options.data.split("://")[
-1
].split("/")
elif "hf://" in self.options.data:
hf_org, hf_repo, filename = self.options.data.split("://")[
-1
].split("/", 2)

if not Path(filename).exists():
file_path = download(
filepath=filename,
huggingface_org=hf_org,
huggingface_repo=hf_repo,
gcs_bucket=bucket,
)
filename = str(Path(file_path))
if "cps_2023" in filename:
time_period = 2023
else:
time_period = None
self.options.data = Dataset.from_file(
filename, time_period=time_period
)

def _initialise_simulations(self):
self.baseline_simulation = self._initialise_simulation(
Expand Down Expand Up @@ -228,10 +258,9 @@ def _apply_region_to_simulation(
elif "constituency/" in region:
constituency = region.split("/")[1]
constituency_names_file_path = download(
repo="policyengine/policyengine-uk-data",
repo_filename="constituencies_2024.csv",
local_folder=None,
version=None,
huggingface_repo="policyengine-uk-data",
gcs_bucket="policyengine-uk-data-private",
filepath="constituencies_2024.csv",
)
constituency_names_file_path = Path(
constituency_names_file_path
Expand All @@ -250,10 +279,9 @@ def _apply_region_to_simulation(
f"Constituency {constituency} not found. See {constituency_names_file_path} for the list of available constituencies."
)
weights_file_path = download(
repo="policyengine/policyengine-uk-data",
repo_filename="parliamentary_constituency_weights.h5",
local_folder=None,
version=None,
huggingface_repo="policyengine-uk-data",
gcs_bucket="policyengine-uk-data-private",
filepath="parliamentary_constituency_weights.h5",
)

with h5py.File(weights_file_path, "r") as f:
Expand All @@ -267,10 +295,9 @@ def _apply_region_to_simulation(
elif "local_authority/" in region:
la = region.split("/")[1]
la_names_file_path = download(
repo="policyengine/policyengine-uk-data",
repo_filename="local_authorities_2021.csv",
local_folder=None,
version=None,
huggingface_repo="policyengine-uk-data",
gcs_bucket="policyengine-uk-data-private",
filepath="local_authorities_2021.csv",
)
la_names_file_path = Path(la_names_file_path)
la_names = pd.read_csv(la_names_file_path)
Expand All @@ -283,10 +310,9 @@ def _apply_region_to_simulation(
f"Local authority {la} not found. See {la_names_file_path} for the list of available local authorities."
)
weights_file_path = download(
repo="policyengine/policyengine-uk-data",
repo_filename="local_authority_weights.h5",
local_folder=None,
version=None,
huggingface_repo="policyengine-uk-data",
gcs_bucket="policyengine-uk-data-private",
filepath="local_authority_weights.h5",
)

with h5py.File(weights_file_path, "r") as f:
Expand All @@ -299,21 +325,3 @@ def _apply_region_to_simulation(
)

return simulation

def _data_handle_cps_special_case(self):
"""Handle special case for CPS data- this data doesn't specify time periods for each variable, but we still use it intensively."""
if self.data is not None and "cps_2023" in self.data:
if "hf://" in self.data:
owner, repo, filename = self.data.split("/")[-3:]
if "@" in filename:
version = filename.split("@")[-1]
filename = filename.split("@")[0]
else:
version = None
self.data = download(
repo=owner + "/" + repo,
repo_filename=filename,
local_folder=None,
version=version,
)
self.data = Dataset.from_file(self.data, "2023")
57 changes: 57 additions & 0 deletions policyengine/utils/data_download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from pathlib import Path
import logging
import os
from policyengine.utils.huggingface import download_from_hf
from policyengine.utils.google_cloud_bucket import download_file_from_gcs
from pydantic import BaseModel


class DataFile(BaseModel):
filepath: str
huggingface_org: str
huggingface_repo: str | None = None
gcs_bucket: str | None = None


def download(
filepath: str,
huggingface_repo: str = None,
gcs_bucket: str = None,
huggingface_org: str = "policyengine",
):
data_file = DataFile(
filepath=filepath,
huggingface_org=huggingface_org,
huggingface_repo=huggingface_repo,
gcs_bucket=gcs_bucket,
)

logging.info = print
if Path(filepath).exists():
logging.info(f"File {filepath} already exists. Skipping download.")
return filepath

if data_file.huggingface_repo is not None:
logging.info("Using Hugging Face for download.")
try:
return download_from_hf(
repo=data_file.huggingface_org
+ "/"
+ data_file.huggingface_repo,
repo_filename=data_file.filepath,
)
except:
logging.info("Failed to download from Hugging Face.")

if data_file.gcs_bucket is not None:
logging.info("Using Google Cloud Storage for download.")
download_file_from_gcs(
bucket_name=data_file.gcs_bucket,
file_name=filepath,
destination_path=filepath,
)
return filepath

raise ValueError(
"No valid download method specified. Please provide either a Hugging Face repo or a Google Cloud Storage bucket."
)
29 changes: 29 additions & 0 deletions policyengine/utils/google_cloud_bucket.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
def download_file_from_gcs(
bucket_name: str, file_name: str, destination_path: str
) -> None:
"""
Download a file from Google Cloud Storage to a local path.

Args:
bucket_name (str): The name of the GCS bucket.
file_name (str): The name of the file in the GCS bucket.
destination_path (str): The local path where the file will be saved.

Returns:
None
"""
from google.cloud import storage

# Initialize a client
client = storage.Client()

# Get the bucket
bucket = client.bucket(bucket_name)

# Create a blob object from the file name
blob = bucket.blob(file_name)

# Download the file to a local path
blob.download_to_filename(destination_path)

return destination_path
2 changes: 1 addition & 1 deletion policyengine/utils/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import time


def download(
def download_from_hf(
repo: str,
repo_filename: str,
local_folder: str | None = None,
Expand Down
10 changes: 3 additions & 7 deletions policyengine/utils/maps.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pandas as pd
import plotly.express as px
import pandas as pd
from policyengine.utils.huggingface import download
from policyengine.utils.data_download import download
import plotly.express as px
from policyengine.utils.charts import *

Expand All @@ -10,16 +10,12 @@ def get_location_options_table(location_type: str) -> pd.DataFrame:
if location_type == "parliamentary_constituencies":
area_names_file_path = download(
repo="policyengine/policyengine-uk-data",
repo_filename="constituencies_2024.csv",
local_folder=None,
version=None,
filepath="constituencies_2024.csv",
)
elif location_type == "local_authorities":
area_names_file_path = download(
repo="policyengine/policyengine-uk-data",
repo_filename="local_authorities_2021.csv",
local_folder=None,
version=None,
filepath="local_authorities_2021.csv",
)
df = pd.read_csv(area_names_file_path)
return df
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ dependencies = [
"microdf_python",
"getpass4",
"pydantic",
"google-cloud-storage",
]

[project.optional-dependencies]
Expand Down
Loading