11import os
22import random
33from abc import ABC , abstractmethod
4+ from pathlib import Path
45from typing import TYPE_CHECKING , Any , Dict , Generator , List , Optional , Tuple , Union
56
67import lightning as pl
@@ -76,6 +77,7 @@ def __init__(
7677 label_filter : Optional [int ] = None ,
7778 balance_after_filter : Optional [float ] = None ,
7879 num_workers : int = 1 ,
80+ persistent_workers : bool = True ,
7981 chebi_version : int = 200 ,
8082 inner_k_folds : int = - 1 , # use inner cross-validation if > 1
8183 fold_index : Optional [int ] = None ,
@@ -99,6 +101,7 @@ def __init__(
99101 ), "Filter balancing requires a filter"
100102 self .balance_after_filter = balance_after_filter
101103 self .num_workers = num_workers
104+ self .persistent_workers : bool = bool (persistent_workers )
102105 self .chebi_version = chebi_version
103106 assert type (inner_k_folds ) is int
104107 self .inner_k_folds = inner_k_folds
@@ -363,7 +366,7 @@ def train_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader
363366 "train" ,
364367 shuffle = True ,
365368 num_workers = self .num_workers ,
366- persistent_workers = True ,
369+ persistent_workers = self . persistent_workers ,
367370 ** kwargs ,
368371 )
369372
@@ -382,7 +385,7 @@ def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]
382385 "validation" ,
383386 shuffle = False ,
384387 num_workers = self .num_workers ,
385- persistent_workers = True ,
388+ persistent_workers = self . persistent_workers ,
386389 ** kwargs ,
387390 )
388391
@@ -420,10 +423,17 @@ def prepare_data(self, *args, **kwargs) -> None:
420423
421424 self ._prepare_data_flag += 1
422425 self ._perform_data_preparation (* args , ** kwargs )
426+ self ._after_prepare_data (* args , ** kwargs )
423427
424428 def _perform_data_preparation (self , * args , ** kwargs ) -> None :
425429 raise NotImplementedError
426430
431+ def _after_prepare_data (self , * args , ** kwargs ) -> None :
432+ """
433+ Hook to perform additional pre-processing after pre-processed data is available.
434+ """
435+ ...
436+
427437 def setup (self , * args , ** kwargs ) -> None :
428438 """
429439 Setup the data module.
@@ -466,14 +476,17 @@ def _set_processed_data_props(self):
466476 - self._num_of_labels: Number of target labels in the dataset.
467477 - self._feature_vector_size: Maximum feature vector length across all data points.
468478 """
469- data_pt = torch .load (
470- os .path .join (self .processed_dir , self .processed_file_names_dict ["data" ]),
471- weights_only = False ,
479+ pt_file_path = os .path .join (
480+ self .processed_dir , self .processed_file_names_dict ["data" ]
472481 )
482+ data_pt = torch .load (pt_file_path , weights_only = False )
473483
474484 self ._num_of_labels = len (data_pt [0 ]["labels" ])
475485 self ._feature_vector_size = max (len (d ["features" ]) for d in data_pt )
476486
487+ print (
488+ f"Number of samples in encoded data ({ pt_file_path } ): { len (data_pt )} samples"
489+ )
477490 print (f"Number of labels for loaded data: { self ._num_of_labels } " )
478491 print (f"Feature vector size: { self ._feature_vector_size } " )
479492
@@ -747,6 +760,7 @@ def __init__(
747760 )
748761 self .apply_label_filter = apply_label_filter
749762 self .apply_id_filter = apply_id_filter
763+ self ._data_pkl_filename : str = "data.pkl"
750764
751765 @staticmethod
752766 def _validate_splits_file_path (splits_file_path : Optional [str ]) -> Optional [str ]:
@@ -885,6 +899,21 @@ def save_processed(self, data: pd.DataFrame, filename: str) -> None:
885899 """
886900 pd .to_pickle (data , open (os .path .join (self .processed_dir_main , filename ), "wb" ))
887901
902+ def get_processed_pickled_df_file (self , filename : str ) -> Optional [pd .DataFrame ]:
903+ """
904+ Gets the processed dataset pickle file.
905+
906+ Args:
907+ filename (str): The filename for the pickle file.
908+
909+ Returns:
910+ pd.DataFrame: The processed dataset as a DataFrame.
911+ """
912+ file_path = Path (self .processed_dir_main ) / filename
913+ if file_path .exists ():
914+ return pd .read_pickle (file_path )
915+ return None
916+
888917 # ------------------------------ Phase: Setup data -----------------------------------
889918 def setup_processed (self ) -> None :
890919 """
@@ -923,7 +952,9 @@ def _get_data_size(input_file_path: str) -> int:
923952 int: The size of the data.
924953 """
925954 with open (input_file_path , "rb" ) as f :
926- return len (pd .read_pickle (f ))
955+ df = pd .read_pickle (f )
956+ print (f"Processed data size ({ input_file_path } ): { len (df )} rows" )
957+ return len (df )
927958
928959 @abstractmethod
929960 def _load_dict (self , input_file_path : str ) -> Generator [Dict [str , Any ], None , None ]:
@@ -1260,7 +1291,7 @@ def processed_main_file_names_dict(self) -> dict:
12601291 dict: A dictionary mapping dataset key to their respective file names.
12611292 For example, {"data": "data.pkl"}.
12621293 """
1263- return {"data" : "data.pkl" }
1294+ return {"data" : self . _data_pkl_filename }
12641295
12651296 @property
12661297 def raw_file_names (self ) -> List [str ]:
0 commit comments