Skip to content

Commit d864011

Browse files
committed
perform aug after pkl file is created
1 parent c51a27b commit d864011

File tree

2 files changed

+66
-24
lines changed

2 files changed

+66
-24
lines changed

chebai/preprocessing/datasets/base.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import random
33
from abc import ABC, abstractmethod
4+
from pathlib import Path
45
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
56

67
import lightning as pl
@@ -419,10 +420,17 @@ def prepare_data(self, *args, **kwargs) -> None:
419420

420421
self._prepare_data_flag += 1
421422
self._perform_data_preparation(*args, **kwargs)
423+
self._after_prepare_data(*args, **kwargs)
422424

423425
def _perform_data_preparation(self, *args, **kwargs) -> None:
424426
raise NotImplementedError
425427

428+
def _after_prepare_data(self, *args, **kwargs) -> None:
429+
"""
430+
Hook to perform additional pre-processing after pre-processed data is available.
431+
"""
432+
...
433+
426434
def setup(self, *args, **kwargs) -> None:
427435
"""
428436
Setup the data module.
@@ -872,6 +880,21 @@ def save_processed(self, data: pd.DataFrame, filename: str) -> None:
872880
"""
873881
pd.to_pickle(data, open(os.path.join(self.processed_dir_main, filename), "wb"))
874882

883+
def get_processed_pickled_df_file(self, filename: str) -> pd.DataFrame | None:
884+
"""
885+
Gets the processed dataset pickle file.
886+
887+
Args:
888+
filename (str): The filename for the pickle file.
889+
890+
Returns:
891+
pd.DataFrame: The processed dataset as a DataFrame.
892+
"""
893+
file_path = Path(self.processed_dir_main) / filename
894+
if file_path.exists():
895+
return pd.read_pickle(file_path)
896+
return None
897+
875898
# ------------------------------ Phase: Setup data -----------------------------------
876899
def setup_processed(self) -> None:
877900
"""

chebai/preprocessing/datasets/chebi.py

Lines changed: 43 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@
1515
from abc import ABC
1616
from collections import OrderedDict
1717
from itertools import cycle
18-
from multiprocessing import Pool
19-
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
18+
from typing import Any, Generator, Optional, Union
2019

2120
import fastobo
2221
import networkx as nx
@@ -142,15 +141,16 @@ def __init__(
142141
aug_smiles_variations: Optional[int] = None,
143142
**kwargs,
144143
):
145-
if augment_smiles:
146-
assert aug_smiles_variations is not None, ""
144+
if bool(augment_smiles):
147145
assert (
148-
aug_smiles_variations > 0
146+
int(aug_smiles_variations) > 0
149147
), "Number of variations must be greater than 0"
150-
kwargs.setdefault("reader_kwargs", {}).update(canonicalize_smiles=False)
148+
reader_kwargs = kwargs.get("reader_kwargs", {})
149+
reader_kwargs["canonicalize_smiles"] = False
150+
kwargs["reader_kwargs"] = reader_kwargs
151151

152-
self.augment_smiles = augment_smiles
153-
self.aug_smiles_variations = aug_smiles_variations
152+
self.augment_smiles = bool(augment_smiles)
153+
self.aug_smiles_variations = int(aug_smiles_variations)
154154
# predict only single class (given as id of one of the classes present in the raw data set)
155155
self.single_class = single_class
156156
super(_ChEBIDataExtractor, self).__init__(**kwargs)
@@ -321,14 +321,28 @@ def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame:
321321
# This filters the DataFrame to include only the rows where at least one value in the row from 4th column
322322
# onwards is True/non-zero.
323323
data = data[data.iloc[:, self._LABELS_START_IDX :].any(axis=1)]
324-
if self.augment_smiles:
325-
return self._perform_smiles_augmentation(data)
326324
return data
327325

328-
def _perform_smiles_augmentation(self, data_df: pd.DataFrame) -> pd.DataFrame:
326+
def _after_prepare_data(self, *args, **kwargs) -> None:
327+
self._perform_smiles_augmentation()
328+
329+
def _perform_smiles_augmentation(self) -> None:
330+
if not self.augment_smiles:
331+
return
332+
333+
aug_data_df = self.get_processed_pickled_df_file(
334+
self.processed_main_file_names_dict["aug_data"]
335+
)
336+
if aug_data_df is not None:
337+
return
338+
339+
data_df = self.get_processed_pickled_df_file(
340+
self.processed_main_file_names_dict["data"]
341+
)
342+
329343
AUG_SMILES_VARIATIONS = self.aug_smiles_variations
330344

331-
def generate_augmented_smiles(smiles: str):
345+
def generate_augmented_smiles(smiles: str) -> list[str]:
332346
mol: Chem.Mol = Chem.MolFromSmiles(smiles)
333347
atom_ids = [atom.GetIdx() for atom in mol.GetAtoms()]
334348
random.shuffle(atom_ids) # seed set by lightning
@@ -340,15 +354,15 @@ def generate_augmented_smiles(smiles: str):
340354
augmented.add(smiles)
341355
return list(augmented)
342356

343-
with Pool() as pool:
344-
data_df["augmented_smiles"] = pool.map(
345-
generate_augmented_smiles, data_df["SMILES"]
346-
)
357+
data_df["augmented_smiles"] = data_df["SMILES"].apply(generate_augmented_smiles)
358+
347359
# Explode the list of augmented smiles into multiple rows
348360
# augmented smiles will have same ident, as of the original, but does it matter ?
349361
exploded_df = data_df.explode("augmented_smiles").reset_index(drop=True)
350362
exploded_df.rename(columns={"augmented_smiles": "SMILES"}, inplace=True)
351-
return exploded_df
363+
self.save_processed(
364+
exploded_df, self.processed_main_file_names_dict["aug_data"]
365+
)
352366

353367
# ------------------------------ Phase: Setup data -----------------------------------
354368
def setup_processed(self) -> None:
@@ -377,7 +391,7 @@ def setup_processed(self) -> None:
377391
print("Calling the setup method related to it")
378392
self._chebi_version_train_obj.setup()
379393

380-
def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, None]:
394+
def _load_dict(self, input_file_path: str) -> Generator[dict[str, Any], None, None]:
381395
"""
382396
Loads a dictionary from a pickled file, yielding individual dictionaries for each row.
383397
@@ -418,7 +432,7 @@ def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, No
418432
)
419433

