11"""Simulate tax-benefit policy and derive society-level output statistics."""
22
3+ import sys
34from pydantic import BaseModel , Field
45from typing import Literal
5- from .constants import get_default_dataset
6+ from .utils .data .datasets import (
7+ get_default_dataset ,
8+ process_gs_path ,
9+ POLICYENGINE_DATASETS ,
10+ DATASET_TIME_PERIODS ,
11+ )
612from policyengine_core .simulations import Simulation as CountrySimulation
713from policyengine_core .simulations import (
814 Microsimulation as CountryMicrosimulation ,
2228import h5py
2329from pathlib import Path
2430import pandas as pd
25- from typing import Type , Optional
31+ from typing import Type , Any , Optional
2632from functools import wraps , partial
27- from typing import Dict , Any , Callable
33+ from typing import Callable
2834import importlib
2935from policyengine .utils .data_download import download
3036
3137CountryType = Literal ["uk" , "us" ]
3238ScopeType = Literal ["household" , "macro" ]
3339DataType = (
34- str | dict | Any | None
40+ str | dict [ Any , Any ] | Dataset | None
3541) # Needs stricter typing. Any==policyengine_core.data.Dataset, but pydantic refuses for some reason.
3642TimePeriodType = int
3743ReformType = ParametricReform | Type [StructuralReform ] | None
@@ -72,6 +78,10 @@ class SimulationOptions(BaseModel):
7278 description = "The version of the data used in the simulation. If not provided, the current data version will be used. If provided, this package will throw an error if the data version does not match. Use this as an extra safety check." ,
7379 )
7480
81+ model_config = {
82+ "arbitrary_types_allowed" : True ,
83+ }
84+
7585
7686class Simulation :
7787 """Simulate tax-benefit policy and derive society-level output statistics."""
@@ -89,7 +99,10 @@ class Simulation:
8999 def __init__ (self , ** options : SimulationOptions ):
90100 self .options = SimulationOptions (** options )
91101 self .check_model_version ()
92- self ._set_data ()
102+ if not isinstance (self .options .data , dict ) and not isinstance (
103+ self .options .data , Dataset
104+ ):
105+ self ._set_data (self .options .data )
93106 self ._initialise_simulations ()
94107 self .check_data_version ()
95108 self ._add_output_functions ()
@@ -125,39 +138,37 @@ def _add_output_functions(self):
125138 wrapped_func ,
126139 )
127140
128- def _set_data (self ):
129- if self .options .data is None :
130- self .options .data = get_default_dataset (
131- country = self .options .country ,
132- region = self .options .region ,
133- )
141+ def _set_data (self , file_address : str | None = None ) -> None :
134142
135- if isinstance (self .options .data , str ):
136- filename = self .options .data
137- if self .options .data [:6 ] == "gcs://" :
138- bucket , filename = self .options .data .split ("://" )[- 1 ].split (
139- "/"
140- )
141- version = self .options .data_version
143+ # filename refers to file's unique name + extension;
144+ # file_address refers to URI + filename
142145
143- file_path , version = download (
144- filepath = filename ,
145- gcs_bucket = bucket ,
146- version = version ,
147- return_version = True ,
148- )
149- self .data_version = version
150- filename = str (Path (file_path ))
151- else :
152- # If it's a local file, we can't infer the version.
153- version = None
154- if "cps_2023" in filename :
155- time_period = 2023
156- else :
157- time_period = None
158- self .options .data = Dataset .from_file (
159- filename , time_period = time_period
146+ # If None is passed, user wants default dataset; get URL, then continue initializing.
147+ if file_address is None :
148+ file_address = get_default_dataset (
149+ country = self .options .country , region = self .options .region
160150 )
151+ print (
152+ f"No data provided, using default dataset: { file_address } " ,
153+ file = sys .stderr ,
154+ )
155+
156+ if file_address not in POLICYENGINE_DATASETS :
157+ # If it's a local file, no URI present and unable to infer version.
158+ filename = file_address
159+ version = None
160+
161+ else :
162+ # All official PolicyEngine datasets are stored in GCS;
163+ # load accordingly
164+ filename , version = self ._set_data_from_gs (file_address )
165+ self .data_version = version
166+
167+ time_period = self ._set_data_time_period (file_address )
168+
169+ self .options .data = Dataset .from_file (
170+ filename , time_period = time_period
171+ )
161172
162173 def _initialise_simulations (self ):
163174 self .baseline_simulation = self ._initialise_simulation (
@@ -361,3 +372,34 @@ def check_data_version(self) -> None:
361372 raise ValueError (
362373 f"Data version { self .data_version } does not match expected version { self .options .data_version } ."
363374 )
375+
376+ def _set_data_time_period (self , file_address : str ) -> Optional [int ]:
377+ """
378+ Set the time period based on the file address.
379+ If the file address is a PE dataset, return the time period from the dataset.
380+ If it's a local file, return None.
381+ """
382+ if file_address in DATASET_TIME_PERIODS :
383+ return DATASET_TIME_PERIODS [file_address ]
384+ else :
385+ # Local file, no time period available
386+ return None
387+
388+ def _set_data_from_gs (self , file_address : str ) -> tuple [str , str | None ]:
389+ """
390+ Set the data from a GCS path and return the filename and version.
391+ """
392+
393+ bucket , filename = process_gs_path (file_address )
394+ version = self .options .data_version
395+
396+ print (f"Downloading { filename } from bucket { bucket } " , file = sys .stderr )
397+
398+ filepath , version = download (
399+ filepath = filename ,
400+ gcs_bucket = bucket ,
401+ version = version ,
402+ return_version = True ,
403+ )
404+
405+ return filename , version
0 commit comments