Skip to content

Commit 8c5ebcd

Browse files
authored
Merge pull request #136 from ChEB-AI/fix/duplicate_smiles_aug
Avoid generation of original SMILES in augmentation
2 parents b1af011 + df02f55 commit 8c5ebcd

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

chebai/preprocessing/datasets/chebi.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,8 @@ def _perform_smiles_augmentation(self) -> None:
358358
self.processed_main_file_names_dict["data"]
359359
)
360360

361-
AUG_SMILES_VARIATIONS = self.aug_smiles_variations
361+
# +1 to account for if original SMILES is generated by random chance
362+
AUG_SMILES_VARIATIONS = self.aug_smiles_variations + 1
362363

363364
def generate_augmented_smiles(smiles: str) -> list[str]:
364365
mol: Chem.Mol = Chem.MolFromSmiles(smiles)
@@ -397,6 +398,16 @@ def generate_augmented_smiles(smiles: str) -> list[str]:
397398
else:
398399
augmented = frag_smiles[0]
399400

401+
if smiles in augmented:
402+
augmented.remove(smiles)
403+
404+
if len(augmented) > AUG_SMILES_VARIATIONS - 1:
405+
# AUG_SMILES_VARIATIONS = number of new smiles needed to generate + original smiles
406+
# if 3 smiles variations are needed, and 4 are generated because none
407+
# correponds to original smiles and no-one is elimnated in previous if condition
408+
augmented = random.sample(augmented, AUG_SMILES_VARIATIONS - 1)
409+
410+
# original smiles always first in the list
400411
return [smiles] + list(augmented)
401412

402413
data_df["SMILES"] = data_df["SMILES"].apply(generate_augmented_smiles)

0 commit comments

Comments
 (0)