@@ -117,6 +117,25 @@ def __init__(
117117 os .makedirs (os .path .join (self .processed_dir , self .fold_dir ), exist_ok = True )
118118 self .save_hyperparameters ()
119119
120+ self ._num_of_labels = None
121+ self ._feature_vector_size = None
122+ self ._prepare_data_flag = 1
123+ self ._setup_data_flag = 1
124+ self .prepare_data ()
125+ self .setup ()
126+
127+ @property
128+ def num_of_labels (self ):
129+ assert self ._num_of_labels is not None , "num of labels must be set"
130+ return self ._num_of_labels
131+
132+ @property
133+ def feature_vector_size (self ):
134+ assert (
135+ self ._feature_vector_size is not None
136+ ), "size of feature vector must be set"
137+ return self ._feature_vector_size
138+
120139 @property
121140 def identifier (self ) -> tuple :
122141 """Identifier for the dataset."""
@@ -381,6 +400,12 @@ def predict_dataloader(
381400 """
382401 return self .dataloader (self .prediction_kind , shuffle = False , ** kwargs )
383402
403+ def prepare_data (self ) -> None :
404+ if self ._prepare_data_flag != 1 :
405+ return
406+
407+ self ._prepare_data_flag += 1
408+
384409 def setup (self , ** kwargs ):
385410 """
386411 Setup the data module.
@@ -390,6 +415,11 @@ def setup(self, **kwargs):
390415 Args:
391416 **kwargs: Additional keyword arguments.
392417 """
418+ if self ._setup_data_flag != 1 :
419+ return
420+
421+ self ._setup_data_flag += 1
422+
393423 rank_zero_info (f"Check for processed data in { self .processed_dir } " )
394424 rank_zero_info (f"Cross-validation enabled: { self .use_inner_cross_validation } " )
395425 if any (
@@ -401,20 +431,20 @@ def setup(self, **kwargs):
401431 if not ("keep_reader" in kwargs and kwargs ["keep_reader" ]):
402432 self .reader .on_finish ()
403433
404- self ._add_num_of_labels_to_hparams ()
434+ self ._set_processed_data_props ()
405435
406- def _add_num_of_labels_to_hparams (self ):
407- num_of_labels = len (
408- torch .load (
409- os .path .join (
410- self .processed_dir , self .processed_file_names_dict ["data" ]
411- ),
412- weights_only = False ,
413- )[0 ]["labels" ]
414- )
436+ def _set_processed_data_props (self ):
415437
416- print (f"Number of labels for loaded data: { num_of_labels } " )
417- self .hparams .num_of_labels = num_of_labels
438+ single_data_instance = torch .load (
439+ os .path .join (self .processed_dir , self .processed_file_names_dict ["data" ]),
440+ weights_only = False ,
441+ )[0 ]
442+
443+ self ._num_of_labels = len (single_data_instance ["labels" ])
444+ self ._feature_vector_size = len (single_data_instance ["features" ])
445+
446+ print (f"Number of labels for loaded data: { self ._num_of_labels } " )
447+ print (f"Feature vector size: { self ._feature_vector_size } " )
418448
419449 def setup_processed (self ):
420450 """
@@ -541,6 +571,7 @@ def prepare_data(self):
541571 """
542572 Placeholder for data preparation logic.
543573 """
574+ super ().prepare_data ()
544575 for s in self .subsets :
545576 s .prepare_data ()
546577
@@ -553,10 +584,14 @@ def setup(self, **kwargs):
553584 Args:
554585 **kwargs: Additional keyword arguments.
555586 """
587+ if self ._setup_data_flag != 1 :
588+ return
589+
590+ self ._setup_data_flag += 1
556591 for s in self .subsets :
557592 s .setup (** kwargs )
558593
559- self ._add_num_of_labels_to_hparams ()
594+ self ._set_processed_data_props ()
560595
561596 def dataloader (self , kind : str , ** kwargs ) -> DataLoader :
562597 """
@@ -752,6 +787,7 @@ def prepare_data(self, *args: Any, **kwargs: Any) -> None:
752787 Returns:
753788 None
754789 """
790+ super ().prepare_data ()
755791 print ("Checking for processed data in" , self .processed_dir_main )
756792
757793 processed_name = self .processed_main_file_names_dict ["data" ]
0 commit comments