1515from abc import ABC
1616from collections import OrderedDict
1717from itertools import cycle
18- from multiprocessing import Pool
19- from typing import Any , Dict , Generator , List , Optional , Tuple , Union
18+ from typing import Any , Generator , Optional , Union
2019
2120import fastobo
2221import networkx as nx
@@ -142,15 +141,16 @@ def __init__(
142141 aug_smiles_variations : Optional [int ] = None ,
143142 ** kwargs ,
144143 ):
145- if augment_smiles :
146- assert aug_smiles_variations is not None , ""
144+ if bool (augment_smiles ):
147145 assert (
148- aug_smiles_variations > 0
146+ int ( aug_smiles_variations ) > 0
149147 ), "Number of variations must be greater than 0"
150- kwargs .setdefault ("reader_kwargs" , {}).update (canonicalize_smiles = False )
148+ reader_kwargs = kwargs .get ("reader_kwargs" , {})
149+ reader_kwargs ["canonicalize_smiles" ] = False
150+ kwargs ["reader_kwargs" ] = reader_kwargs
151151
152- self .augment_smiles = augment_smiles
153- self .aug_smiles_variations = aug_smiles_variations
152+ self .augment_smiles = bool ( augment_smiles )
153+ self .aug_smiles_variations = int ( aug_smiles_variations )
154154 # predict only single class (given as id of one of the classes present in the raw data set)
155155 self .single_class = single_class
156156 super (_ChEBIDataExtractor , self ).__init__ (** kwargs )
@@ -321,14 +321,28 @@ def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame:
321321 # This filters the DataFrame to include only the rows where at least one value in the row from 4th column
322322 # onwards is True/non-zero.
323323 data = data [data .iloc [:, self ._LABELS_START_IDX :].any (axis = 1 )]
324- if self .augment_smiles :
325- return self ._perform_smiles_augmentation (data )
326324 return data
327325
328- def _perform_smiles_augmentation (self , data_df : pd .DataFrame ) -> pd .DataFrame :
326+ def _after_prepare_data (self , * args , ** kwargs ) -> None :
327+ self ._perform_smiles_augmentation ()
328+
329+ def _perform_smiles_augmentation (self ) -> None :
330+ if not self .augment_smiles :
331+ return
332+
333+ aug_data_df = self .get_processed_pickled_df_file (
334+ self .processed_main_file_names_dict ["aug_data" ]
335+ )
336+ if aug_data_df is not None :
337+ return
338+
339+ data_df = self .get_processed_pickled_df_file (
340+ self .processed_main_file_names_dict ["data" ]
341+ )
342+
329343 AUG_SMILES_VARIATIONS = self .aug_smiles_variations
330344
331- def generate_augmented_smiles (smiles : str ):
345+ def generate_augmented_smiles (smiles : str ) -> list [ str ] :
332346 mol : Chem .Mol = Chem .MolFromSmiles (smiles )
333347 atom_ids = [atom .GetIdx () for atom in mol .GetAtoms ()]
334348 random .shuffle (atom_ids ) # seed set by lightning
@@ -340,15 +354,15 @@ def generate_augmented_smiles(smiles: str):
340354 augmented .add (smiles )
341355 return list (augmented )
342356
343- with Pool () as pool :
344- data_df ["augmented_smiles" ] = pool .map (
345- generate_augmented_smiles , data_df ["SMILES" ]
346- )
357+ data_df ["augmented_smiles" ] = data_df ["SMILES" ].apply (generate_augmented_smiles )
358+
347359 # Explode the list of augmented smiles into multiple rows
348360 # augmented smiles will have same ident, as of the original, but does it matter ?
349361 exploded_df = data_df .explode ("augmented_smiles" ).reset_index (drop = True )
350362 exploded_df .rename (columns = {"augmented_smiles" : "SMILES" }, inplace = True )
351- return exploded_df
363+ self .save_processed (
364+ exploded_df , self .processed_main_file_names_dict ["aug_data" ]
365+ )
352366
353367 # ------------------------------ Phase: Setup data -----------------------------------
354368 def setup_processed (self ) -> None :
@@ -377,7 +391,7 @@ def setup_processed(self) -> None:
377391 print ("Calling the setup method related to it" )
378392 self ._chebi_version_train_obj .setup ()
379393
380- def _load_dict (self , input_file_path : str ) -> Generator [Dict [str , Any ], None , None ]:
394+ def _load_dict (self , input_file_path : str ) -> Generator [dict [str , Any ], None , None ]:
381395 """
382396 Loads a dictionary from a pickled file, yielding individual dictionaries for each row.
383397
@@ -418,7 +432,7 @@ def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, No
418432 )
419433
420434 # ------------------------------ Phase: Dynamic Splits -----------------------------------
421- def _get_data_splits (self ) -> Tuple [pd .DataFrame , pd .DataFrame , pd .DataFrame ]:
435+ def _get_data_splits (self ) -> tuple [pd .DataFrame , pd .DataFrame , pd .DataFrame ]:
422436 """
423437 Loads encoded/transformed data and generates training, validation, and test splits.
424438
@@ -591,9 +605,10 @@ def processed_main_file_names_dict(self) -> dict:
591605 dict: A dictionary mapping dataset key to their respective file names.
592606 For example, {"data": "data.pkl"}.
593607 """
594- if not self .augment_smiles :
595- return super ().processed_main_file_names_dict
596- return {"data" : f"aug_data_var{ self .aug_smiles_variations } .pkl" }
608+ p_dict = super ().processed_main_file_names_dict
609+ if self .augment_smiles :
610+ p_dict ["aug_data" ] = f"aug_data_var{ self .aug_smiles_variations } .pkl"
611+ return p_dict
597612
598613 @property
599614 def processed_file_names_dict (self ) -> dict :
@@ -606,6 +621,10 @@ def processed_file_names_dict(self) -> dict:
606621 """
607622 if not self .augment_smiles :
608623 return super ().processed_file_names_dict
624+ if self .n_token_limit is not None :
625+ return {
626+ "data" : f"aug_data_var{ self .aug_smiles_variations } _maxlen{ self .n_token_limit } .pt"
627+ }
609628 return {"data" : f"aug_data_var{ self .aug_smiles_variations } .pt" }
610629
611630
@@ -644,7 +663,7 @@ def _name(self) -> str:
644663 """
645664 return f"ChEBI{ self .THRESHOLD } "
646665
647- def select_classes (self , g : nx .DiGraph , * args , ** kwargs ) -> List :
666+ def select_classes (self , g : nx .DiGraph , * args , ** kwargs ) -> list :
648667 """
649668 Selects classes from the ChEBI dataset based on the number of successors meeting a specified threshold.
650669
@@ -856,7 +875,7 @@ def chebi_to_int(s: str) -> int:
856875 return int (s [s .index (":" ) + 1 :])
857876
858877
859- def term_callback (doc : fastobo .term .TermFrame ) -> Union [Dict , bool ]:
878+ def term_callback (doc : fastobo .term .TermFrame ) -> Union [dict , bool ]:
860879 """
861880 Extracts information from a ChEBI term document.
862881 This function takes a ChEBI term document as input and extracts relevant information such as the term ID, parents,
0 commit comments