Skip to content

Commit 30c5065

Browse files
committed
smiles augmentation poc code
1 parent 1d8a7c3 commit 30c5065

File tree

1 file changed

+39
-1
lines changed

1 file changed

+39
-1
lines changed

chebai/preprocessing/datasets/chebi.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,18 @@
1111

1212
import os
1313
import pickle
14+
import random
1415
from abc import ABC
1516
from collections import OrderedDict
16-
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
17+
from itertools import cycle
18+
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple, Union
1719

1820
import fastobo
1921
import networkx as nx
2022
import pandas as pd
2123
import requests
2224
import torch
25+
from rdkit import Chem
2326

2427
from chebai.preprocessing import reader as dr
2528
from chebai.preprocessing.datasets.base import XYBaseDataModule, _DynamicDataset
@@ -134,8 +137,17 @@ def __init__(
134137
self,
135138
chebi_version_train: Optional[int] = None,
136139
single_class: Optional[int] = None,
140+
augment_smiles: bool = False,
141+
aug_smiles_variations: Literal["max"] | int | None = None,
137142
**kwargs,
138143
):
144+
if augment_smiles:
145+
assert aug_smiles_variations is not None, ""
146+
assert aug_smiles_variations == "max" or (
147+
int(aug_smiles_variations) and int(aug_smiles_variations) >= 1
148+
), ""
149+
self.augment_smiles = augment_smiles
150+
self.aug_smiles_variations = aug_smiles_variations
139151
# predict only single class (given as id of one of the classes present in the raw data set)
140152
self.single_class = single_class
141153
super(_ChEBIDataExtractor, self).__init__(**kwargs)
@@ -304,8 +316,34 @@ def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame:
304316
# This filters the DataFrame to include only the rows where at least one value in the row from 4th column
305317
# onwards is True/non-zero.
306318
data = data[data.iloc[:, self._LABELS_START_IDX :].any(axis=1)]
319+
if self.augment_smiles:
320+
data = self._perform_smiles_augmentation()
321+
307322
return data
308323

324+
def _perform_smiles_augmentation(self, data_df: pd.DataFrame) -> pd.DataFrame:
325+
data_df["augmented_smiles"] = data_df["SMILES"].apply(self.augment_smiles())
326+
# Explode the list of augmented smiles into multiple rows
327+
# augmented smiles will have same ident, as of the original, but does it matter ?
328+
exploded_df = data_df.explode("augmented_smiles").reset_index(drop=True)
329+
exploded_df.rename(columns={"augmented_smile", "SMILES"})
330+
return exploded_df
331+
332+
def augment_smiles(self, smiles: str):
333+
mol: Chem.Mol = Chem.MolFromSmiles(smiles)
334+
# As chebi smiles might be different than rdkit smiles, for same canonical mol
335+
# TODO: if same smiles is generated as mol_smiles remove it
336+
# mol_smiles = Chem.MolToSmiles(smiles)
337+
atom_ids = [atom.GetIdx() for atom in mol.GetAtoms()]
338+
random.shuffle(atom_ids) # seed set by lightning
339+
atom_id_iter = cycle(atom_ids)
340+
return list(
341+
{
342+
Chem.MolToSmiles(mol, rootedAtAtom=next(atom_id_iter), doRandom=True)
343+
for _ in range(self.aug_smiles_variations)
344+
}
345+
) + [smiles]
346+
309347
# ------------------------------ Phase: Setup data -----------------------------------
310348
def setup_processed(self) -> None:
311349
"""

0 commit comments

Comments
 (0)