33import sys
44from pydantic import BaseModel , Field
55from typing import Literal
6- from .utils .data .datasets import get_default_dataset , process_gs_path , POLICYENGINE_DATASETS , DATASET_TIME_PERIODS
6+ from .utils .data .datasets import (
7+ get_default_dataset ,
8+ process_gs_path ,
9+ POLICYENGINE_DATASETS ,
10+ DATASET_TIME_PERIODS ,
11+ )
712from policyengine_core .simulations import Simulation as CountrySimulation
813from policyengine_core .simulations import (
914 Microsimulation as CountryMicrosimulation ,
3136
3237CountryType = Literal ["uk" , "us" ]
3338ScopeType = Literal ["household" , "macro" ]
34- DataType = (
35- str | Dataset | None
36- )
39+ DataType = str | Dataset | None
3740TimePeriodType = int
3841ReformType = ParametricReform | Type [StructuralReform ] | None
3942RegionType = Optional [str ]
@@ -95,7 +98,7 @@ def __init__(self, **options: SimulationOptions):
9598 self .options = SimulationOptions (** options )
9699 self .check_model_version ()
97100 if not isinstance (self .options .data , Dataset ):
98- self ._set_data (self .options .data )
101+ self ._set_data (self .options .data )
99102 self ._initialise_simulations ()
100103 self .check_data_version ()
101104 self ._add_output_functions ()
@@ -139,8 +142,7 @@ def _set_data(self, file_address: str | None = None) -> None:
139142 # If None is passed, user wants default dataset; get URL, then continue initializing.
140143 if file_address is None :
141144 file_address = get_default_dataset (
142- country = self .options .country ,
143- region = self .options .region
145+ country = self .options .country , region = self .options .region
144146 )
145147 print (
146148 f"No data provided, using default dataset: { file_address } " ,
@@ -155,15 +157,11 @@ def _set_data(self, file_address: str | None = None) -> None:
155157 else :
156158 # All official PolicyEngine datasets are stored in GCS;
157159 # load accordingly
158- filename , version = self ._set_data_from_gs (
159- file_address
160- )
160+ filename , version = self ._set_data_from_gs (file_address )
161161 self .data_version = version
162162
163- time_period = self ._set_data_time_period (
164- file_address
165- )
166-
163+ time_period = self ._set_data_time_period (file_address )
164+
167165 self .options .data = Dataset .from_file (
168166 filename , time_period = time_period
169167 )
@@ -370,7 +368,7 @@ def check_data_version(self) -> None:
370368 raise ValueError (
371369 f"Data version { self .data_version } does not match expected version { self .options .data_version } ."
372370 )
373-
371+
374372 def _set_data_time_period (self , file_address : str ) -> Optional [int ]:
375373 """
376374 Set the time period based on the file address.
@@ -383,9 +381,7 @@ def _set_data_time_period(self, file_address: str) -> Optional[int]:
383381 # Local file, no time period available
384382 return None
385383
386- def _set_data_from_gs (
387- self , file_address : str
388- ) -> tuple [str , str | None ]:
384+ def _set_data_from_gs (self , file_address : str ) -> tuple [str , str | None ]:
389385 """
390386 Set the data from a GCS path and return the filename and version.
391387 """
@@ -403,4 +399,3 @@ def _set_data_from_gs(
403399 )
404400
405401 return filename , version
406-
0 commit comments