|
11 | 11 |
|
12 | 12 | import os |
13 | 13 | import pickle |
| 14 | +import random |
14 | 15 | from abc import ABC |
15 | 16 | from collections import OrderedDict |
16 | | -from typing import Any, Dict, Generator, List, Optional, Tuple, Union |
| 17 | +from itertools import cycle |
| 18 | +from typing import Any, Dict, Generator, List, Literal, Optional, Tuple, Union |
17 | 19 |
|
18 | 20 | import fastobo |
19 | 21 | import networkx as nx |
20 | 22 | import pandas as pd |
21 | 23 | import requests |
22 | 24 | import torch |
| 25 | +from rdkit import Chem |
23 | 26 |
|
24 | 27 | from chebai.preprocessing import reader as dr |
25 | 28 | from chebai.preprocessing.datasets.base import XYBaseDataModule, _DynamicDataset |
@@ -134,8 +137,17 @@ def __init__( |
134 | 137 | self, |
135 | 138 | chebi_version_train: Optional[int] = None, |
136 | 139 | single_class: Optional[int] = None, |
| 140 | + augment_smiles: bool = False, |
| 141 | + aug_smiles_variations: Literal["max"] | int | None = None, |
137 | 142 | **kwargs, |
138 | 143 | ): |
| 144 | + if augment_smiles: |
| 145 | + assert aug_smiles_variations is not None, "" |
| 146 | + assert aug_smiles_variations == "max" or ( |
| 147 | + int(aug_smiles_variations) and int(aug_smiles_variations) >= 1 |
| 148 | + ), "" |
| 149 | + self.augment_smiles = augment_smiles |
| 150 | + self.aug_smiles_variations = aug_smiles_variations |
139 | 151 | # predict only single class (given as id of one of the classes present in the raw data set) |
140 | 152 | self.single_class = single_class |
141 | 153 | super(_ChEBIDataExtractor, self).__init__(**kwargs) |
@@ -304,8 +316,34 @@ def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame: |
304 | 316 | # This filters the DataFrame to include only the rows where at least one value in the row from 4th column |
305 | 317 | # onwards is True/non-zero. |
306 | 318 | data = data[data.iloc[:, self._LABELS_START_IDX :].any(axis=1)] |
| 319 | + if self.augment_smiles: |
| 320 | + data = self._perform_smiles_augmentation() |
| 321 | + |
307 | 322 | return data |
308 | 323 |
|
| 324 | + def _perform_smiles_augmentation(self, data_df: pd.DataFrame) -> pd.DataFrame: |
| 325 | + data_df["augmented_smiles"] = data_df["SMILES"].apply(self.augment_smiles()) |
| 326 | + # Explode the list of augmented smiles into multiple rows |
| 327 | + # augmented smiles will have same ident, as of the original, but does it matter ? |
| 328 | + exploded_df = data_df.explode("augmented_smiles").reset_index(drop=True) |
| 329 | + exploded_df.rename(columns={"augmented_smile", "SMILES"}) |
| 330 | + return exploded_df |
| 331 | + |
| 332 | + def augment_smiles(self, smiles: str): |
| 333 | + mol: Chem.Mol = Chem.MolFromSmiles(smiles) |
| 334 | + # As chebi smiles might be different than rdkit smiles, for same canonical mol |
| 335 | + # TODO: if same smiles is generated as mol_smiles remove it |
| 336 | + # mol_smiles = Chem.MolToSmiles(smiles) |
| 337 | + atom_ids = [atom.GetIdx() for atom in mol.GetAtoms()] |
| 338 | + random.shuffle(atom_ids) # seed set by lightning |
| 339 | + atom_id_iter = cycle(atom_ids) |
| 340 | + return list( |
| 341 | + { |
| 342 | + Chem.MolToSmiles(mol, rootedAtAtom=next(atom_id_iter), doRandom=True) |
| 343 | + for _ in range(self.aug_smiles_variations) |
| 344 | + } |
| 345 | + ) + [smiles] |
| 346 | + |
309 | 347 | # ------------------------------ Phase: Setup data ----------------------------------- |
310 | 348 | def setup_processed(self) -> None: |
311 | 349 | """ |
|
0 commit comments