11"""Simulate tax-benefit policy and derive society-level output statistics."""
22
3+ import sys
34from pydantic import BaseModel , Field
45from typing import Literal
5- from .constants import get_default_dataset
6+ from .utils . data . datasets import get_default_dataset , process_gs_path , POLICYENGINE_DATASETS , DATASET_TIME_PERIODS
67from policyengine_core .simulations import Simulation as CountrySimulation
78from policyengine_core .simulations import (
89 Microsimulation as CountryMicrosimulation ,
3132CountryType = Literal ["uk" , "us" ]
3233ScopeType = Literal ["household" , "macro" ]
3334DataType = (
34- str | dict | Any | None
35- ) # Needs stricter typing. Any==policyengine_core.data.Dataset, but pydantic refuses for some reason.
35+ str | Dataset | None
36+ )
3637TimePeriodType = int
3738ReformType = ParametricReform | Type [StructuralReform ] | None
3839RegionType = Optional [str ]
@@ -72,6 +73,10 @@ class SimulationOptions(BaseModel):
7273 description = "The version of the data used in the simulation. If not provided, the current data version will be used. If provided, this package will throw an error if the data version does not match. Use this as an extra safety check." ,
7374 )
7475
76+ model_config = {
77+ "arbitrary_types_allowed" : True ,
78+ }
79+
7580
7681class Simulation :
7782 """Simulate tax-benefit policy and derive society-level output statistics."""
@@ -89,7 +94,8 @@ class Simulation:
8994 def __init__ (self , ** options : SimulationOptions ):
9095 self .options = SimulationOptions (** options )
9196 self .check_model_version ()
92- self ._set_data ()
97+ if not isinstance (self .options .data , Dataset ):
98+ self ._set_data (self .options .data )
9399 self ._initialise_simulations ()
94100 self .check_data_version ()
95101 self ._add_output_functions ()
@@ -125,39 +131,42 @@ def _add_output_functions(self):
125131 wrapped_func ,
126132 )
127133
128- def _set_data (self ):
129- if self .options .data is None :
130- self .options .data = get_default_dataset (
134+ def _set_data (self , file_address : str | None = None ) -> None :
135+
136+ # filename refers to file's unique name + extension;
137+ # file_address refers to URI + filename
138+
139+ # If None is passed, user wants default dataset; get URL, then continue initializing.
140+ if file_address is None :
141+ file_address = get_default_dataset (
131142 country = self .options .country ,
132- region = self .options .region ,
143+ region = self .options .region
144+ )
145+ print (
146+ f"No data provided, using default dataset: { file_address } " ,
147+ file = sys .stderr ,
133148 )
134149
135- if isinstance (self .options .data , str ):
136- filename = self .options .data
137- if self .options .data [:6 ] == "gcs://" :
138- bucket , filename = self .options .data .split ("://" )[- 1 ].split (
139- "/"
140- )
141- version = self .options .data_version
150+ if file_address not in POLICYENGINE_DATASETS :
151+ # If it's a local file, no URI present and unable to infer version.
152+ filename = file_address
153+ version = None
142154
143- file_path , version = download (
144- filepath = filename ,
145- gcs_bucket = bucket ,
146- version = version ,
147- return_version = True ,
148- )
149- self .data_version = version
150- filename = str (Path (file_path ))
151- else :
152- # If it's a local file, we can't infer the version.
153- version = None
154- if "cps_2023" in filename :
155- time_period = 2023
156- else :
157- time_period = None
158- self .options .data = Dataset .from_file (
159- filename , time_period = time_period
155+ else :
156+ # All official PolicyEngine datasets are stored in GCS;
157+ # load accordingly
158+ filename , version = self ._set_data_from_gs (
159+ file_address
160160 )
161+ self .data_version = version
162+
163+ time_period = self ._set_data_time_period (
164+ file_address
165+ )
166+
167+ self .options .data = Dataset .from_file (
168+ filename , time_period = time_period
169+ )
161170
162171 def _initialise_simulations (self ):
163172 self .baseline_simulation = self ._initialise_simulation (
@@ -361,3 +370,37 @@ def check_data_version(self) -> None:
361370 raise ValueError (
362371 f"Data version { self .data_version } does not match expected version { self .options .data_version } ."
363372 )
373+
374+ def _set_data_time_period (self , file_address : str ) -> Optional [int ]:
375+ """
376+ Set the time period based on the file address.
377+ If the file address is a PE dataset, return the time period from the dataset.
378+ If it's a local file, return None.
379+ """
380+ if file_address in DATASET_TIME_PERIODS :
381+ return DATASET_TIME_PERIODS [file_address ]
382+ else :
383+ # Local file, no time period available
384+ return None
385+
386+ def _set_data_from_gs (
387+ self , file_address : str
388+ ) -> tuple [str , str | None ]:
389+ """
390+ Set the data from a GCS path and return the filename and version.
391+ """
392+
393+ bucket , filename = process_gs_path (file_address )
394+ version = self .options .data_version
395+
396+ print (f"Downloading { filename } from bucket { bucket } " , file = sys .stderr )
397+
398+ filepath , version = download (
399+ filepath = filename ,
400+ gcs_bucket = bucket ,
401+ version = version ,
402+ return_version = True ,
403+ )
404+
405+ return filename , version
406+
0 commit comments