@@ -120,6 +120,8 @@ class _ChEBIDataExtractor(XYBaseDataModule, ABC):
120120 chebi_version will be used for training, validation and test. Defaults to None.
121121 single_class (int, optional): The ID of the single class to predict. If not set, all available labels will be
122122 predicted. Defaults to None.
123+ dynamic_data_split_seed (int, optional): The seed for random data splitting. Defaults to 42.
124+ splits_file_path (str, optional): Path to the splits CSV file. Defaults to None.
123125 **kwargs: Additional keyword arguments (passed to XYBaseDataModule).
124126
125127 Attributes:
@@ -677,7 +679,7 @@ def prepare_data(self, *args, **kwargs):
677679
678680 def _generate_dynamic_splits (self ):
679681 """Generate data splits during run-time and saves in class variables"""
680-
682+ print ( "Generate dynamic splits..." )
681683 # Load encoded data derived from "chebi_version"
682684 try :
683685 filename = self .processed_file_names_dict ["data" ]
@@ -746,7 +748,8 @@ def _generate_dynamic_splits(self):
746748 self .dynamic_df_val = df_val
747749 self .dynamic_df_test = df_test
748750
749- def _retreive_splits_from_csv (self ):
751+ def _retrieve_splits_from_csv (self ):
752+ print (f"Loading splits from { self .splits_file_path } ..." )
750753 splits_df = pd .read_csv (self .splits_file_path )
751754
752755 filename = self .processed_file_names_dict ["data" ]
@@ -782,7 +785,7 @@ def dynamic_split_dfs(self):
782785 self ._generate_dynamic_splits ()
783786 else :
784787 # If user has provided splits file path, use it to get the splits from the data
785- self ._retreive_splits_from_csv ()
788+ self ._retrieve_splits_from_csv ()
786789 return {
787790 "train" : self .dynamic_df_train ,
788791 "validation" : self .dynamic_df_val ,
0 commit comments