1111
1212import os
1313import pickle
14+ import random
1415from abc import ABC
1516from collections import OrderedDict
16- from typing import TYPE_CHECKING , Any , Dict , Generator , List , Optional , Tuple , Union
17+ from itertools import cycle , permutations , product
18+ from typing import TYPE_CHECKING , Any , Dict , Generator , List , Optional , Union
1719
1820import pandas as pd
1921import torch
22+ from rdkit import Chem
2023
2124from chebai .preprocessing import reader as dr
2225from chebai .preprocessing .datasets .base import XYBaseDataModule , _DynamicDataset
@@ -135,8 +138,27 @@ def __init__(
135138 self ,
136139 chebi_version_train : Optional [int ] = None ,
137140 single_class : Optional [int ] = None ,
141+ augment_smiles : bool = False ,
142+ aug_smiles_variations : Optional [int ] = None ,
138143 ** kwargs ,
139144 ):
145+ if bool (augment_smiles ):
146+ assert (
147+ int (aug_smiles_variations ) > 0
148+ ), "Number of variations must be greater than 0"
149+ aug_smiles_variations = int (aug_smiles_variations )
150+
151+ if not kwargs .get ("splits_file_path" , None ):
152+ raise ValueError (
153+ "When using SMILES augmentation, a splits_file_path must be provided to ensure consistent splits."
154+ )
155+
156+ reader_kwargs = kwargs .get ("reader_kwargs" , {})
157+ reader_kwargs ["canonicalize_smiles" ] = False
158+ kwargs ["reader_kwargs" ] = reader_kwargs
159+
160+ self .augment_smiles = bool (augment_smiles )
161+ self .aug_smiles_variations = aug_smiles_variations
140162 # predict only single class (given as id of one of the classes present in the raw data set)
141163 self .single_class = single_class
142164 super (_ChEBIDataExtractor , self ).__init__ (** kwargs )
@@ -151,6 +173,8 @@ def __init__(
151173 _init_kwargs ["chebi_version" ] = self .chebi_version_train
152174 self ._chebi_version_train_obj = self .__class__ (
153175 single_class = self .single_class ,
176+ augment_smiles = self .augment_smiles ,
177+ aug_smiles_variations = self .aug_smiles_variations ,
154178 ** _init_kwargs ,
155179 )
156180
@@ -312,6 +336,75 @@ def _graph_to_raw_dataset(self, g: "nx.DiGraph") -> pd.DataFrame:
312336
313337 return data
314338
339+ def _after_prepare_data (self , * args , ** kwargs ) -> None :
340+ self ._perform_smiles_augmentation ()
341+
342+ def _perform_smiles_augmentation (self ) -> None :
343+ if not self .augment_smiles :
344+ return
345+
346+ aug_pkl_file_name = self .processed_main_file_names_dict ["aug_data" ]
347+ aug_data_df = self .get_processed_pickled_df_file (aug_pkl_file_name )
348+ if aug_data_df is not None :
349+ self ._data_pkl_filename = aug_pkl_file_name
350+ return
351+
352+ data_df = self .get_processed_pickled_df_file (
353+ self .processed_main_file_names_dict ["data" ]
354+ )
355+
356+ AUG_SMILES_VARIATIONS = self .aug_smiles_variations
357+
358+ def generate_augmented_smiles (smiles : str ) -> list [str ]:
359+ mol : Chem .Mol = Chem .MolFromSmiles (smiles )
360+ if mol is None :
361+ return [smiles ] # if mol is None, return original SMILES
362+
363+ # sanitization set to False, as it can alter the fragment representation in ways you might not want.
364+ # As we don’t want RDKit to "fix" fragments, only need the fragments as-is, to generate SMILES strings.
365+ frags = Chem .GetMolFrags (mol , asMols = True , sanitizeFrags = False )
366+ augmented = set ()
367+
368+ frag_smiles : list [set ] = []
369+ for frag in frags :
370+ atom_ids = [atom .GetIdx () for atom in frag .GetAtoms ()]
371+ random .shuffle (atom_ids ) # seed set by lightning
372+ atom_id_iter = cycle (atom_ids )
373+ frag_smiles .append (
374+ {
375+ Chem .MolToSmiles (
376+ frag , rootedAtAtom = next (atom_id_iter ), doRandom = True
377+ )
378+ for _ in range (AUG_SMILES_VARIATIONS )
379+ }
380+ )
381+ if len (frags ) > 1 :
382+ # all permutations (ignoring the set order, meaning mixing sets in every order),
383+ aug_counter : int = 0
384+ for perm in permutations (frag_smiles ):
385+ for combo in product (* perm ):
386+ augmented .add ("." .join (combo ))
387+ aug_counter += 1
388+ if aug_counter >= AUG_SMILES_VARIATIONS :
389+ break
390+ if aug_counter >= AUG_SMILES_VARIATIONS :
391+ break
392+ else :
393+ augmented = frag_smiles [0 ]
394+
395+ return [smiles ] + list (augmented )
396+
397+ data_df ["SMILES" ] = data_df ["SMILES" ].apply (generate_augmented_smiles )
398+
399+ # Explode the list of augmented smiles into multiple rows
400+ # augmented smiles will have same ident, as of the original, but does it matter ?
401+ # instead its helpful to group augmented smiles generated from the same original SMILES
402+ exploded_df = data_df .explode ("SMILES" ).reset_index (drop = True )
403+ self .save_processed (
404+ exploded_df , self .processed_main_file_names_dict ["aug_data" ]
405+ )
406+ self ._data_pkl_filename = aug_pkl_file_name
407+
315408 # ------------------------------ Phase: Setup data -----------------------------------
316409 def setup_processed (self ) -> None :
317410 """
@@ -339,7 +432,7 @@ def setup_processed(self) -> None:
339432 print ("Calling the setup method related to it" )
340433 self ._chebi_version_train_obj .setup ()
341434
342- def _load_dict (self , input_file_path : str ) -> Generator [Dict [str , Any ], None , None ]:
435+ def _load_dict (self , input_file_path : str ) -> Generator [dict [str , Any ], None , None ]:
343436 """
344437 Loads a dictionary from a pickled file, yielding individual dictionaries for each row.
345438
@@ -380,7 +473,7 @@ def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, No
380473 )
381474
382475 # ------------------------------ Phase: Dynamic Splits -----------------------------------
383- def _get_data_splits (self ) -> Tuple [pd .DataFrame , pd .DataFrame , pd .DataFrame ]:
476+ def _get_data_splits (self ) -> tuple [pd .DataFrame , pd .DataFrame , pd .DataFrame ]:
384477 """
385478 Loads encoded/transformed data and generates training, validation, and test splits.
386479
@@ -544,6 +637,37 @@ def processed_dir(self) -> str:
544637 def raw_file_names_dict (self ) -> dict :
545638 return {"chebi" : "chebi.obo" }
546639
640+ @property
641+ def processed_main_file_names_dict (self ) -> dict :
642+ """
643+ Returns a dictionary mapping processed data file names.
644+
645+ Returns:
646+ dict: A dictionary mapping dataset key to their respective file names.
647+ For example, {"data": "data.pkl"}.
648+ """
649+ p_dict = super ().processed_main_file_names_dict
650+ if self .augment_smiles :
651+ p_dict ["aug_data" ] = f"aug_data_var{ self .aug_smiles_variations } .pkl"
652+ return p_dict
653+
654+ @property
655+ def processed_file_names_dict (self ) -> dict :
656+ """
657+ Returns a dictionary for the processed and tokenized data files.
658+
659+ Returns:
660+ dict: A dictionary mapping dataset keys to their respective file names.
661+ For example, {"data": "data.pt"}.
662+ """
663+ if not self .augment_smiles :
664+ return super ().processed_file_names_dict
665+ if self .n_token_limit is not None :
666+ return {
667+ "data" : f"aug_data_var{ self .aug_smiles_variations } _maxlen{ self .n_token_limit } .pt"
668+ }
669+ return {"data" : f"aug_data_var{ self .aug_smiles_variations } .pt" }
670+
547671
548672class JCIExtendedBase (_ChEBIDataExtractor ):
549673 @property
0 commit comments