11import os
22from collections import OrderedDict
3- from random import randint
4- from typing import List , Literal
3+ from typing import List , Literal , Optional
54
65import pandas as pd
76from jsonargparse import CLI
@@ -59,12 +58,12 @@ def __init__(self, data_dir: str, go_branch: Literal["cc", "mf", "bp"]):
5958 raise ValueError (f"go_branch must be one of { valid_go_branches } " )
6059 self ._go_branch = go_branch
6160
62- self ._data_dir = os .path .join (data_dir , go_branch )
63- self ._train_df : pd .DataFrame = None
64- self ._test_df : pd .DataFrame = None
65- self ._validation_df : pd .DataFrame = None
66- self ._terms_df : pd .DataFrame = None
67- self ._classes : List [str ] = None
61+ self ._data_dir : str = os .path .join (rf" { data_dir } " , go_branch )
62+ self ._train_df : Optional [ pd .DataFrame ] = None
63+ self ._test_df : Optional [ pd .DataFrame ] = None
64+ self ._validation_df : Optional [ pd .DataFrame ] = None
65+ self ._terms_df : Optional [ pd .DataFrame ] = None
66+ self ._classes : Optional [ List [str ] ] = None
6867
6968 def _load_data (self ) -> None :
7069 """
@@ -114,7 +113,13 @@ def migrate(self) -> None:
114113 print ("Migration started......" )
115114 self ._load_data ()
116115 if not all (
117- [self ._train_df , self ._validation_df , self ._test_df , self ._terms_df ]
116+ df is not None
117+ for df in [
118+ self ._train_df ,
119+ self ._validation_df ,
120+ self ._test_df ,
121+ self ._terms_df ,
122+ ]
118123 ):
119124 raise Exception (
120125 "Data splits or terms data is not available in instance variables."
@@ -124,7 +129,9 @@ def migrate(self) -> None:
124129 data_df = self ._extract_required_data_from_splits ()
125130 data_with_labels_df = self ._generate_labels (data_df )
126131
127- if not all ([data_with_labels_df , splits_df , self ._classes ]):
132+ if not all (
133+ var is not None for var in [data_with_labels_df , splits_df , self ._classes ]
134+ ):
128135 raise Exception (
129136 "Data splits or terms data is not available in instance variables."
130137 )
@@ -184,8 +191,8 @@ def _generate_labels(self, data_df: pd.DataFrame) -> pd.DataFrame:
184191 pd.DataFrame: DataFrame with new label columns.
185192 """
186193 print ("Generating labels based on terms.pkl file......." )
187- parsed_go_ids : pd .Series = self ._terms_df .apply (
188- lambda row : self . extract_go_id ( row [ " gos" ] )
194+ parsed_go_ids : pd .Series = self ._terms_df [ "gos" ] .apply (
195+ lambda gos : _GOUniProtDataExtractor . _parse_go_id ( gos )
189196 )
190197 all_go_ids_list = parsed_go_ids .values .tolist ()
191198 self ._classes = all_go_ids_list
@@ -203,7 +210,7 @@ def _generate_labels(self, data_df: pd.DataFrame) -> pd.DataFrame:
203210 return data_df
204211
205212 @staticmethod
206- def extract_go_id (go_list : List [str ]) -> List [str ]:
213+ def extract_go_id (go_list : List [str ]) -> List [int ]:
207214 """
208215 Extracts and parses GO IDs from a list of GO annotations.
209216
@@ -230,13 +237,13 @@ def save_migrated_data(
230237 print ("Saving transformed data......" )
231238 go_class_instance : _GOUniProtDataExtractor = self ._CORRESPONDING_GO_CLASSES [
232239 self ._go_branch
233- ](go_branch = self ._go_branch , max_sequence_length = self ._MAXLEN )
240+ ](go_branch = self ._go_branch . upper () , max_sequence_length = self ._MAXLEN )
234241
235242 go_class_instance .save_processed (
236- data_df , go_class_instance .processed_file_names_dict ["data" ]
243+ data_df , go_class_instance .processed_main_file_names_dict ["data" ]
237244 )
238245 print (
239- f"{ go_class_instance .processed_file_names_dict ['data' ]} saved to { go_class_instance .processed_dir_main } "
246+ f"{ go_class_instance .processed_main_file_names_dict ['data' ]} saved to { go_class_instance .processed_dir_main } "
240247 )
241248
242249 splits_df .to_csv (
@@ -263,7 +270,8 @@ class Main:
263270 Initiates the migration process for the specified data directory and GO branch.
264271 """
265272
266- def migrate (self , data_dir : str , go_branch : str ) -> None :
273+ @staticmethod
274+ def migrate (data_dir : str , go_branch : Literal ["cc" , "mf" , "bp" ]) -> None :
267275 """
268276 Initiates the migration process by creating a DeepGoDataMigration instance
269277 and invoking its migrate method.
@@ -278,29 +286,12 @@ def migrate(self, data_dir: str, go_branch: str) -> None:
278286 DeepGoDataMigration (data_dir , go_branch ).migrate ()
279287
280288
281- class Main1 :
282- def __init__ (self , max_prize : int = 100 ):
283- """
284- Args:
285- max_prize: Maximum prize that can be awarded.
286- """
287- self .max_prize = max_prize
288-
289- def person (self , name : str , additional_prize : int = 0 ):
290- """
291- Args:
292- name: Name of the winner.
293- additional_prize: Additional prize that can be added to the prize amount.
294- """
295- prize = randint (0 , self .max_prize ) + additional_prize
296- return f"{ name } won { prize } €!"
297-
298-
299289if __name__ == "__main__" :
300- # Example: python script_name.py migrate data_dir="data/deep_go_se_training_data" go_branch="bp"
290+ # Example: python script_name.py migrate -- data_dir="data/deep_go_se_training_data" -- go_branch="bp"
301291 # --data_dir specifies the directory containing the data files.
302292 # --go_branch specifies the GO branch (cc, mf, or bp) you want to use for the migration.
303293 CLI (
304- Main1 ,
294+ Main ,
305295 description = "DeepGoDataMigration CLI tool to handle migration of GO data for specified branches (cc, mf, bp)." ,
296+ as_positional = False ,
306297 )
0 commit comments