11import os
22from collections import OrderedDict
3- from typing import List , Literal , Optional
3+ from typing import List , Literal , Optional , Tuple
44
55import pandas as pd
6+ from iterstrat .ml_stratifiers import MultilabelStratifiedShuffleSplit
67from jsonargparse import CLI
78
8- from chebai .preprocessing .datasets .go_uniprot import (
9- GOUniProtOver50 ,
10- GOUniProtOver250 ,
11- _GOUniProtDataExtractor ,
12- )
9+ from chebai .preprocessing .datasets .go_uniprot import DeepGO1MigratedData
1310
1411
1512class DeepGo1DataMigration :
1613 """
1714 A class to handle data migration and processing for the DeepGO project.
18- It migrates the deepGO data to our data structure followed for GO-UniProt data.
15+ It migrates the DeepGO data to our data structure followed for GO-UniProt data.
1916
20- It migrates the data of DeepGO model of the below research paper :
17+ This class handles data from the DeepGO model as described in :
2118 Maxat Kulmanov, Mohammed Asif Khan, Robert Hoehndorf,
2219 DeepGO: predicting protein functions from sequence and interactions using a deep ontology-aware classifier,
2320 Bioinformatics, Volume 34, Issue 4, February 2018, Pages 660–668
24- (https://doi.org/10.1093/bioinformatics/btx624),
25-
26- Attributes:
27- _CORRESPONDING_GO_CLASSES (dict): Mapping of GO branches to specific data extractor classes.
28- _MAXLEN (int): Maximum sequence length for sequences.
29- _LABELS_START_IDX (int): Starting index for labels in the dataset.
30-
31- Methods:
32- __init__(data_dir, go_branch): Initializes the data directory and GO branch.
33- _load_data(): Loads train, validation, test, and terms data from the specified directory.
34- _record_splits(): Creates a DataFrame with IDs and their corresponding split.
35- migrate(): Executes the migration process including data loading, processing, and saving.
36- _extract_required_data_from_splits(): Extracts required columns from the splits data.
37- _get_labels_columns(data_df): Generates label columns for the data based on GO terms.
38- extract_go_id(go_list): Extracts GO IDs from a list.
39- save_migrated_data(data_df, splits_df): Saves the processed data and splits.
21+ (https://doi.org/10.1093/bioinformatics/btx624).
4022 """
4123
42- # Number of annotations for each go_branch as per the research paper
43- _CORRESPONDING_GO_CLASSES = {
44- "cc" : GOUniProtOver50 ,
45- "mf" : GOUniProtOver50 ,
46- "bp" : GOUniProtOver250 ,
47- }
48-
24+ # Max sequence length as per DeepGO1
4925 _MAXLEN = 1002
50- _LABELS_START_IDX = _GOUniProtDataExtractor ._LABELS_START_IDX
26+ _LABELS_START_IDX = DeepGO1MigratedData ._LABELS_START_IDX
5127
5228 def __init__ (self , data_dir : str , go_branch : Literal ["cc" , "mf" , "bp" ]):
5329 """
5430 Initializes the data migration object with a data directory and GO branch.
5531
5632 Args:
5733 data_dir (str): Directory containing the data files.
58- go_branch (Literal["cc", "mf", "bp"]): GO branch to use (cellular_component, molecular_function, or biological_process) .
34+ go_branch (Literal["cc", "mf", "bp"]): GO branch to use.
5935 """
60- valid_go_branches = list (self . _CORRESPONDING_GO_CLASSES .keys ())
36+ valid_go_branches = list (DeepGO1MigratedData . GO_BRANCH_MAPPING .keys ())
6137 if go_branch not in valid_go_branches :
6238 raise ValueError (f"go_branch must be one of { valid_go_branches } " )
6339 self ._go_branch = go_branch
@@ -69,49 +45,104 @@ def __init__(self, data_dir: str, go_branch: Literal["cc", "mf", "bp"]):
6945 self ._terms_df : Optional [pd .DataFrame ] = None
7046 self ._classes : Optional [List [str ]] = None
7147
48+ def migrate (self ) -> None :
49+ """
50+ Executes the data migration by loading, processing, and saving the data.
51+ """
52+ print ("Starting the migration process..." )
53+ self ._load_data ()
54+ if not all (
55+ df is not None
56+ for df in [
57+ self ._train_df ,
58+ self ._validation_df ,
59+ self ._test_df ,
60+ self ._terms_df ,
61+ ]
62+ ):
63+ raise Exception (
64+ "Data splits or terms data is not available in instance variables."
65+ )
66+ splits_df = self ._record_splits ()
67+ data_with_labels_df = self ._extract_required_data_from_splits ()
68+
69+ if not all (
70+ var is not None for var in [data_with_labels_df , splits_df , self ._classes ]
71+ ):
72+ raise Exception (
73+ "Data splits or terms data is not available in instance variables."
74+ )
75+
76+ self .save_migrated_data (data_with_labels_df , splits_df )
77+
7278 def _load_data (self ) -> None :
7379 """
7480 Loads the test, train, validation, and terms data from the pickled files
7581 in the data directory.
7682 """
7783 try :
78- print (f"Loading data from { self ._data_dir } ...... " )
84+ print (f"Loading data files from directory: { self ._data_dir } " )
7985 self ._test_df = pd .DataFrame (
8086 pd .read_pickle (
8187 os .path .join (self ._data_dir , f"test-{ self ._go_branch } .pkl" )
8288 )
8389 )
84- self ._train_df = pd .DataFrame (
90+
91+ # DeepGO 1 lacks a validation split, so we will create one by further splitting the training set.
92+ # Although this reduces the training data slightly compared to the original DeepGO setup,
93+ # given the data size, the impact should be minimal.
94+ train_df = pd .DataFrame (
8595 pd .read_pickle (
8696 os .path .join (self ._data_dir , f"train-{ self ._go_branch } .pkl" )
8797 )
8898 )
89- # self._validation_df = pd.DataFrame(
90- # pd.read_pickle(os.path.join(self._data_dir, f"valid-{self._go_branch}.pkl"))
91- # )
92-
93- # DeepGO1 data does not include a separate validation split, but our data structure requires one.
94- # To accommodate this, we will create a placeholder validation split by duplicating a small subset of the
95- # training data. However, to ensure a fair comparison with DeepGO1, we will retain the full training set
96- # without creating an exclusive validation split from it.
97- # Therefore, any metrics calculated on this placeholder validation set should be disregarded, as they do not
98- # reflect true validation performance.
99- self ._validation_df = self ._train_df [len (self ._train_df ) - 5 :]
99+
100+ self ._train_df , self ._validation_df = self ._get_train_val_split (train_df )
101+
100102 self ._terms_df = pd .DataFrame (
101103 pd .read_pickle (os .path .join (self ._data_dir , f"{ self ._go_branch } .pkl" ))
102104 )
103105
104106 except FileNotFoundError as e :
105107 print (f"Error loading data: { e } " )
106108
109+ @staticmethod
110+ def _get_train_val_split (
111+ train_df : pd .DataFrame ,
112+ ) -> Tuple [pd .DataFrame , pd .DataFrame ]:
113+ """
114+ Splits the training data into a smaller training set and a validation set.
115+
116+ Args:
117+ train_df (pd.DataFrame): Original training DataFrame.
118+
119+ Returns:
120+ Tuple[pd.DataFrame, pd.DataFrame]: Training and validation DataFrames.
121+ """
122+ labels_list_train = train_df ["labels" ].tolist ()
123+ train_split = 0.85
124+ test_size = ((1 - train_split ) ** 2 ) / train_split
125+
126+ splitter = MultilabelStratifiedShuffleSplit (
127+ n_splits = 1 , test_size = test_size , random_state = 42
128+ )
129+
130+ train_indices , validation_indices = next (
131+ splitter .split (labels_list_train , labels_list_train )
132+ )
133+
134+ df_validation = train_df .iloc [validation_indices ]
135+ df_train = train_df .iloc [train_indices ]
136+ return df_train , df_validation
137+
107138 def _record_splits (self ) -> pd .DataFrame :
108139 """
109140 Creates a DataFrame that stores the IDs and their corresponding data splits.
110141
111142 Returns:
112143 pd.DataFrame: A combined DataFrame containing split assignments.
113144 """
114- print ("Recording splits.. ." )
145+ print ("Recording data splits for train, validation, and test sets ." )
115146 split_assignment_list : List [pd .DataFrame ] = [
116147 pd .DataFrame ({"id" : self ._train_df ["proteins" ], "split" : "train" }),
117148 pd .DataFrame (
@@ -123,50 +154,18 @@ def _record_splits(self) -> pd.DataFrame:
123154 combined_split_assignment = pd .concat (split_assignment_list , ignore_index = True )
124155 return combined_split_assignment
125156
126- def migrate (self ) -> None :
127- """
128- Executes the data migration by loading, processing, and saving the data.
129- """
130- print ("Migration started......" )
131- self ._load_data ()
132- if not all (
133- df is not None
134- for df in [
135- self ._train_df ,
136- self ._validation_df ,
137- self ._test_df ,
138- self ._terms_df ,
139- ]
140- ):
141- raise Exception (
142- "Data splits or terms data is not available in instance variables."
143- )
144- splits_df = self ._record_splits ()
145-
146- data_with_labels_df = self ._extract_required_data_from_splits ()
147-
148- if not all (
149- var is not None for var in [data_with_labels_df , splits_df , self ._classes ]
150- ):
151- raise Exception (
152- "Data splits or terms data is not available in instance variables."
153- )
154-
155- self .save_migrated_data (data_with_labels_df , splits_df )
156-
157157 def _extract_required_data_from_splits (self ) -> pd .DataFrame :
158158 """
159159 Extracts required columns from the combined data splits.
160160
161161 Returns:
162162 pd.DataFrame: A DataFrame containing the essential columns for processing.
163163 """
164- print ("Combining the data splits with required data..... " )
164+ print ("Combining data splits into a single DataFrame with required columns. " )
165165 required_columns = [
166166 "proteins" ,
167167 "accessions" ,
168168 "sequences" ,
169- # Note: The GO classes here only directly related one, and not transitive GO classes
170169 "gos" ,
171170 "labels" ,
172171 ]
@@ -183,7 +182,7 @@ def _extract_required_data_from_splits(self) -> pd.DataFrame:
183182 lambda row : self .extract_go_id (row ["gos" ]), axis = 1
184183 )
185184
186- labels_df = self ._get_labels_colums (new_df )
185+ labels_df = self ._get_labels_columns (new_df )
187186
188187 data_df = pd .DataFrame (
189188 OrderedDict (
@@ -198,28 +197,32 @@ def _extract_required_data_from_splits(self) -> pd.DataFrame:
198197
199198 return df
200199
201- def _get_labels_colums (self , data_df : pd .DataFrame ) -> pd .DataFrame :
200+ @staticmethod
201+ def extract_go_id (go_list : List [str ]) -> List [int ]:
202202 """
203- Generates a DataFrame with one-hot encoded columns for each GO term label,
204- based on the terms provided in `self._terms_df` and the existing labels in `data_df`.
203+ Extracts and parses GO IDs from a list of GO annotations.
205204
206- This method extracts GO IDs from the `functions` column of `self._terms_df`,
207- creating a list of all unique GO IDs. It then uses this list to create new
208- columns in the returned DataFrame, where each row has binary values
209- (0 or 1) indicating the presence of each GO ID in the corresponding entry of
210- `data_df['labels']`.
205+ Args:
206+ go_list (List[str]): List of GO annotation strings.
207+
208+ Returns:
209+ List[int]: List of parsed GO IDs.
210+ """
211+ return [DeepGO1MigratedData ._parse_go_id (go_id_str ) for go_id_str in go_list ]
212+
213+ def _get_labels_columns (self , data_df : pd .DataFrame ) -> pd .DataFrame :
214+ """
215+ Generates columns for labels based on provided selected terms.
211216
212217 Args:
213- data_df (pd.DataFrame): DataFrame containing data with a 'labels' column,
214- which holds lists of GO ID labels for each row.
218+ data_df (pd.DataFrame): DataFrame with GO annotations and labels.
215219
216220 Returns:
217- pd.DataFrame: A DataFrame with the same index as `data_df` and one column
218- per GO ID, containing binary values indicating label presence.
221+ pd.DataFrame: DataFrame with label columns.
219222 """
220- print ("Generating labels based on terms.pkl file...... ." )
223+ print ("Generating label columns from provided selected terms ." )
221224 parsed_go_ids : pd .Series = self ._terms_df ["functions" ].apply (
222- lambda gos : _GOUniProtDataExtractor ._parse_go_id (gos )
225+ lambda gos : DeepGO1MigratedData ._parse_go_id (gos )
223226 )
224227 all_go_ids_list = parsed_go_ids .values .tolist ()
225228 self ._classes = all_go_ids_list
@@ -230,21 +233,6 @@ def _get_labels_colums(self, data_df: pd.DataFrame) -> pd.DataFrame:
230233
231234 return new_label_columns
232235
233- @staticmethod
234- def extract_go_id (go_list : List [str ]) -> List [int ]:
235- """
236- Extracts and parses GO IDs from a list of GO annotations.
237-
238- Args:
239- go_list (List[str]): List of GO annotation strings.
240-
241- Returns:
242- List[str]: List of parsed GO IDs.
243- """
244- return [
245- _GOUniProtDataExtractor ._parse_go_id (go_id_str ) for go_id_str in go_list
246- ]
247-
248236 def save_migrated_data (
249237 self , data_df : pd .DataFrame , splits_df : pd .DataFrame
250238 ) -> None :
@@ -255,31 +243,38 @@ def save_migrated_data(
255243 data_df (pd.DataFrame): Data with GO labels.
256244 splits_df (pd.DataFrame): Split assignment DataFrame.
257245 """
258- print ("Saving transformed data......" )
259- go_class_instance : _GOUniProtDataExtractor = self ._CORRESPONDING_GO_CLASSES [
260- self ._go_branch
261- ](go_branch = self ._go_branch .upper (), max_sequence_length = self ._MAXLEN )
246+ print ("Saving transformed data files." )
262247
263- go_class_instance .save_processed (
264- data_df , go_class_instance .processed_main_file_names_dict ["data" ]
248+ deepgo_migr_inst : DeepGO1MigratedData = DeepGO1MigratedData (
249+ go_branch = DeepGO1MigratedData .GO_BRANCH_MAPPING [self ._go_branch ],
250+ max_sequence_length = self ._MAXLEN ,
251+ )
252+
253+ # Save data file
254+ deepgo_migr_inst .save_processed (
255+ data_df , deepgo_migr_inst .processed_main_file_names_dict ["data" ]
265256 )
266257 print (
267- f"{ go_class_instance .processed_main_file_names_dict ['data' ]} saved to { go_class_instance .processed_dir_main } "
258+ f"{ deepgo_migr_inst .processed_main_file_names_dict ['data' ]} saved to { deepgo_migr_inst .processed_dir_main } "
268259 )
269260
261+ # Save splits file
270262 splits_df .to_csv (
271- os .path .join (go_class_instance .processed_dir_main , "splits .csv" ),
263+ os .path .join (deepgo_migr_inst .processed_dir_main , "splits_deep_go1 .csv" ),
272264 index = False ,
273265 )
274- print (f"splits .csv saved to { go_class_instance .processed_dir_main } " )
266+ print (f"splits_deep_go1 .csv saved to { deepgo_migr_inst .processed_dir_main } " )
275267
268+ # Save classes file
276269 classes = sorted (self ._classes )
277270 with open (
278- os .path .join (go_class_instance .processed_dir_main , "classes.txt" ), "wt"
271+ os .path .join (deepgo_migr_inst .processed_dir_main , "classes_deep_go1.txt" ),
272+ "wt" ,
279273 ) as fout :
280274 fout .writelines (str (node ) + "\n " for node in classes )
281- print (f"classes.txt saved to { go_class_instance .processed_dir_main } " )
282- print ("Migration completed!" )
275+ print (f"classes_deep_go1.txt saved to { deepgo_migr_inst .processed_dir_main } " )
276+
277+ print ("Migration process completed!" )
283278
284279
285280class Main :
0 commit comments