Skip to content

Commit c51a27b

Browse files
committed
multi-processing for augment func
1 parent f127b5e commit c51a27b

File tree

1 file changed

+52
-24
lines changed

1 file changed

+52
-24
lines changed

chebai/preprocessing/datasets/chebi.py

Lines changed: 52 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
from abc import ABC
1616
from collections import OrderedDict
1717
from itertools import cycle
18-
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple, Union
18+
from multiprocessing import Pool
19+
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
1920

2021
import fastobo
2122
import networkx as nx
@@ -138,14 +139,14 @@ def __init__(
138139
chebi_version_train: Optional[int] = None,
139140
single_class: Optional[int] = None,
140141
augment_smiles: bool = False,
141-
aug_smiles_variations: Literal["max"] | int | None = None,
142+
aug_smiles_variations: Optional[int] = None,
142143
**kwargs,
143144
):
144145
if augment_smiles:
145146
assert aug_smiles_variations is not None, ""
146-
assert aug_smiles_variations == "max" or (
147-
int(aug_smiles_variations) and int(aug_smiles_variations) >= 1
148-
), ""
147+
assert (
148+
aug_smiles_variations > 0
149+
), "Number of variations must be greater than 0"
149150
kwargs.setdefault("reader_kwargs", {}).update(canonicalize_smiles=False)
150151

151152
self.augment_smiles = augment_smiles
@@ -321,33 +322,34 @@ def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame:
321322
# onwards is True/non-zero.
322323
data = data[data.iloc[:, self._LABELS_START_IDX :].any(axis=1)]
323324
if self.augment_smiles:
324-
data = self._perform_smiles_augmentation()
325-
325+
return self._perform_smiles_augmentation(data)
326326
return data
327327

328328
def _perform_smiles_augmentation(self, data_df: pd.DataFrame) -> pd.DataFrame:
329-
data_df["augmented_smiles"] = data_df["SMILES"].apply(self.augment_smiles())
329+
AUG_SMILES_VARIATIONS = self.aug_smiles_variations
330+
331+
def generate_augmented_smiles(smiles: str):
332+
mol: Chem.Mol = Chem.MolFromSmiles(smiles)
333+
atom_ids = [atom.GetIdx() for atom in mol.GetAtoms()]
334+
random.shuffle(atom_ids) # seed set by lightning
335+
atom_id_iter = cycle(atom_ids)
336+
augmented = {
337+
Chem.MolToSmiles(mol, rootedAtAtom=next(atom_id_iter), doRandom=True)
338+
for _ in range(AUG_SMILES_VARIATIONS)
339+
}
340+
augmented.add(smiles)
341+
return list(augmented)
342+
343+
with Pool() as pool:
344+
data_df["augmented_smiles"] = pool.map(
345+
generate_augmented_smiles, data_df["SMILES"]
346+
)
330347
# Explode the list of augmented smiles into multiple rows
331348
# augmented smiles will have same ident, as of the original, but does it matter ?
332349
exploded_df = data_df.explode("augmented_smiles").reset_index(drop=True)
333-
exploded_df.rename(columns={"augmented_smile", "SMILES"})
350+
exploded_df.rename(columns={"augmented_smiles": "SMILES"}, inplace=True)
334351
return exploded_df
335352

336-
def augment_smiles(self, smiles: str):
337-
mol: Chem.Mol = Chem.MolFromSmiles(smiles)
338-
# As chebi smiles might be different than rdkit smiles, for same canonical mol
339-
# TODO: if same smiles is generated as mol_smiles remove it
340-
# mol_smiles = Chem.MolToSmiles(smiles)
341-
atom_ids = [atom.GetIdx() for atom in mol.GetAtoms()]
342-
random.shuffle(atom_ids) # seed set by lightning
343-
atom_id_iter = cycle(atom_ids)
344-
return list(
345-
{
346-
Chem.MolToSmiles(mol, rootedAtAtom=next(atom_id_iter), doRandom=True)
347-
for _ in range(self.aug_smiles_variations)
348-
}
349-
) + [smiles]
350-
351353
# ------------------------------ Phase: Setup data -----------------------------------
352354
def setup_processed(self) -> None:
353355
"""
@@ -580,6 +582,32 @@ def processed_dir(self) -> str:
580582
def raw_file_names_dict(self) -> dict:
581583
return {"chebi": "chebi.obo"}
582584

585+
@property
586+
def processed_main_file_names_dict(self) -> dict:
587+
"""
588+
Returns a dictionary mapping processed data file names.
589+
590+
Returns:
591+
dict: A dictionary mapping dataset key to their respective file names.
592+
For example, {"data": "data.pkl"}.
593+
"""
594+
if not self.augment_smiles:
595+
return super().processed_main_file_names_dict
596+
return {"data": f"aug_data_var{self.aug_smiles_variations}.pkl"}
597+
598+
@property
599+
def processed_file_names_dict(self) -> dict:
600+
"""
601+
Returns a dictionary for the processed and tokenized data files.
602+
603+
Returns:
604+
dict: A dictionary mapping dataset keys to their respective file names.
605+
For example, {"data": "data.pt"}.
606+
"""
607+
if not self.augment_smiles:
608+
return super().processed_file_names_dict
609+
return {"data": f"aug_data_var{self.aug_smiles_variations}.pt"}
610+
583611

584612
class JCIExtendedBase(_ChEBIDataExtractor):
585613
@property

0 commit comments

Comments
 (0)