@@ -156,6 +156,51 @@ def __init__(
156156 ** _init_kwargs ,
157157 )
158158
159+ self .splits_file_path = self ._validate_splits_file_path (
160+ kwargs .get ("splits_file_path" , None )
161+ )
162+
163+ @staticmethod
164+ def _validate_splits_file_path (splits_file_path = None ):
165+ """
166+ Validates the provided splits file path.
167+
168+ Args:
169+ splits_file_path (str or None): Path to the splits CSV file.
170+
171+ Returns:
172+ str or None: Validated splits file path if checks pass, None if splits_file_path is None.
173+
174+ Raises:
175+ FileNotFoundError: If the splits file does not exist.
176+ ValueError: If the splits file is empty or missing required columns ('id' and/or 'split'), or not a CSV file.
177+ """
178+ if splits_file_path is None :
179+ return None
180+
181+ if not os .path .isfile (splits_file_path ):
182+ raise FileNotFoundError (f"File { splits_file_path } does not exist" )
183+
184+ file_size = os .path .getsize (splits_file_path )
185+ if file_size == 0 :
186+ raise ValueError (f"File { splits_file_path } is empty" )
187+
188+ # Check if the file has a CSV extension
189+ if not splits_file_path .lower ().endswith (".csv" ):
190+ raise ValueError (f"File { splits_file_path } is not a CSV file" )
191+
192+ # Read the CSV file into a DataFrame
193+ splits_df = pd .read_csv (splits_file_path )
194+
195+ # Check if 'id' and 'split' columns are in the DataFrame
196+ required_columns = {"id" , "split" }
197+ if not required_columns .issubset (splits_df .columns ):
198+ raise ValueError (
199+ f"CSV file { splits_file_path } is missing required columns ('id' and/or 'split')."
200+ )
201+
202+ return splits_file_path
203+
159204 def extract_class_hierarchy (self , chebi_path ):
160205 """
161206 Extracts the class hierarchy from the ChEBI ontology.
@@ -632,7 +677,7 @@ def prepare_data(self, *args, **kwargs):
632677 # Generate the "chebi_version_train" data if it doesn't exist
633678 self ._chebi_version_train_obj .prepare_data (* args , ** kwargs )
634679
635- def _get_dynamic_splits (self ):
680+ def _generate_dynamic_splits (self ):
636681 """Generate data splits during run-time and saves in class variables"""
637682
638683 # Load encoded data derived from "chebi_version"
@@ -687,10 +732,43 @@ def _get_dynamic_splits(self):
687732 )
688733 df_test = df_test_chebi_ver
689734
735+ # Generate splits.csv file to store ids of each corresponding split
736+ split_assignment_list : List [pd .DataFrame ] = [
737+ pd .DataFrame ({"id" : df_train ["ident" ], "split" : "train" }),
738+ pd .DataFrame ({"id" : df_val ["ident" ], "split" : "validation" }),
739+ pd .DataFrame ({"id" : df_test ["ident" ], "split" : "test" }),
740+ ]
741+ combined_split_assignment = pd .concat (split_assignment_list , ignore_index = True )
742+ combined_split_assignment .to_csv (
743+ os .path .join (self .processed_dir_main , "splits.csv" )
744+ )
745+
746+ # Store the splits in class variables
690747 self .dynamic_df_train = df_train
691748 self .dynamic_df_val = df_val
692749 self .dynamic_df_test = df_test
693750
751+ def _retreive_splits_from_csv (self ):
752+ splits_df = pd .read_csv (self .splits_file_path )
753+
754+ filename = self .processed_file_names_dict ["data" ]
755+ data_chebi_version = torch .load (os .path .join (self .processed_dir , filename ))
756+ df_chebi_version = pd .DataFrame (data_chebi_version )
757+
758+ train_ids = splits_df [splits_df ["split" ] == "train" ]["id" ]
759+ validation_ids = splits_df [splits_df ["split" ] == "validation" ]["id" ]
760+ test_ids = splits_df [splits_df ["split" ] == "test" ]["id" ]
761+
762+ self .dynamic_df_train = df_chebi_version [
763+ df_chebi_version ["ident" ].isin (train_ids )
764+ ]
765+ self .dynamic_df_val = df_chebi_version [
766+ df_chebi_version ["ident" ].isin (validation_ids )
767+ ]
768+ self .dynamic_df_test = df_chebi_version [
769+ df_chebi_version ["ident" ].isin (test_ids )
770+ ]
771+
694772 @property
695773 def dynamic_split_dfs (self ):
696774 if any (
@@ -701,7 +779,10 @@ def dynamic_split_dfs(self):
701779 self .dynamic_df_train ,
702780 ]
703781 ):
704- self ._get_dynamic_splits ()
782+ if self .splits_file_path is None :
783+ self ._generate_dynamic_splits ()
784+ else :
785+ self ._retreive_splits_from_csv ()
705786 return {
706787 "train" : self .dynamic_df_train ,
707788 "validation" : self .dynamic_df_val ,
0 commit comments