@@ -144,9 +144,9 @@ def __init__(
144144 ** kwargs ,
145145 ):
146146 if bool (augment_smiles ):
147- assert (
148- int ( aug_smiles_variations ) > 0
149- ), "Number of variations must be greater than 0"
147+ assert int ( aug_smiles_variations ) > 0 , (
148+ "Number of variations must be greater than 0"
149+ )
150150 aug_smiles_variations = int (aug_smiles_variations )
151151
152152 if not kwargs .get ("splits_file_path" , None ):
@@ -358,7 +358,8 @@ def _perform_smiles_augmentation(self) -> None:
358358 self .processed_main_file_names_dict ["data" ]
359359 )
360360
361- AUG_SMILES_VARIATIONS = self .aug_smiles_variations
361+ # +1 to account for if original SMILES is generated by random chance
362+ AUG_SMILES_VARIATIONS = self .aug_smiles_variations + 1
362363
363364 def generate_augmented_smiles (smiles : str ) -> list [str ]:
364365 mol : Chem .Mol = Chem .MolFromSmiles (smiles )
@@ -397,6 +398,16 @@ def generate_augmented_smiles(smiles: str) -> list[str]:
397398 else :
398399 augmented = frag_smiles [0 ]
399400
401+ if smiles in augmented :
402+ augmented .remove (smiles )
403+
404+ if len (augmented ) > AUG_SMILES_VARIATIONS - 1 :
405+ # AUG_SMILES_VARIATIONS = number of new smiles needed to generate + original smiles
406+ # if 3 smiles variations are needed, and 4 are generated because none
407+ # correponds to original smiles and no-one is elimnated in previous if condition
408+ augmented = random .sample (augmented , AUG_SMILES_VARIATIONS - 1 )
409+
410+ # original smiles always first in the list
400411 return [smiles ] + list (augmented )
401412
402413 data_df ["SMILES" ] = data_df ["SMILES" ].apply (generate_augmented_smiles )
0 commit comments