Skip to content

Commit 2218fc8

Browse files
committed
avoid generation of original smiles in augmentation
1 parent 97839d9 commit 2218fc8

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

chebai/preprocessing/datasets/chebi.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -144,9 +144,9 @@ def __init__(
144144
**kwargs,
145145
):
146146
if bool(augment_smiles):
147-
assert (
148-
int(aug_smiles_variations) > 0
149-
), "Number of variations must be greater than 0"
147+
assert int(aug_smiles_variations) > 0, (
148+
"Number of variations must be greater than 0"
149+
)
150150
aug_smiles_variations = int(aug_smiles_variations)
151151

152152
if not kwargs.get("splits_file_path", None):
@@ -358,7 +358,8 @@ def _perform_smiles_augmentation(self) -> None:
358358
self.processed_main_file_names_dict["data"]
359359
)
360360

361-
AUG_SMILES_VARIATIONS = self.aug_smiles_variations
361+
# +1 to account for if original SMILES is generated by random chance
362+
AUG_SMILES_VARIATIONS = self.aug_smiles_variations + 1
362363

363364
def generate_augmented_smiles(smiles: str) -> list[str]:
364365
mol: Chem.Mol = Chem.MolFromSmiles(smiles)
@@ -397,6 +398,16 @@ def generate_augmented_smiles(smiles: str) -> list[str]:
397398
else:
398399
augmented = frag_smiles[0]
399400

401+
if smiles in augmented:
402+
augmented.remove(smiles)
403+
404+
if len(augmented) > AUG_SMILES_VARIATIONS - 1:
405+
# AUG_SMILES_VARIATIONS = number of new smiles needed to generate + original smiles
406+
# if 3 smiles variations are needed, and 4 are generated because none
407+
# correponds to original smiles and no-one is elimnated in previous if condition
408+
augmented = random.sample(augmented, AUG_SMILES_VARIATIONS - 1)
409+
410+
# original smiles always first in the list
400411
return [smiles] + list(augmented)
401412

402413
data_df["SMILES"] = data_df["SMILES"].apply(generate_augmented_smiles)

0 commit comments

Comments
 (0)