|
14 | 14 | import random |
15 | 15 | from abc import ABC |
16 | 16 | from collections import OrderedDict |
17 | | -from itertools import cycle |
| 17 | +from itertools import cycle, permutations, product |
18 | 18 | from typing import Any, Generator, Optional, Union |
19 | 19 |
|
20 | 20 | import fastobo |
@@ -145,12 +145,13 @@ def __init__( |
145 | 145 | assert ( |
146 | 146 | int(aug_smiles_variations) > 0 |
147 | 147 | ), "Number of variations must be greater than 0" |
| 148 | + aug_smiles_variations = int(aug_smiles_variations) |
148 | 149 | reader_kwargs = kwargs.get("reader_kwargs", {}) |
149 | 150 | reader_kwargs["canonicalize_smiles"] = False |
150 | 151 | kwargs["reader_kwargs"] = reader_kwargs |
151 | 152 |
|
152 | 153 | self.augment_smiles = bool(augment_smiles) |
153 | | - self.aug_smiles_variations = int(aug_smiles_variations) |
| 154 | + self.aug_smiles_variations = aug_smiles_variations |
154 | 155 | # predict only single class (given as id of one of the classes present in the raw data set) |
155 | 156 | self.single_class = single_class |
156 | 157 | super(_ChEBIDataExtractor, self).__init__(**kwargs) |
@@ -344,22 +345,44 @@ def _perform_smiles_augmentation(self) -> None: |
344 | 345 |
|
345 | 346 | def generate_augmented_smiles(smiles: str) -> list[str]: |
346 | 347 | mol: Chem.Mol = Chem.MolFromSmiles(smiles) |
347 | | - atom_ids = [atom.GetIdx() for atom in mol.GetAtoms()] |
348 | | - random.shuffle(atom_ids) # seed set by lightning |
349 | | - atom_id_iter = cycle(atom_ids) |
350 | | - augmented = { |
351 | | - Chem.MolToSmiles(mol, rootedAtAtom=next(atom_id_iter), doRandom=True) |
352 | | - for _ in range(AUG_SMILES_VARIATIONS) |
353 | | - } |
354 | | - augmented.add(smiles) |
355 | | - return list(augmented) |
356 | | - |
357 | | - data_df["augmented_smiles"] = data_df["SMILES"].apply(generate_augmented_smiles) |
| 348 | + frags = Chem.GetMolFrags(mol, asMols=True, sanitizeFrags=False) |
| 349 | + augmented = set() |
| 350 | + |
| 351 | + frag_smiles: list[set] = [] |
| 352 | + for frag in frags: |
| 353 | + atom_ids = [atom.GetIdx() for atom in frag.GetAtoms()] |
| 354 | + random.shuffle(atom_ids) # seed set by lightning |
| 355 | + atom_id_iter = cycle(atom_ids) |
| 356 | + frag_smiles.append( |
| 357 | + { |
| 358 | + Chem.MolToSmiles( |
| 359 | + frag, rootedAtAtom=next(atom_id_iter), doRandom=True |
| 360 | + ) |
| 361 | + for _ in range(AUG_SMILES_VARIATIONS) |
| 362 | + } |
| 363 | + ) |
| 364 | + if len(frags) > 1: |
| 365 | + # all permutations (ignoring the set order, meaning mixing sets in every order), |
| 366 | + aug_counter: int = 0 |
| 367 | + for perm in permutations(frag_smiles): |
| 368 | + for combo in product(*perm): |
| 369 | + augmented.add(".".join(combo)) |
| 370 | + aug_counter += 1 |
| 371 | + if aug_counter >= AUG_SMILES_VARIATIONS: |
| 372 | + break |
| 373 | + if aug_counter >= AUG_SMILES_VARIATIONS: |
| 374 | + break |
| 375 | + else: |
| 376 | + augmented = frag_smiles[0] |
| 377 | + |
| 378 | + return [smiles] + list(augmented) |
| 379 | + |
| 380 | + data_df["SMILES"] = data_df["SMILES"].apply(generate_augmented_smiles) |
358 | 381 |
|
359 | 382 | # Explode the list of augmented smiles into multiple rows |
360 | 383 | # augmented smiles will have same ident, as of the original, but does it matter ? |
361 | | - exploded_df = data_df.explode("augmented_smiles").reset_index(drop=True) |
362 | | - exploded_df.rename(columns={"augmented_smiles": "SMILES"}, inplace=True) |
| 384 | + # instead its helpful to group augmented smiles generated from the same original SMILES |
| 385 | + exploded_df = data_df.explode("SMILES").reset_index(drop=True) |
363 | 386 | self.save_processed( |
364 | 387 | exploded_df, self.processed_main_file_names_dict["aug_data"] |
365 | 388 | ) |
|
0 commit comments