420434
# ------------------------------ Phase: Dynamic Splits -----------------------------------
421-
def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
435+
def _get_data_splits(self) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
422436
"""
423437
Loads encoded/transformed data and generates training, validation, and test splits.
424438
@@ -591,9 +605,10 @@ def processed_main_file_names_dict(self) -> dict:
591605
dict: A dictionary mapping dataset key to their respective file names.
592606
For example, {"data": "data.pkl"}.
593607
"""
594-
if not self.augment_smiles:
595-
return super().processed_main_file_names_dict
596-
return {"data": f"aug_data_var{self.aug_smiles_variations}.pkl"}
608+
p_dict = super().processed_main_file_names_dict
609+
if self.augment_smiles:
610+
p_dict["aug_data"] = f"aug_data_var{self.aug_smiles_variations}.pkl"
611+
return p_dict
597612

598613
@property
599614
def processed_file_names_dict(self) -> dict:
@@ -606,6 +621,10 @@ def processed_file_names_dict(self) -> dict:
606621
"""
607622
if not self.augment_smiles:
608623
return super().processed_file_names_dict
624+
if self.n_token_limit is not None:
625+
return {
626+
"data": f"aug_data_var{self.aug_smiles_variations}_maxlen{self.n_token_limit}.pt"
627+
}
609628
return {"data": f"aug_data_var{self.aug_smiles_variations}.pt"}
610629

611630

@@ -644,7 +663,7 @@ def _name(self) -> str:
644663
"""
645664
return f"ChEBI{self.THRESHOLD}"
646665

647-
def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List:
666+
def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> list:
648667
"""
649668
Selects classes from the ChEBI dataset based on the number of successors meeting a specified threshold.
650669
@@ -856,7 +875,7 @@ def chebi_to_int(s: str) -> int:
856875
return int(s[s.index(":") + 1 :])
857876

858877

859-
def term_callback(doc: fastobo.term.TermFrame) -> Union[Dict, bool]:
878+
def term_callback(doc: fastobo.term.TermFrame) -> Union[dict, bool]:
860879
"""
861880
Extracts information from a ChEBI term document.
862881
This function takes a ChEBI term document as input and extracts relevant information such as the term ID, parents,

0 commit comments

Comments
 (0)