|
15 | 15 | from abc import ABC |
16 | 16 | from collections import OrderedDict |
17 | 17 | from itertools import cycle |
18 | | -from typing import Any, Dict, Generator, List, Literal, Optional, Tuple, Union |
| 18 | +from multiprocessing import Pool |
| 19 | +from typing import Any, Dict, Generator, List, Optional, Tuple, Union |
19 | 20 |
|
20 | 21 | import fastobo |
21 | 22 | import networkx as nx |
@@ -138,14 +139,14 @@ def __init__( |
138 | 139 | chebi_version_train: Optional[int] = None, |
139 | 140 | single_class: Optional[int] = None, |
140 | 141 | augment_smiles: bool = False, |
141 | | - aug_smiles_variations: Literal["max"] | int | None = None, |
| 142 | + aug_smiles_variations: Optional[int] = None, |
142 | 143 | **kwargs, |
143 | 144 | ): |
144 | 145 | if augment_smiles: |
145 | 146 | assert aug_smiles_variations is not None, "" |
146 | | - assert aug_smiles_variations == "max" or ( |
147 | | - int(aug_smiles_variations) and int(aug_smiles_variations) >= 1 |
148 | | - ), "" |
| 147 | + assert ( |
| 148 | + aug_smiles_variations > 0 |
| 149 | + ), "Number of variations must be greater than 0" |
149 | 150 | kwargs.setdefault("reader_kwargs", {}).update(canonicalize_smiles=False) |
150 | 151 |
|
151 | 152 | self.augment_smiles = augment_smiles |
@@ -321,33 +322,34 @@ def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame: |
321 | 322 | # onwards is True/non-zero. |
322 | 323 | data = data[data.iloc[:, self._LABELS_START_IDX :].any(axis=1)] |
323 | 324 | if self.augment_smiles: |
324 | | - data = self._perform_smiles_augmentation() |
325 | | - |
| 325 | + return self._perform_smiles_augmentation(data) |
326 | 326 | return data |
327 | 327 |
|
328 | 328 | def _perform_smiles_augmentation(self, data_df: pd.DataFrame) -> pd.DataFrame: |
329 | | - data_df["augmented_smiles"] = data_df["SMILES"].apply(self.augment_smiles()) |
| 329 | + AUG_SMILES_VARIATIONS = self.aug_smiles_variations |
| 330 | + |
| 331 | + def generate_augmented_smiles(smiles: str): |
| 332 | + mol: Chem.Mol = Chem.MolFromSmiles(smiles) |
| 333 | + atom_ids = [atom.GetIdx() for atom in mol.GetAtoms()] |
| 334 | + random.shuffle(atom_ids) # seed set by lightning |
| 335 | + atom_id_iter = cycle(atom_ids) |
| 336 | + augmented = { |
| 337 | + Chem.MolToSmiles(mol, rootedAtAtom=next(atom_id_iter), doRandom=True) |
| 338 | + for _ in range(AUG_SMILES_VARIATIONS) |
| 339 | + } |
| 340 | + augmented.add(smiles) |
| 341 | + return list(augmented) |
| 342 | + |
| 343 | + with Pool() as pool: |
| 344 | + data_df["augmented_smiles"] = pool.map( |
| 345 | + generate_augmented_smiles, data_df["SMILES"] |
| 346 | + ) |
330 | 347 | # Explode the list of augmented smiles into multiple rows |
331 | 348 | # augmented smiles will have same ident, as of the original, but does it matter ? |
332 | 349 | exploded_df = data_df.explode("augmented_smiles").reset_index(drop=True) |
333 | | - exploded_df.rename(columns={"augmented_smile", "SMILES"}) |
| 350 | + exploded_df.rename(columns={"augmented_smiles": "SMILES"}, inplace=True) |
334 | 351 | return exploded_df |
335 | 352 |
|
336 | | - def augment_smiles(self, smiles: str): |
337 | | - mol: Chem.Mol = Chem.MolFromSmiles(smiles) |
338 | | - # As chebi smiles might be different than rdkit smiles, for same canonical mol |
339 | | - # TODO: if same smiles is generated as mol_smiles remove it |
340 | | - # mol_smiles = Chem.MolToSmiles(smiles) |
341 | | - atom_ids = [atom.GetIdx() for atom in mol.GetAtoms()] |
342 | | - random.shuffle(atom_ids) # seed set by lightning |
343 | | - atom_id_iter = cycle(atom_ids) |
344 | | - return list( |
345 | | - { |
346 | | - Chem.MolToSmiles(mol, rootedAtAtom=next(atom_id_iter), doRandom=True) |
347 | | - for _ in range(self.aug_smiles_variations) |
348 | | - } |
349 | | - ) + [smiles] |
350 | | - |
351 | 353 | # ------------------------------ Phase: Setup data ----------------------------------- |
352 | 354 | def setup_processed(self) -> None: |
353 | 355 | """ |
@@ -580,6 +582,32 @@ def processed_dir(self) -> str: |
580 | 582 | def raw_file_names_dict(self) -> dict: |
581 | 583 | return {"chebi": "chebi.obo"} |
582 | 584 |
|
| 585 | + @property |
| 586 | + def processed_main_file_names_dict(self) -> dict: |
| 587 | + """ |
| 588 | + Returns a dictionary mapping processed data file names. |
| 589 | +
|
| 590 | + Returns: |
| 591 | + dict: A dictionary mapping dataset key to their respective file names. |
| 592 | + For example, {"data": "data.pkl"}. |
| 593 | + """ |
| 594 | + if not self.augment_smiles: |
| 595 | + return super().processed_main_file_names_dict |
| 596 | + return {"data": f"aug_data_var{self.aug_smiles_variations}.pkl"} |
| 597 | + |
| 598 | + @property |
| 599 | + def processed_file_names_dict(self) -> dict: |
| 600 | + """ |
| 601 | + Returns a dictionary for the processed and tokenized data files. |
| 602 | +
|
| 603 | + Returns: |
| 604 | + dict: A dictionary mapping dataset keys to their respective file names. |
| 605 | + For example, {"data": "data.pt"}. |
| 606 | + """ |
| 607 | + if not self.augment_smiles: |
| 608 | + return super().processed_file_names_dict |
| 609 | + return {"data": f"aug_data_var{self.aug_smiles_variations}.pt"} |
| 610 | + |
583 | 611 |
|
584 | 612 | class JCIExtendedBase(_ChEBIDataExtractor): |
585 | 613 | @property |
|
0 commit comments