Skip to content

Commit 5d3aa0a

Browse files
committed
aug smiles based on fragmentation - #44
1 parent d864011 commit 5d3aa0a

File tree

1 file changed

+38
-15
lines changed

1 file changed

+38
-15
lines changed

chebai/preprocessing/datasets/chebi.py

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import random
1515
from abc import ABC
1616
from collections import OrderedDict
17-
from itertools import cycle
17+
from itertools import cycle, permutations, product
1818
from typing import Any, Generator, Optional, Union
1919

2020
import fastobo
@@ -145,12 +145,13 @@ def __init__(
145145
assert (
146146
int(aug_smiles_variations) > 0
147147
), "Number of variations must be greater than 0"
148+
aug_smiles_variations = int(aug_smiles_variations)
148149
reader_kwargs = kwargs.get("reader_kwargs", {})
149150
reader_kwargs["canonicalize_smiles"] = False
150151
kwargs["reader_kwargs"] = reader_kwargs
151152

152153
self.augment_smiles = bool(augment_smiles)
153-
self.aug_smiles_variations = int(aug_smiles_variations)
154+
self.aug_smiles_variations = aug_smiles_variations
154155
# predict only single class (given as id of one of the classes present in the raw data set)
155156
self.single_class = single_class
156157
super(_ChEBIDataExtractor, self).__init__(**kwargs)
@@ -344,22 +345,44 @@ def _perform_smiles_augmentation(self) -> None:
344345

345346
def generate_augmented_smiles(smiles: str) -> list[str]:
346347
mol: Chem.Mol = Chem.MolFromSmiles(smiles)
347-
atom_ids = [atom.GetIdx() for atom in mol.GetAtoms()]
348-
random.shuffle(atom_ids) # seed set by lightning
349-
atom_id_iter = cycle(atom_ids)
350-
augmented = {
351-
Chem.MolToSmiles(mol, rootedAtAtom=next(atom_id_iter), doRandom=True)
352-
for _ in range(AUG_SMILES_VARIATIONS)
353-
}
354-
augmented.add(smiles)
355-
return list(augmented)
356-
357-
data_df["augmented_smiles"] = data_df["SMILES"].apply(generate_augmented_smiles)
348+
frags = Chem.GetMolFrags(mol, asMols=True, sanitizeFrags=False)
349+
augmented = set()
350+
351+
frag_smiles: list[set] = []
352+
for frag in frags:
353+
atom_ids = [atom.GetIdx() for atom in frag.GetAtoms()]
354+
random.shuffle(atom_ids) # seed set by lightning
355+
atom_id_iter = cycle(atom_ids)
356+
frag_smiles.append(
357+
{
358+
Chem.MolToSmiles(
359+
frag, rootedAtAtom=next(atom_id_iter), doRandom=True
360+
)
361+
for _ in range(AUG_SMILES_VARIATIONS)
362+
}
363+
)
364+
if len(frags) > 1:
365+
# all permutations (ignoring the set order, meaning mixing sets in every order),
366+
aug_counter: int = 0
367+
for perm in permutations(frag_smiles):
368+
for combo in product(*perm):
369+
augmented.add(".".join(combo))
370+
aug_counter += 1
371+
if aug_counter >= AUG_SMILES_VARIATIONS:
372+
break
373+
if aug_counter >= AUG_SMILES_VARIATIONS:
374+
break
375+
else:
376+
augmented = frag_smiles[0]
377+
378+
return [smiles] + list(augmented)
379+
380+
data_df["SMILES"] = data_df["SMILES"].apply(generate_augmented_smiles)
358381

359382
# Explode the list of augmented smiles into multiple rows
360383
# augmented smiles will have same ident, as of the original, but does it matter ?
361-
exploded_df = data_df.explode("augmented_smiles").reset_index(drop=True)
362-
exploded_df.rename(columns={"augmented_smiles": "SMILES"}, inplace=True)
384+
# instead its helpful to group augmented smiles generated from the same original SMILES
385+
exploded_df = data_df.explode("SMILES").reset_index(drop=True)
363386
self.save_processed(
364387
exploded_df, self.processed_main_file_names_dict["aug_data"]
365388
)

0 commit comments

Comments
 (0)