@@ -59,6 +59,7 @@ class Dataset:
5959 region_id (str): The region ID, for saving the files accordingly and eventually downloading them if needed
6060 data_dir (str): Path to the directory containing the raw data, and save intermediate results
6161 RGIIds (pd.Series): Series of RGI IDs from the data
62+ output_format (str): csv or parquet
6263 months_tail_pad (list of str): Months to pad the start of the hydrological year
6364 months_head_pad (list of str): Months to pad the end of the hydrological year
6465 """
@@ -70,6 +71,7 @@ def __init__(
7071 region_name : str ,
7172 region_id : int ,
7273 data_path : str ,
74+ output_format : str = "csv" ,
7375 months_tail_pad = None , #: List[str] = ['aug_', 'sep_'], # before 'oct'
7476 months_head_pad = None , #: List[str] = ['oct_'], # after 'sep'
7577 ):
@@ -81,7 +83,8 @@ def __init__(
8183 self .RGIIds = self .data ["RGIId" ]
8284 if not os .path .isdir (self .data_dir ):
8385 os .makedirs (self .data_dir , exist_ok = True )
84-
86+ assert output_format in ["csv" , "parquet" ], "format must be csv or parquet"
87+ self .output_format = output_format
8588 # Padding to allow for flexible month ranges (customize freely)
8689 assert (months_head_pad is None ) == (
8790 months_tail_pad is None
@@ -101,7 +104,9 @@ def get_topo_features(self, vois: list[str], custom_working_dir: str = "") -> No
101104 vois (list[str]): A string containing the topographical variables of interest
102105 custom_working_dir (str, optional): The path to the custom working directory for OGGM data. Default to ''
103106 """
104- output_fname = self ._get_output_filename ("topographical_features" )
107+ output_fname = self ._get_output_filename (
108+ "topographical_features" , self .output_format
109+ )
105110 self .data = get_topographical_features (
106111 self .data , output_fname , vois , self .RGIIds , custom_working_dir , self .cfg
107112 )
@@ -124,7 +129,7 @@ def get_climate_features(
124129 change_units (bool, optional): A boolean indicating whether to change the units of the climate data. Default to False.
125130 smoothing_vois (dict, optional): A dictionary containing the variables of interest for smoothing climate artifacts. Default to None.
126131 """
127- output_fname = self ._get_output_filename ("climate_features" )
132+ output_fname = self ._get_output_filename ("climate_features" , self . output_format )
128133
129134 smoothing_vois = smoothing_vois or {} # Safely default to empty dict
130135 vois_climate = smoothing_vois .get ("vois_climate" )
@@ -207,9 +212,14 @@ def convert_to_monthly(
207212 """
208213 if meta_data_columns is None :
209214 meta_data_columns = self .cfg .metaData
210- output_fname = self ._get_output_filename ("monthly_dataset" )
215+ output_fname = self ._get_output_filename ("monthly_dataset" , self . output_format )
211216 self .data = transform_to_monthly (
212- self .data , meta_data_columns , vois_climate , vois_topographical , output_fname
217+ self .data ,
218+ meta_data_columns ,
219+ vois_climate ,
220+ vois_topographical ,
221+ output_fname ,
222+ self .output_format ,
213223 )
214224
215225 def get_glacier_mask (
@@ -254,17 +264,19 @@ def create_glacier_grid_RGI(self, custom_working_dir: str = "") -> pd.DataFrame:
254264 df_grid = create_glacier_grid_RGI (ds , years , glacier_indices , gdir , rgi_gl )
255265 return df_grid
256266
257- def _get_output_filename (self , feature_type : str ) -> str :
267+ def _get_output_filename (self , feature_type : str , output_format : str ) -> str :
258268 """
259269 Generates the output filename for a given feature type.
260270
261271 Args:
262272 feature_type (str): The type of feature (e.g., "topographical_features", "climate_features", "monthly")
263-
273+ format : csv or parquet
264274 Returns:
265275 str: The full path to the output file
266276 """
267- return os .path .join (self .data_dir , f"{ self .region } _{ feature_type } .csv" )
277+ return os .path .join (
278+ self .data_dir , f"{ self .region } _{ feature_type } .{ output_format } "
279+ )
268280
269281 def _copy_padded_month_columns (
270282 self , df : pd .DataFrame , prefixes = ("pcsr" ,), overwrite : bool = False
@@ -380,7 +392,6 @@ def __init__(
380392 self .metadata = metadata
381393 self .metadataColumns = metadataColumns or self .cfg .metaData
382394 self .targets = targets
383-
384395 assert len (self .features ) > 0 , "The features variable is empty."
385396
386397 _ , self .month_pos = _rebuild_month_index (months_head_pad , months_tail_pad )
@@ -391,10 +402,8 @@ def __init__(
391402 for i in range (len (self .metadata ))
392403 ]
393404 )
394- self .uniqueID = np .unique (self .ID )
395- self .maxConcatNb = max (
396- [len (np .argwhere (self .ID == id )[:, 0 ]) for id in self .uniqueID ]
397- )
405+ self .uniqueID , counts = np .unique (self .ID , return_counts = True )
406+ self .maxConcatNb = counts .max ()
398407 self .nbFeatures = self .features .shape [1 ]
399408 self .nbMetadata = self .metadata .shape [1 ]
400409 self .norm = Normalizer ({k : cfg .bnds [k ] for k in cfg .featureColumns })
@@ -416,14 +425,19 @@ def mapSplitsToDataset(
416425 corresponding indices the cross validation should use according to
417426 the input splits variable.
418427 """
428+ # Precompute the mapping of unique IDs to indices
429+ uniqueID_to_indices = {
430+ uid : np .where (self .uniqueID == uid )[0 ] for uid in self .uniqueID
431+ }
419432 ret = []
420433 for split in splits :
421434 t = []
422435 for e in split :
423- uniqueSelectedId = np .unique (self .ID [e ])
424- ind = np .argwhere (self .uniqueID [None , :] == uniqueSelectedId [:, None ])[
425- :, 1
426- ]
436+ uniqueSelectedId = np .unique (self .ID [e ]) # Get the unique selected IDs
437+ # Use the precomputed mapping for fast lookups
438+ ind = np .concatenate (
439+ [uniqueID_to_indices [uid ] for uid in uniqueSelectedId ]
440+ )
427441 assert all (uniqueSelectedId == self .uniqueID [ind ])
428442 t .append (ind )
429443 ret .append (tuple (t ))
@@ -470,6 +484,9 @@ def indexToId(self, index):
470484 return self .uniqueID [index ]
471485
472486 def indexToMetadata (self , index ):
487+ """
488+ Returns the metadata for a given index.
489+ """
473490 ind = self ._getInd (index )
474491 return self .metadata [ind ][:, :]
475492
0 commit comments