diff --git a/README.md b/README.md index 12b0b95f..eeecd714 100644 --- a/README.md +++ b/README.md @@ -78,8 +78,21 @@ The `classes_path` is the path to the dataset's `raw/classes.txt` file that cont ## Evaluation -An example for evaluating a model trained on the ontology extension task is given in `tutorials/eval_model_basic.ipynb`. -It takes in the finetuned model as input for performing the evaluation. +You can evaluate a model trained on the ontology extension task in one of two ways: + +### 1. Using the Jupyter Notebook +An example notebook is provided at `tutorials/eval_model_basic.ipynb`. +- Load your finetuned model and run the evaluation cells to compute metrics on the test set. + +### 2. Using the Lightning CLI +Alternatively, you can evaluate the model via the CLI: + +```bash +python -m chebai test --trainer=configs/training/default_trainer.yml --trainer.devices=1 --trainer.num_nodes=1 --ckpt_path=[path-to-finetuned-model] --model=configs/model/electra.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --data=configs/data/chebi/chebi50.yml --data.init_args.batch_size=32 --data.init_args.num_workers=10 --data.init_args.chebi_version=[chebi-version] --model.pass_loss_kwargs=false --model.criterion=configs/loss/bce.yml --model.criterion.init_args.beta=0.99 --data.init_args.splits_file_path=[path-to-splits-file] +``` + +> **Note**: It is recommended to use `devices=1` and `num_nodes=1` during testing; multi-device settings use a `DistributedSampler`, which may replicate some samples to maintain equal batch sizes, so using a single device ensures that each sample or batch is evaluated exactly once. + ## Cross-validation You can do inner k-fold cross-validation, i.e., train models on k train-validation splits that all use the same test diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index a229e7af..12eb634c 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -1,6 +1,7 @@ import os import random from abc import ABC, abstractmethod +from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple, Union import lightning as pl @@ -416,10 +417,17 @@ def prepare_data(self, *args, **kwargs) -> None: self._prepare_data_flag += 1 self._perform_data_preparation(*args, **kwargs) + self._after_prepare_data(*args, **kwargs) def _perform_data_preparation(self, *args, **kwargs) -> None: raise NotImplementedError + def _after_prepare_data(self, *args, **kwargs) -> None: + """ + Hook to perform additional pre-processing after pre-processed data is available. + """ + ... + def setup(self, *args, **kwargs) -> None: """ Setup the data module. @@ -461,14 +469,17 @@ def _set_processed_data_props(self): - self._num_of_labels: Number of target labels in the dataset. - self._feature_vector_size: Maximum feature vector length across all data points. """ - data_pt = torch.load( - os.path.join(self.processed_dir, self.processed_file_names_dict["data"]), - weights_only=False, + pt_file_path = os.path.join( + self.processed_dir, self.processed_file_names_dict["data"] ) + data_pt = torch.load(pt_file_path, weights_only=False) self._num_of_labels = len(data_pt[0]["labels"]) self._feature_vector_size = max(len(d["features"]) for d in data_pt) + print( + f"Number of samples in encoded data ({pt_file_path}): {len(data_pt)} samples" + ) print(f"Number of labels for loaded data: {self._num_of_labels}") print(f"Feature vector size: {self._feature_vector_size}") @@ -731,6 +742,7 @@ def __init__( self.splits_file_path = self._validate_splits_file_path( kwargs.get("splits_file_path", None) ) + self._data_pkl_filename: str = "data.pkl" @staticmethod def _validate_splits_file_path(splits_file_path: Optional[str]) -> Optional[str]: @@ -869,6 +881,21 @@ def save_processed(self, data: pd.DataFrame, filename: str) -> None: """ pd.to_pickle(data, open(os.path.join(self.processed_dir_main, filename), "wb")) + def get_processed_pickled_df_file(self, filename: str) -> Optional[pd.DataFrame]: + """ + Gets the processed dataset pickle file. + + Args: + filename (str): The filename for the pickle file. + + Returns: + pd.DataFrame: The processed dataset as a DataFrame. + """ + file_path = Path(self.processed_dir_main) / filename + if file_path.exists(): + return pd.read_pickle(file_path) + return None + # ------------------------------ Phase: Setup data ----------------------------------- def setup_processed(self) -> None: """ @@ -907,7 +934,9 @@ def _get_data_size(input_file_path: str) -> int: int: The size of the data. """ with open(input_file_path, "rb") as f: - return len(pd.read_pickle(f)) + df = pd.read_pickle(f) + print(f"Processed data size ({input_file_path}): {len(df)} rows") + return len(df) @abstractmethod def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, None]: @@ -1223,7 +1252,7 @@ def processed_main_file_names_dict(self) -> dict: dict: A dictionary mapping dataset key to their respective file names. For example, {"data": "data.pkl"}. """ - return {"data": "data.pkl"} + return {"data": self._data_pkl_filename} @property def raw_file_names(self) -> List[str]: diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index cbd04895..887107b5 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -11,12 +11,15 @@ import os import pickle +import random from abc import ABC from collections import OrderedDict -from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple, Union +from itertools import cycle, permutations, product +from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Union import pandas as pd import torch +from rdkit import Chem from chebai.preprocessing import reader as dr from chebai.preprocessing.datasets.base import XYBaseDataModule, _DynamicDataset @@ -135,8 +138,27 @@ def __init__( self, chebi_version_train: Optional[int] = None, single_class: Optional[int] = None, + augment_smiles: bool = False, + aug_smiles_variations: Optional[int] = None, **kwargs, ): + if bool(augment_smiles): + assert ( + int(aug_smiles_variations) > 0 + ), "Number of variations must be greater than 0" + aug_smiles_variations = int(aug_smiles_variations) + + if not kwargs.get("splits_file_path", None): + raise ValueError( + "When using SMILES augmentation, a splits_file_path must be provided to ensure consistent splits." + ) + + reader_kwargs = kwargs.get("reader_kwargs", {}) + reader_kwargs["canonicalize_smiles"] = False + kwargs["reader_kwargs"] = reader_kwargs + + self.augment_smiles = bool(augment_smiles) + self.aug_smiles_variations = aug_smiles_variations # predict only single class (given as id of one of the classes present in the raw data set) self.single_class = single_class super(_ChEBIDataExtractor, self).__init__(**kwargs) @@ -151,6 +173,8 @@ def __init__( _init_kwargs["chebi_version"] = self.chebi_version_train self._chebi_version_train_obj = self.__class__( single_class=self.single_class, + augment_smiles=self.augment_smiles, + aug_smiles_variations=self.aug_smiles_variations, **_init_kwargs, ) @@ -312,6 +336,75 @@ def _graph_to_raw_dataset(self, g: "nx.DiGraph") -> pd.DataFrame: return data + def _after_prepare_data(self, *args, **kwargs) -> None: + self._perform_smiles_augmentation() + + def _perform_smiles_augmentation(self) -> None: + if not self.augment_smiles: + return + + aug_pkl_file_name = self.processed_main_file_names_dict["aug_data"] + aug_data_df = self.get_processed_pickled_df_file(aug_pkl_file_name) + if aug_data_df is not None: + self._data_pkl_filename = aug_pkl_file_name + return + + data_df = self.get_processed_pickled_df_file( + self.processed_main_file_names_dict["data"] + ) + + AUG_SMILES_VARIATIONS = self.aug_smiles_variations + + def generate_augmented_smiles(smiles: str) -> list[str]: + mol: Chem.Mol = Chem.MolFromSmiles(smiles) + if mol is None: + return [smiles] # if mol is None, return original SMILES + + # sanitization set to False, as it can alter the fragment representation in ways you might not want. + # As we don’t want RDKit to "fix" fragments, only need the fragments as-is, to generate SMILES strings. + frags = Chem.GetMolFrags(mol, asMols=True, sanitizeFrags=False) + augmented = set() + + frag_smiles: list[set] = [] + for frag in frags: + atom_ids = [atom.GetIdx() for atom in frag.GetAtoms()] + random.shuffle(atom_ids) # seed set by lightning + atom_id_iter = cycle(atom_ids) + frag_smiles.append( + { + Chem.MolToSmiles( + frag, rootedAtAtom=next(atom_id_iter), doRandom=True + ) + for _ in range(AUG_SMILES_VARIATIONS) + } + ) + if len(frags) > 1: + # all permutations (ignoring the set order, meaning mixing sets in every order), + aug_counter: int = 0 + for perm in permutations(frag_smiles): + for combo in product(*perm): + augmented.add(".".join(combo)) + aug_counter += 1 + if aug_counter >= AUG_SMILES_VARIATIONS: + break + if aug_counter >= AUG_SMILES_VARIATIONS: + break + else: + augmented = frag_smiles[0] + + return [smiles] + list(augmented) + + data_df["SMILES"] = data_df["SMILES"].apply(generate_augmented_smiles) + + # Explode the list of augmented smiles into multiple rows + # augmented smiles will have same ident, as of the original, but does it matter ? + # instead its helpful to group augmented smiles generated from the same original SMILES + exploded_df = data_df.explode("SMILES").reset_index(drop=True) + self.save_processed( + exploded_df, self.processed_main_file_names_dict["aug_data"] + ) + self._data_pkl_filename = aug_pkl_file_name + # ------------------------------ Phase: Setup data ----------------------------------- def setup_processed(self) -> None: """ @@ -339,7 +432,7 @@ def setup_processed(self) -> None: print("Calling the setup method related to it") self._chebi_version_train_obj.setup() - def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, None]: + def _load_dict(self, input_file_path: str) -> Generator[dict[str, Any], None, None]: """ Loads a dictionary from a pickled file, yielding individual dictionaries for each row. @@ -380,7 +473,7 @@ def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, No ) # ------------------------------ Phase: Dynamic Splits ----------------------------------- - def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: + def _get_data_splits(self) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: """ Loads encoded/transformed data and generates training, validation, and test splits. @@ -544,6 +637,37 @@ def processed_dir(self) -> str: def raw_file_names_dict(self) -> dict: return {"chebi": "chebi.obo"} + @property + def processed_main_file_names_dict(self) -> dict: + """ + Returns a dictionary mapping processed data file names. + + Returns: + dict: A dictionary mapping dataset key to their respective file names. + For example, {"data": "data.pkl"}. + """ + p_dict = super().processed_main_file_names_dict + if self.augment_smiles: + p_dict["aug_data"] = f"aug_data_var{self.aug_smiles_variations}.pkl" + return p_dict + + @property + def processed_file_names_dict(self) -> dict: + """ + Returns a dictionary for the processed and tokenized data files. + + Returns: + dict: A dictionary mapping dataset keys to their respective file names. + For example, {"data": "data.pt"}. + """ + if not self.augment_smiles: + return super().processed_file_names_dict + if self.n_token_limit is not None: + return { + "data": f"aug_data_var{self.aug_smiles_variations}_maxlen{self.n_token_limit}.pt" + } + return {"data": f"aug_data_var{self.aug_smiles_variations}.pt"} + class JCIExtendedBase(_ChEBIDataExtractor): @property