@@ -119,6 +119,23 @@ def __init__(
119119 os .makedirs (os .path .join (self .processed_dir , self .fold_dir ), exist_ok = True )
120120 self .save_hyperparameters ()
121121
122+ self ._num_of_labels = None
123+ self ._feature_vector_size = None
124+ self ._prepare_data_flag = 1
125+ self ._setup_data_flag = 1
126+
127+ @property
128+ def num_of_labels (self ):
129+ assert self ._num_of_labels is not None , "num of labels must be set"
130+ return self ._num_of_labels
131+
132+ @property
133+ def feature_vector_size (self ):
134+ assert (
135+ self ._feature_vector_size is not None
136+ ), "size of feature vector must be set"
137+ return self ._feature_vector_size
138+
122139 @property
123140 def identifier (self ) -> tuple :
124141 """Identifier for the dataset."""
@@ -390,7 +407,17 @@ def predict_dataloader(
390407 """
391408 return self .dataloader (self .prediction_kind , shuffle = False , ** kwargs )
392409
393- def setup (self , ** kwargs ):
410+ def prepare_data (self , * args , ** kwargs ) -> None :
411+ if self ._prepare_data_flag != 1 :
412+ return
413+
414+ self ._prepare_data_flag += 1
415+ self ._perform_data_preparation (* args , ** kwargs )
416+
417+ def _perform_data_preparation (self , * args , ** kwargs ) -> None :
418+ raise NotImplementedError
419+
420+ def setup (self , * args , ** kwargs ) -> None :
394421 """
395422 Setup the data module.
396423
@@ -399,6 +426,11 @@ def setup(self, **kwargs):
399426 Args:
400427 **kwargs: Additional keyword arguments.
401428 """
429+ if self ._setup_data_flag != 1 :
430+ return
431+
432+ self ._setup_data_flag += 1
433+
402434 rank_zero_info (f"Check for processed data in { self .processed_dir } " )
403435 rank_zero_info (f"Cross-validation enabled: { self .use_inner_cross_validation } " )
404436 if any (
@@ -410,6 +442,21 @@ def setup(self, **kwargs):
410442 if not ("keep_reader" in kwargs and kwargs ["keep_reader" ]):
411443 self .reader .on_finish ()
412444
445+ self ._set_processed_data_props ()
446+
447+ def _set_processed_data_props (self ):
448+
449+ data_pt = torch .load (
450+ os .path .join (self .processed_dir , self .processed_file_names_dict ["data" ]),
451+ weights_only = False ,
452+ )
453+
454+ self ._num_of_labels = len (data_pt [0 ]["labels" ])
455+ self ._feature_vector_size = max (len (d ["features" ]) for d in data_pt )
456+
457+ print (f"Number of labels for loaded data: { self ._num_of_labels } " )
458+ print (f"Feature vector size: { self ._feature_vector_size } " )
459+
413460 def setup_processed (self ):
414461 """
415462 Setup the processed data.
@@ -482,18 +529,6 @@ def raw_file_names_dict(self) -> dict:
482529 """
483530 raise NotImplementedError
484531
485- @property
486- def label_number (self ) -> int :
487- """
488- Returns the number of labels.
489-
490- This property should be implemented by subclasses to provide the number of labels.
491-
492- Returns:
493- int: The number of labels. Returns -1 for seq2seq encoding.
494- """
495- raise NotImplementedError
496-
497532
498533class MergedDataset (XYBaseDataModule ):
499534 MERGED = []
@@ -531,7 +566,7 @@ def __init__(
531566 os .makedirs (self .processed_dir , exist_ok = True )
532567 super (pl .LightningDataModule , self ).__init__ (** kwargs )
533568
534- def prepare_data (self ):
569+ def _perform_data_preparation (self ):
535570 """
536571 Placeholder for data preparation logic.
537572 """
@@ -547,9 +582,15 @@ def setup(self, **kwargs):
547582 Args:
548583 **kwargs: Additional keyword arguments.
549584 """
585+ if self ._setup_data_flag != 1 :
586+ return
587+
588+ self ._setup_data_flag += 1
550589 for s in self .subsets :
551590 s .setup (** kwargs )
552591
592+ self ._set_processed_data_props ()
593+
553594 def dataloader (self , kind : str , ** kwargs ) -> DataLoader :
554595 """
555596 Creates a DataLoader for a specific subset.
@@ -623,13 +664,6 @@ def processed_file_names(self) -> List[str]:
623664 """
624665 return ["test.pt" , "train.pt" , "validation.pt" ]
625666
626- @property
627- def label_number (self ) -> int :
628- """
629- Returns the number of labels from the first subset.
630- """
631- return self .subsets [0 ].label_number
632-
633667 @property
634668 def limits (self ):
635669 """
@@ -725,7 +759,7 @@ def _validate_splits_file_path(splits_file_path: Optional[str]) -> Optional[str]
725759 return splits_file_path
726760
727761 # ------------------------------ Phase: Prepare data -----------------------------------
728- def prepare_data (self , * args : Any , ** kwargs : Any ) -> None :
762+ def _perform_data_preparation (self , * args : Any , ** kwargs : Any ) -> None :
729763 """
730764 Prepares the data for the dataset.
731765
0 commit comments