Skip to content

Commit bbcf802

Browse files
Merge pull request #112 from PolicyEngine/nikhilwoodruff/issue111
Add GCP dataset downloads
2 parents e67f4cb + c968bf9 commit bbcf802

File tree

8 files changed

+149
-52
lines changed

8 files changed

+149
-52
lines changed

changelog_entry.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
- bump: minor
2+
changes:
3+
added:
4+
- Google Cloud Storage data downloads.

policyengine/outputs/macro/comparison/calculate_economy_comparison.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from microdf import MicroSeries
44
import numpy as np
5-
from policyengine_core.tools.hugging_face import download_huggingface_dataset
5+
from policyengine.utils.data_download import download
66
import pandas as pd
77
import h5py
88
from pydantic import BaseModel
@@ -709,18 +709,20 @@ def uk_constituency_breakdown(
709709
baseline_hnet = baseline.household_net_income
710710
reform_hnet = reform.household_net_income
711711

712-
constituency_weights_path = download_huggingface_dataset(
713-
repo="policyengine/policyengine-uk-data",
714-
repo_filename="parliamentary_constituency_weights.h5",
712+
constituency_weights_path = download(
713+
huggingface_repo="policyengine-uk-data",
714+
gcs_bucket="policyengine-uk-data-private",
715+
filepath="parliamentary_constituency_weights.h5",
715716
)
716717
with h5py.File(constituency_weights_path, "r") as f:
717718
weights = f["2025"][
718719
...
719720
] # {2025: array(650, 100180) where cell i, j is the weight of household record i in constituency j}
720721

721-
constituency_names_path = download_huggingface_dataset(
722-
repo="policyengine/policyengine-uk-data",
723-
repo_filename="constituencies_2024.csv",
722+
constituency_names_path = download(
723+
huggingface_repo="policyengine-uk-data",
724+
gcs_bucket="policyengine-uk-data-private",
725+
filepath="constituencies_2024.csv",
724726
)
725727
constituency_names = pd.read_csv(
726728
constituency_names_path

policyengine/simulation.py

Lines changed: 45 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""Simulate tax-benefit policy and derive society-level output statistics."""
2-
2+
33
from pydantic import BaseModel, Field
44
from typing import Literal
55
from .constants import DEFAULT_DATASETS_BY_COUNTRY
@@ -10,7 +10,6 @@
1010
from .utils.reforms import ParametricReform
1111
from policyengine_core.reforms import Reform as StructuralReform
1212
from policyengine_core.data import Dataset
13-
from .utils.huggingface import download
1413
from policyengine_us import (
1514
Simulation as USSimulation,
1615
Microsimulation as USMicrosimulation,
@@ -26,6 +25,7 @@
2625
from functools import wraps, partial
2726
from typing import Dict, Any, Callable
2827
import importlib
28+
from policyengine.utils.data_download import download
2929

3030
CountryType = Literal["uk", "us"]
3131
ScopeType = Literal["household", "macro"]
@@ -78,6 +78,7 @@ def __init__(self, **options: SimulationOptions):
7878
self.options.country
7979
]
8080

81+
self._set_data()
8182
self._initialise_simulations()
8283
self._add_output_functions()
8384

@@ -118,7 +119,36 @@ def _set_data(self):
118119
self.options.country
119120
]
120121

121-
self._data_handle_cps_special_case()
122+
if isinstance(self.options.data, str):
123+
filename = self.options.data
124+
if "://" in self.options.data:
125+
bucket = None
126+
hf_repo = None
127+
hf_org = None
128+
if "gs://" in self.options.data:
129+
bucket, filename = self.options.data.split("://")[
130+
-1
131+
].split("/")
132+
elif "hf://" in self.options.data:
133+
hf_org, hf_repo, filename = self.options.data.split("://")[
134+
-1
135+
].split("/", 2)
136+
137+
if not Path(filename).exists():
138+
file_path = download(
139+
filepath=filename,
140+
huggingface_org=hf_org,
141+
huggingface_repo=hf_repo,
142+
gcs_bucket=bucket,
143+
)
144+
filename = str(Path(file_path))
145+
if "cps_2023" in filename:
146+
time_period = 2023
147+
else:
148+
time_period = None
149+
self.options.data = Dataset.from_file(
150+
filename, time_period=time_period
151+
)
122152

123153
def _initialise_simulations(self):
124154
self.baseline_simulation = self._initialise_simulation(
@@ -228,10 +258,9 @@ def _apply_region_to_simulation(
228258
elif "constituency/" in region:
229259
constituency = region.split("/")[1]
230260
constituency_names_file_path = download(
231-
repo="policyengine/policyengine-uk-data",
232-
repo_filename="constituencies_2024.csv",
233-
local_folder=None,
234-
version=None,
261+
huggingface_repo="policyengine-uk-data",
262+
gcs_bucket="policyengine-uk-data-private",
263+
filepath="constituencies_2024.csv",
235264
)
236265
constituency_names_file_path = Path(
237266
constituency_names_file_path
@@ -250,10 +279,9 @@ def _apply_region_to_simulation(
250279
f"Constituency {constituency} not found. See {constituency_names_file_path} for the list of available constituencies."
251280
)
252281
weights_file_path = download(
253-
repo="policyengine/policyengine-uk-data",
254-
repo_filename="parliamentary_constituency_weights.h5",
255-
local_folder=None,
256-
version=None,
282+
huggingface_repo="policyengine-uk-data",
283+
gcs_bucket="policyengine-uk-data-private",
284+
filepath="parliamentary_constituency_weights.h5",
257285
)
258286

259287
with h5py.File(weights_file_path, "r") as f:
@@ -267,10 +295,9 @@ def _apply_region_to_simulation(
267295
elif "local_authority/" in region:
268296
la = region.split("/")[1]
269297
la_names_file_path = download(
270-
repo="policyengine/policyengine-uk-data",
271-
repo_filename="local_authorities_2021.csv",
272-
local_folder=None,
273-
version=None,
298+
huggingface_repo="policyengine-uk-data",
299+
gcs_bucket="policyengine-uk-data-private",
300+
filepath="local_authorities_2021.csv",
274301
)
275302
la_names_file_path = Path(la_names_file_path)
276303
la_names = pd.read_csv(la_names_file_path)
@@ -283,10 +310,9 @@ def _apply_region_to_simulation(
283310
f"Local authority {la} not found. See {la_names_file_path} for the list of available local authorities."
284311
)
285312
weights_file_path = download(
286-
repo="policyengine/policyengine-uk-data",
287-
repo_filename="local_authority_weights.h5",
288-
local_folder=None,
289-
version=None,
313+
huggingface_repo="policyengine-uk-data",
314+
gcs_bucket="policyengine-uk-data-private",
315+
filepath="local_authority_weights.h5",
290316
)
291317

292318
with h5py.File(weights_file_path, "r") as f:
@@ -299,21 +325,3 @@ def _apply_region_to_simulation(
299325
)
300326

301327
return simulation
302-
303-
def _data_handle_cps_special_case(self):
304-
"""Handle special case for CPS data- this data doesn't specify time periods for each variable, but we still use it intensively."""
305-
if self.data is not None and "cps_2023" in self.data:
306-
if "hf://" in self.data:
307-
owner, repo, filename = self.data.split("/")[-3:]
308-
if "@" in filename:
309-
version = filename.split("@")[-1]
310-
filename = filename.split("@")[0]
311-
else:
312-
version = None
313-
self.data = download(
314-
repo=owner + "/" + repo,
315-
repo_filename=filename,
316-
local_folder=None,
317-
version=version,
318-
)
319-
self.data = Dataset.from_file(self.data, "2023")
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from pathlib import Path
2+
import logging
3+
import os
4+
from policyengine.utils.huggingface import download_from_hf
5+
from policyengine.utils.google_cloud_bucket import download_file_from_gcs
6+
from pydantic import BaseModel
7+
8+
9+
class DataFile(BaseModel):
10+
filepath: str
11+
huggingface_org: str
12+
huggingface_repo: str | None = None
13+
gcs_bucket: str | None = None
14+
15+
16+
def download(
17+
filepath: str,
18+
huggingface_repo: str = None,
19+
gcs_bucket: str = None,
20+
huggingface_org: str = "policyengine",
21+
):
22+
data_file = DataFile(
23+
filepath=filepath,
24+
huggingface_org=huggingface_org,
25+
huggingface_repo=huggingface_repo,
26+
gcs_bucket=gcs_bucket,
27+
)
28+
29+
logging.info = print
30+
if Path(filepath).exists():
31+
logging.info(f"File {filepath} already exists. Skipping download.")
32+
return filepath
33+
34+
if data_file.huggingface_repo is not None:
35+
logging.info("Using Hugging Face for download.")
36+
try:
37+
return download_from_hf(
38+
repo=data_file.huggingface_org
39+
+ "/"
40+
+ data_file.huggingface_repo,
41+
repo_filename=data_file.filepath,
42+
)
43+
except:
44+
logging.info("Failed to download from Hugging Face.")
45+
46+
if data_file.gcs_bucket is not None:
47+
logging.info("Using Google Cloud Storage for download.")
48+
download_file_from_gcs(
49+
bucket_name=data_file.gcs_bucket,
50+
file_name=filepath,
51+
destination_path=filepath,
52+
)
53+
return filepath
54+
55+
raise ValueError(
56+
"No valid download method specified. Please provide either a Hugging Face repo or a Google Cloud Storage bucket."
57+
)
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
def download_file_from_gcs(
2+
bucket_name: str, file_name: str, destination_path: str
3+
) -> None:
4+
"""
5+
Download a file from Google Cloud Storage to a local path.
6+
7+
Args:
8+
bucket_name (str): The name of the GCS bucket.
9+
file_name (str): The name of the file in the GCS bucket.
10+
destination_path (str): The local path where the file will be saved.
11+
12+
Returns:
13+
None
14+
"""
15+
from google.cloud import storage
16+
17+
# Initialize a client
18+
client = storage.Client()
19+
20+
# Get the bucket
21+
bucket = client.bucket(bucket_name)
22+
23+
# Create a blob object from the file name
24+
blob = bucket.blob(file_name)
25+
26+
# Download the file to a local path
27+
blob.download_to_filename(destination_path)
28+
29+
return destination_path

policyengine/utils/huggingface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import time
55

66

7-
def download(
7+
def download_from_hf(
88
repo: str,
99
repo_filename: str,
1010
local_folder: str | None = None,

policyengine/utils/maps.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pandas as pd
22
import plotly.express as px
33
import pandas as pd
4-
from policyengine.utils.huggingface import download
4+
from policyengine.utils.data_download import download
55
import plotly.express as px
66
from policyengine.utils.charts import *
77

@@ -10,16 +10,12 @@ def get_location_options_table(location_type: str) -> pd.DataFrame:
1010
if location_type == "parliamentary_constituencies":
1111
area_names_file_path = download(
1212
repo="policyengine/policyengine-uk-data",
13-
repo_filename="constituencies_2024.csv",
14-
local_folder=None,
15-
version=None,
13+
filepath="constituencies_2024.csv",
1614
)
1715
elif location_type == "local_authorities":
1816
area_names_file_path = download(
1917
repo="policyengine/policyengine-uk-data",
20-
repo_filename="local_authorities_2021.csv",
21-
local_folder=None,
22-
version=None,
18+
filepath="local_authorities_2021.csv",
2319
)
2420
df = pd.read_csv(area_names_file_path)
2521
return df

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ dependencies = [
1919
"microdf_python",
2020
"getpass4",
2121
"pydantic",
22+
"google-cloud-storage",
2223
]
2324

2425
[project.optional-dependencies]

0 commit comments

Comments
 (0)