From 30c5065c570ff0c66be10416871280fc7c0896c6 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 31 Jul 2025 00:11:58 +0200 Subject: [PATCH 01/13] smiles augmentation poc code --- chebai/preprocessing/datasets/chebi.py | 40 +++++++++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index 1df144d9..449a30e6 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -11,15 +11,18 @@ import os import pickle +import random from abc import ABC from collections import OrderedDict -from typing import Any, Dict, Generator, List, Optional, Tuple, Union +from itertools import cycle +from typing import Any, Dict, Generator, List, Literal, Optional, Tuple, Union import fastobo import networkx as nx import pandas as pd import requests import torch +from rdkit import Chem from chebai.preprocessing import reader as dr from chebai.preprocessing.datasets.base import XYBaseDataModule, _DynamicDataset @@ -134,8 +137,17 @@ def __init__( self, chebi_version_train: Optional[int] = None, single_class: Optional[int] = None, + augment_smiles: bool = False, + aug_smiles_variations: Literal["max"] | int | None = None, **kwargs, ): + if augment_smiles: + assert aug_smiles_variations is not None, "" + assert aug_smiles_variations == "max" or ( + int(aug_smiles_variations) and int(aug_smiles_variations) >= 1 + ), "" + self.augment_smiles = augment_smiles + self.aug_smiles_variations = aug_smiles_variations # predict only single class (given as id of one of the classes present in the raw data set) self.single_class = single_class super(_ChEBIDataExtractor, self).__init__(**kwargs) @@ -304,8 +316,34 @@ def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame: # This filters the DataFrame to include only the rows where at least one value in the row from 4th column # onwards is True/non-zero. data = data[data.iloc[:, self._LABELS_START_IDX :].any(axis=1)] + if self.augment_smiles: + data = self._perform_smiles_augmentation() + return data + def _perform_smiles_augmentation(self, data_df: pd.DataFrame) -> pd.DataFrame: + data_df["augmented_smiles"] = data_df["SMILES"].apply(self.augment_smiles()) + # Explode the list of augmented smiles into multiple rows + # augmented smiles will have same ident, as of the original, but does it matter ? + exploded_df = data_df.explode("augmented_smiles").reset_index(drop=True) + exploded_df.rename(columns={"augmented_smile", "SMILES"}) + return exploded_df + + def augment_smiles(self, smiles: str): + mol: Chem.Mol = Chem.MolFromSmiles(smiles) + # As chebi smiles might be different than rdkit smiles, for same canonical mol + # TODO: if same smiles is generated as mol_smiles remove it + # mol_smiles = Chem.MolToSmiles(smiles) + atom_ids = [atom.GetIdx() for atom in mol.GetAtoms()] + random.shuffle(atom_ids) # seed set by lightning + atom_id_iter = cycle(atom_ids) + return list( + { + Chem.MolToSmiles(mol, rootedAtAtom=next(atom_id_iter), doRandom=True) + for _ in range(self.aug_smiles_variations) + } + ) + [smiles] + # ------------------------------ Phase: Setup data ----------------------------------- def setup_processed(self) -> None: """ From f127b5ed6e793006fa75bc39ea434ea20988375e Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 9 Aug 2025 13:09:18 +0200 Subject: [PATCH 02/13] if aug is true, set reader's canoncialize as False https://github.com/ChEB-AI/python-chebai/pull/118#issuecomment-3170598184 --- chebai/preprocessing/datasets/chebi.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index 449a30e6..01714096 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -146,6 +146,8 @@ def __init__( assert aug_smiles_variations == "max" or ( int(aug_smiles_variations) and int(aug_smiles_variations) >= 1 ), "" + kwargs.setdefault("reader_kwargs", {}).update(canonicalize_smiles=False) + self.augment_smiles = augment_smiles self.aug_smiles_variations = aug_smiles_variations # predict only single class (given as id of one of the classes present in the raw data set) @@ -162,6 +164,8 @@ def __init__( _init_kwargs["chebi_version"] = self.chebi_version_train self._chebi_version_train_obj = self.__class__( single_class=self.single_class, + augment_smiles=self.augment_smiles, + aug_smiles_variations=self.aug_smiles_variations, **_init_kwargs, ) From c51a27b5f85c503056217a9cd8d8b2e0ccb219b7 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 9 Aug 2025 15:34:49 +0200 Subject: [PATCH 03/13] multi-processing for augment func --- chebai/preprocessing/datasets/chebi.py | 76 ++++++++++++++++++-------- 1 file changed, 52 insertions(+), 24 deletions(-) diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index 01714096..6db83842 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -15,7 +15,8 @@ from abc import ABC from collections import OrderedDict from itertools import cycle -from typing import Any, Dict, Generator, List, Literal, Optional, Tuple, Union +from multiprocessing import Pool +from typing import Any, Dict, Generator, List, Optional, Tuple, Union import fastobo import networkx as nx @@ -138,14 +139,14 @@ def __init__( chebi_version_train: Optional[int] = None, single_class: Optional[int] = None, augment_smiles: bool = False, - aug_smiles_variations: Literal["max"] | int | None = None, + aug_smiles_variations: Optional[int] = None, **kwargs, ): if augment_smiles: assert aug_smiles_variations is not None, "" - assert aug_smiles_variations == "max" or ( - int(aug_smiles_variations) and int(aug_smiles_variations) >= 1 - ), "" + assert ( + aug_smiles_variations > 0 + ), "Number of variations must be greater than 0" kwargs.setdefault("reader_kwargs", {}).update(canonicalize_smiles=False) self.augment_smiles = augment_smiles @@ -321,33 +322,34 @@ def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame: # onwards is True/non-zero. data = data[data.iloc[:, self._LABELS_START_IDX :].any(axis=1)] if self.augment_smiles: - data = self._perform_smiles_augmentation() - + return self._perform_smiles_augmentation(data) return data def _perform_smiles_augmentation(self, data_df: pd.DataFrame) -> pd.DataFrame: - data_df["augmented_smiles"] = data_df["SMILES"].apply(self.augment_smiles()) + AUG_SMILES_VARIATIONS = self.aug_smiles_variations + + def generate_augmented_smiles(smiles: str): + mol: Chem.Mol = Chem.MolFromSmiles(smiles) + atom_ids = [atom.GetIdx() for atom in mol.GetAtoms()] + random.shuffle(atom_ids) # seed set by lightning + atom_id_iter = cycle(atom_ids) + augmented = { + Chem.MolToSmiles(mol, rootedAtAtom=next(atom_id_iter), doRandom=True) + for _ in range(AUG_SMILES_VARIATIONS) + } + augmented.add(smiles) + return list(augmented) + + with Pool() as pool: + data_df["augmented_smiles"] = pool.map( + generate_augmented_smiles, data_df["SMILES"] + ) # Explode the list of augmented smiles into multiple rows # augmented smiles will have same ident, as of the original, but does it matter ? exploded_df = data_df.explode("augmented_smiles").reset_index(drop=True) - exploded_df.rename(columns={"augmented_smile", "SMILES"}) + exploded_df.rename(columns={"augmented_smiles": "SMILES"}, inplace=True) return exploded_df - def augment_smiles(self, smiles: str): - mol: Chem.Mol = Chem.MolFromSmiles(smiles) - # As chebi smiles might be different than rdkit smiles, for same canonical mol - # TODO: if same smiles is generated as mol_smiles remove it - # mol_smiles = Chem.MolToSmiles(smiles) - atom_ids = [atom.GetIdx() for atom in mol.GetAtoms()] - random.shuffle(atom_ids) # seed set by lightning - atom_id_iter = cycle(atom_ids) - return list( - { - Chem.MolToSmiles(mol, rootedAtAtom=next(atom_id_iter), doRandom=True) - for _ in range(self.aug_smiles_variations) - } - ) + [smiles] - # ------------------------------ Phase: Setup data ----------------------------------- def setup_processed(self) -> None: """ @@ -580,6 +582,32 @@ def processed_dir(self) -> str: def raw_file_names_dict(self) -> dict: return {"chebi": "chebi.obo"} + @property + def processed_main_file_names_dict(self) -> dict: + """ + Returns a dictionary mapping processed data file names. + + Returns: + dict: A dictionary mapping dataset key to their respective file names. + For example, {"data": "data.pkl"}. + """ + if not self.augment_smiles: + return super().processed_main_file_names_dict + return {"data": f"aug_data_var{self.aug_smiles_variations}.pkl"} + + @property + def processed_file_names_dict(self) -> dict: + """ + Returns a dictionary for the processed and tokenized data files. + + Returns: + dict: A dictionary mapping dataset keys to their respective file names. + For example, {"data": "data.pt"}. + """ + if not self.augment_smiles: + return super().processed_file_names_dict + return {"data": f"aug_data_var{self.aug_smiles_variations}.pt"} + class JCIExtendedBase(_ChEBIDataExtractor): @property From d86401152a04b585c5e4da7326cc64eb7dbdaeaf Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 10 Aug 2025 15:37:59 +0200 Subject: [PATCH 04/13] perform aug after pkl file is created --- chebai/preprocessing/datasets/base.py | 23 +++++++++ chebai/preprocessing/datasets/chebi.py | 67 +++++++++++++++++--------- 2 files changed, 66 insertions(+), 24 deletions(-) diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index 4a1898bc..d63396cd 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -1,6 +1,7 @@ import os import random from abc import ABC, abstractmethod +from pathlib import Path from typing import Any, Dict, Generator, List, Optional, Tuple, Union import lightning as pl @@ -419,10 +420,17 @@ def prepare_data(self, *args, **kwargs) -> None: self._prepare_data_flag += 1 self._perform_data_preparation(*args, **kwargs) + self._after_prepare_data(*args, **kwargs) def _perform_data_preparation(self, *args, **kwargs) -> None: raise NotImplementedError + def _after_prepare_data(self, *args, **kwargs) -> None: + """ + Hook to perform additional pre-processing after pre-processed data is available. + """ + ... + def setup(self, *args, **kwargs) -> None: """ Setup the data module. @@ -872,6 +880,21 @@ def save_processed(self, data: pd.DataFrame, filename: str) -> None: """ pd.to_pickle(data, open(os.path.join(self.processed_dir_main, filename), "wb")) + def get_processed_pickled_df_file(self, filename: str) -> pd.DataFrame | None: + """ + Gets the processed dataset pickle file. + + Args: + filename (str): The filename for the pickle file. + + Returns: + pd.DataFrame: The processed dataset as a DataFrame. + """ + file_path = Path(self.processed_dir_main) / filename + if file_path.exists(): + return pd.read_pickle(file_path) + return None + # ------------------------------ Phase: Setup data ----------------------------------- def setup_processed(self) -> None: """ diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index 6db83842..2afeef0e 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -15,8 +15,7 @@ from abc import ABC from collections import OrderedDict from itertools import cycle -from multiprocessing import Pool -from typing import Any, Dict, Generator, List, Optional, Tuple, Union +from typing import Any, Generator, Optional, Union import fastobo import networkx as nx @@ -142,15 +141,16 @@ def __init__( aug_smiles_variations: Optional[int] = None, **kwargs, ): - if augment_smiles: - assert aug_smiles_variations is not None, "" + if bool(augment_smiles): assert ( - aug_smiles_variations > 0 + int(aug_smiles_variations) > 0 ), "Number of variations must be greater than 0" - kwargs.setdefault("reader_kwargs", {}).update(canonicalize_smiles=False) + reader_kwargs = kwargs.get("reader_kwargs", {}) + reader_kwargs["canonicalize_smiles"] = False + kwargs["reader_kwargs"] = reader_kwargs - self.augment_smiles = augment_smiles - self.aug_smiles_variations = aug_smiles_variations + self.augment_smiles = bool(augment_smiles) + self.aug_smiles_variations = int(aug_smiles_variations) # predict only single class (given as id of one of the classes present in the raw data set) self.single_class = single_class super(_ChEBIDataExtractor, self).__init__(**kwargs) @@ -321,14 +321,28 @@ def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame: # This filters the DataFrame to include only the rows where at least one value in the row from 4th column # onwards is True/non-zero. data = data[data.iloc[:, self._LABELS_START_IDX :].any(axis=1)] - if self.augment_smiles: - return self._perform_smiles_augmentation(data) return data - def _perform_smiles_augmentation(self, data_df: pd.DataFrame) -> pd.DataFrame: + def _after_prepare_data(self, *args, **kwargs) -> None: + self._perform_smiles_augmentation() + + def _perform_smiles_augmentation(self) -> None: + if not self.augment_smiles: + return + + aug_data_df = self.get_processed_pickled_df_file( + self.processed_main_file_names_dict["aug_data"] + ) + if aug_data_df is not None: + return + + data_df = self.get_processed_pickled_df_file( + self.processed_main_file_names_dict["data"] + ) + AUG_SMILES_VARIATIONS = self.aug_smiles_variations - def generate_augmented_smiles(smiles: str): + def generate_augmented_smiles(smiles: str) -> list[str]: mol: Chem.Mol = Chem.MolFromSmiles(smiles) atom_ids = [atom.GetIdx() for atom in mol.GetAtoms()] random.shuffle(atom_ids) # seed set by lightning @@ -340,15 +354,15 @@ def generate_augmented_smiles(smiles: str): augmented.add(smiles) return list(augmented) - with Pool() as pool: - data_df["augmented_smiles"] = pool.map( - generate_augmented_smiles, data_df["SMILES"] - ) + data_df["augmented_smiles"] = data_df["SMILES"].apply(generate_augmented_smiles) + # Explode the list of augmented smiles into multiple rows # augmented smiles will have same ident, as of the original, but does it matter ? exploded_df = data_df.explode("augmented_smiles").reset_index(drop=True) exploded_df.rename(columns={"augmented_smiles": "SMILES"}, inplace=True) - return exploded_df + self.save_processed( + exploded_df, self.processed_main_file_names_dict["aug_data"] + ) # ------------------------------ Phase: Setup data ----------------------------------- def setup_processed(self) -> None: @@ -377,7 +391,7 @@ def setup_processed(self) -> None: print("Calling the setup method related to it") self._chebi_version_train_obj.setup() - def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, None]: + def _load_dict(self, input_file_path: str) -> Generator[dict[str, Any], None, None]: """ Loads a dictionary from a pickled file, yielding individual dictionaries for each row. @@ -418,7 +432,7 @@ def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, No ) # ------------------------------ Phase: Dynamic Splits ----------------------------------- - def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: + def _get_data_splits(self) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: """ Loads encoded/transformed data and generates training, validation, and test splits. @@ -591,9 +605,10 @@ def processed_main_file_names_dict(self) -> dict: dict: A dictionary mapping dataset key to their respective file names. For example, {"data": "data.pkl"}. """ - if not self.augment_smiles: - return super().processed_main_file_names_dict - return {"data": f"aug_data_var{self.aug_smiles_variations}.pkl"} + p_dict = super().processed_main_file_names_dict + if self.augment_smiles: + p_dict["aug_data"] = f"aug_data_var{self.aug_smiles_variations}.pkl" + return p_dict @property def processed_file_names_dict(self) -> dict: @@ -606,6 +621,10 @@ def processed_file_names_dict(self) -> dict: """ if not self.augment_smiles: return super().processed_file_names_dict + if self.n_token_limit is not None: + return { + "data": f"aug_data_var{self.aug_smiles_variations}_maxlen{self.n_token_limit}.pt" + } return {"data": f"aug_data_var{self.aug_smiles_variations}.pt"} @@ -644,7 +663,7 @@ def _name(self) -> str: """ return f"ChEBI{self.THRESHOLD}" - def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List: + def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> list: """ Selects classes from the ChEBI dataset based on the number of successors meeting a specified threshold. @@ -856,7 +875,7 @@ def chebi_to_int(s: str) -> int: return int(s[s.index(":") + 1 :]) -def term_callback(doc: fastobo.term.TermFrame) -> Union[Dict, bool]: +def term_callback(doc: fastobo.term.TermFrame) -> Union[dict, bool]: """ Extracts information from a ChEBI term document. This function takes a ChEBI term document as input and extracts relevant information such as the term ID, parents, From 5d3aa0a8068eb69f2ef3120b40b76a024e3b8577 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 22 Aug 2025 13:35:34 +0200 Subject: [PATCH 05/13] aug smiles based on fragmentation - https://github.com/ChEB-AI/python-chebai/issues/44 --- chebai/preprocessing/datasets/chebi.py | 53 ++++++++++++++++++-------- 1 file changed, 38 insertions(+), 15 deletions(-) diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index 2afeef0e..6072d404 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -14,7 +14,7 @@ import random from abc import ABC from collections import OrderedDict -from itertools import cycle +from itertools import cycle, permutations, product from typing import Any, Generator, Optional, Union import fastobo @@ -145,12 +145,13 @@ def __init__( assert ( int(aug_smiles_variations) > 0 ), "Number of variations must be greater than 0" + aug_smiles_variations = int(aug_smiles_variations) reader_kwargs = kwargs.get("reader_kwargs", {}) reader_kwargs["canonicalize_smiles"] = False kwargs["reader_kwargs"] = reader_kwargs self.augment_smiles = bool(augment_smiles) - self.aug_smiles_variations = int(aug_smiles_variations) + self.aug_smiles_variations = aug_smiles_variations # predict only single class (given as id of one of the classes present in the raw data set) self.single_class = single_class super(_ChEBIDataExtractor, self).__init__(**kwargs) @@ -344,22 +345,44 @@ def _perform_smiles_augmentation(self) -> None: def generate_augmented_smiles(smiles: str) -> list[str]: mol: Chem.Mol = Chem.MolFromSmiles(smiles) - atom_ids = [atom.GetIdx() for atom in mol.GetAtoms()] - random.shuffle(atom_ids) # seed set by lightning - atom_id_iter = cycle(atom_ids) - augmented = { - Chem.MolToSmiles(mol, rootedAtAtom=next(atom_id_iter), doRandom=True) - for _ in range(AUG_SMILES_VARIATIONS) - } - augmented.add(smiles) - return list(augmented) - - data_df["augmented_smiles"] = data_df["SMILES"].apply(generate_augmented_smiles) + frags = Chem.GetMolFrags(mol, asMols=True, sanitizeFrags=False) + augmented = set() + + frag_smiles: list[set] = [] + for frag in frags: + atom_ids = [atom.GetIdx() for atom in frag.GetAtoms()] + random.shuffle(atom_ids) # seed set by lightning + atom_id_iter = cycle(atom_ids) + frag_smiles.append( + { + Chem.MolToSmiles( + frag, rootedAtAtom=next(atom_id_iter), doRandom=True + ) + for _ in range(AUG_SMILES_VARIATIONS) + } + ) + if len(frags) > 1: + # all permutations (ignoring the set order, meaning mixing sets in every order), + aug_counter: int = 0 + for perm in permutations(frag_smiles): + for combo in product(*perm): + augmented.add(".".join(combo)) + aug_counter += 1 + if aug_counter >= AUG_SMILES_VARIATIONS: + break + if aug_counter >= AUG_SMILES_VARIATIONS: + break + else: + augmented = frag_smiles[0] + + return [smiles] + list(augmented) + + data_df["SMILES"] = data_df["SMILES"].apply(generate_augmented_smiles) # Explode the list of augmented smiles into multiple rows # augmented smiles will have same ident, as of the original, but does it matter ? - exploded_df = data_df.explode("augmented_smiles").reset_index(drop=True) - exploded_df.rename(columns={"augmented_smiles": "SMILES"}, inplace=True) + # instead its helpful to group augmented smiles generated from the same original SMILES + exploded_df = data_df.explode("SMILES").reset_index(drop=True) self.save_processed( exploded_df, self.processed_main_file_names_dict["aug_data"] ) From 06c09043344a708956834978ed954ed92649e1f1 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 22 Aug 2025 14:56:57 +0200 Subject: [PATCH 06/13] separate splits file for aug data --- chebai/preprocessing/datasets/base.py | 17 ++++++++++++++--- chebai/preprocessing/datasets/chebi.py | 12 ++++++++++++ 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index d63396cd..b8394dbd 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -856,7 +856,7 @@ def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame: pass @abstractmethod - def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List: + def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> list: """ Selects classes from the dataset based on a specified criteria. @@ -880,7 +880,7 @@ def save_processed(self, data: pd.DataFrame, filename: str) -> None: """ pd.to_pickle(data, open(os.path.join(self.processed_dir_main, filename), "wb")) - def get_processed_pickled_df_file(self, filename: str) -> pd.DataFrame | None: + def get_processed_pickled_df_file(self, filename: str) -> Optional[pd.DataFrame]: """ Gets the processed dataset pickle file. @@ -1004,7 +1004,8 @@ def _generate_dynamic_splits(self) -> None: combined_split_assignment = pd.concat(split_assignment_list, ignore_index=True) combined_split_assignment.to_csv( - os.path.join(self.processed_dir_main, "splits.csv"), index=False + os.path.join(self.processed_dir_main, self.splits_file_name), + index=False, ) # Store the splits in class variables @@ -1264,3 +1265,13 @@ def processed_file_names_dict(self) -> dict: if self.n_token_limit is not None: return {"data": f"data_maxlen{self.n_token_limit}.pt"} return {"data": "data.pt"} + + @property + def splits_file_name(self) -> str: + """ + Returns the name of the splits file. + + Returns: + str: The name of the splits file. + """ + return "splits.csv" diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index 6072d404..9a52d22f 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -650,6 +650,18 @@ def processed_file_names_dict(self) -> dict: } return {"data": f"aug_data_var{self.aug_smiles_variations}.pt"} + @property + def splits_file_name(self) -> str: + """ + Returns the name of the splits file. + + Returns: + str: The name of the splits file. + """ + if self.augment_smiles: + return f"aug_splits_var{self.aug_smiles_variations}.csv" + return super().splits_file_name + class JCIExtendedBase(_ChEBIDataExtractor): @property From 97e99fcd68ae76f1fccb5124818927a44b23b828 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 22 Aug 2025 15:14:35 +0200 Subject: [PATCH 07/13] Update chebi.py --- chebai/preprocessing/datasets/chebi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index 901f3659..defbdd19 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -897,7 +897,7 @@ def _extract_class_hierarchy(self, chebi_path: str) -> nx.DiGraph: ) return g - def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List: + def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> list: """Only selects classes that meet the threshold AND are subclasses of the top class ID (including itself).""" smiles = nx.get_node_attributes(g, "smiles") nodes = list( From 0b714a684de8f303f85e796e35a3d61995fb674d Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 22 Aug 2025 16:06:56 +0200 Subject: [PATCH 08/13] fix mol none error --- chebai/preprocessing/datasets/chebi.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index defbdd19..9e73d949 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -343,6 +343,11 @@ def _perform_smiles_augmentation(self) -> None: def generate_augmented_smiles(smiles: str) -> list[str]: mol: Chem.Mol = Chem.MolFromSmiles(smiles) + if mol is None: + return [smiles] # if mol is None, return original SMILES + + # sanitization set to False, as it can alter the fragment representation in ways you might not want. + # As we don’t want RDKit to "fix" fragments, only need the fragments as-is, to generate SMILES strings. frags = Chem.GetMolFrags(mol, asMols=True, sanitizeFrags=False) augmented = set() From 5dc209dc70155dd2c81c52267534178983197e9a Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 22 Aug 2025 17:03:31 +0200 Subject: [PATCH 09/13] fix - use aug pkl to generate pt file --- chebai/preprocessing/datasets/base.py | 3 ++- chebai/preprocessing/datasets/chebi.py | 7 ++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index b8394dbd..45118b94 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -742,6 +742,7 @@ def __init__( self.splits_file_path = self._validate_splits_file_path( kwargs.get("splits_file_path", None) ) + self._data_pkl_filename: str = "data.pkl" @staticmethod def _validate_splits_file_path(splits_file_path: Optional[str]) -> Optional[str]: @@ -1241,7 +1242,7 @@ def processed_main_file_names_dict(self) -> dict: dict: A dictionary mapping dataset key to their respective file names. For example, {"data": "data.pkl"}. """ - return {"data": "data.pkl"} + return {"data": self._data_pkl_filename} @property def raw_file_names(self) -> List[str]: diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index 9e73d949..d4e16905 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -329,10 +329,10 @@ def _perform_smiles_augmentation(self) -> None: if not self.augment_smiles: return - aug_data_df = self.get_processed_pickled_df_file( - self.processed_main_file_names_dict["aug_data"] - ) + aug_pkl_file_name = self.processed_main_file_names_dict["aug_data"] + aug_data_df = self.get_processed_pickled_df_file(aug_pkl_file_name) if aug_data_df is not None: + self._data_pkl_filename = aug_pkl_file_name return data_df = self.get_processed_pickled_df_file( @@ -389,6 +389,7 @@ def generate_augmented_smiles(smiles: str) -> list[str]: self.save_processed( exploded_df, self.processed_main_file_names_dict["aug_data"] ) + self._data_pkl_filename = aug_pkl_file_name # ------------------------------ Phase: Setup data ----------------------------------- def setup_processed(self) -> None: From 40af1bc434a7a7ef4d41200296f035205b4c5662 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 22 Aug 2025 20:39:12 +0200 Subject: [PATCH 10/13] prints stats for loaded data sizes --- chebai/preprocessing/datasets/base.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index 45118b94..cf44788f 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -472,14 +472,17 @@ def _set_processed_data_props(self): - self._num_of_labels: Number of target labels in the dataset. - self._feature_vector_size: Maximum feature vector length across all data points. """ - data_pt = torch.load( - os.path.join(self.processed_dir, self.processed_file_names_dict["data"]), - weights_only=False, + pt_file_path = os.path.join( + self.processed_dir, self.processed_file_names_dict["data"] ) + data_pt = torch.load(pt_file_path, weights_only=False) self._num_of_labels = len(data_pt[0]["labels"]) self._feature_vector_size = max(len(d["features"]) for d in data_pt) + print( + f"Number of samples in encoded data ({pt_file_path}): {len(data_pt)} samples" + ) print(f"Number of labels for loaded data: {self._num_of_labels}") print(f"Feature vector size: {self._feature_vector_size}") @@ -934,7 +937,9 @@ def _get_data_size(input_file_path: str) -> int: int: The size of the data. """ with open(input_file_path, "rb") as f: - return len(pd.read_pickle(f)) + df = pd.read_pickle(f) + print(f"Processed data size ({input_file_path}): {len(df)} rows") + return len(df) @abstractmethod def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, None]: From c5339acf7894dfcc9ba80ac4ad08de5b0feb01ad Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 2 Sep 2025 12:28:01 +0200 Subject: [PATCH 11/13] for aug data splits should be consistent with original data --- chebai/preprocessing/datasets/chebi.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index d4e16905..561c2cbe 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -146,6 +146,12 @@ def __init__( int(aug_smiles_variations) > 0 ), "Number of variations must be greater than 0" aug_smiles_variations = int(aug_smiles_variations) + + if not kwargs.get("splits_file_path", None): + raise ValueError( + "When using SMILES augmentation, a splits_file_path must be provided to ensure consistent splits." + ) + reader_kwargs = kwargs.get("reader_kwargs", {}) reader_kwargs["canonicalize_smiles"] = False kwargs["reader_kwargs"] = reader_kwargs From be74a1635ca092dfacf2550933b062c35dc0404e Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 21 Sep 2025 13:12:27 +0200 Subject: [PATCH 12/13] no splits file for aug data - it must use org splits --- chebai/preprocessing/datasets/base.py | 13 +------------ chebai/preprocessing/datasets/chebi.py | 12 ------------ 2 files changed, 1 insertion(+), 24 deletions(-) diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index cf44788f..18656cf6 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -1010,8 +1010,7 @@ def _generate_dynamic_splits(self) -> None: combined_split_assignment = pd.concat(split_assignment_list, ignore_index=True) combined_split_assignment.to_csv( - os.path.join(self.processed_dir_main, self.splits_file_name), - index=False, + os.path.join(self.processed_dir_main, "splits.csv"), index=False ) # Store the splits in class variables @@ -1271,13 +1270,3 @@ def processed_file_names_dict(self) -> dict: if self.n_token_limit is not None: return {"data": f"data_maxlen{self.n_token_limit}.pt"} return {"data": "data.pt"} - - @property - def splits_file_name(self) -> str: - """ - Returns the name of the splits file. - - Returns: - str: The name of the splits file. - """ - return "splits.csv" diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index 561c2cbe..75553c06 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -660,18 +660,6 @@ def processed_file_names_dict(self) -> dict: } return {"data": f"aug_data_var{self.aug_smiles_variations}.pt"} - @property - def splits_file_name(self) -> str: - """ - Returns the name of the splits file. - - Returns: - str: The name of the splits file. - """ - if self.augment_smiles: - return f"aug_splits_var{self.aug_smiles_variations}.csv" - return super().splits_file_name - class JCIExtendedBase(_ChEBIDataExtractor): @property From 40b98bc73438df3076ba6950fc0517e5e28d204e Mon Sep 17 00:00:00 2001 From: Aditya Khedekar <65857172+aditya0by0@users.noreply.github.com> Date: Sun, 21 Sep 2025 20:31:30 +0200 Subject: [PATCH 13/13] test evaluation using CLI --- README.md | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 12b0b95f..eeecd714 100644 --- a/README.md +++ b/README.md @@ -78,8 +78,21 @@ The `classes_path` is the path to the dataset's `raw/classes.txt` file that cont ## Evaluation -An example for evaluating a model trained on the ontology extension task is given in `tutorials/eval_model_basic.ipynb`. -It takes in the finetuned model as input for performing the evaluation. +You can evaluate a model trained on the ontology extension task in one of two ways: + +### 1. Using the Jupyter Notebook +An example notebook is provided at `tutorials/eval_model_basic.ipynb`. +- Load your finetuned model and run the evaluation cells to compute metrics on the test set. + +### 2. Using the Lightning CLI +Alternatively, you can evaluate the model via the CLI: + +```bash +python -m chebai test --trainer=configs/training/default_trainer.yml --trainer.devices=1 --trainer.num_nodes=1 --ckpt_path=[path-to-finetuned-model] --model=configs/model/electra.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --data=configs/data/chebi/chebi50.yml --data.init_args.batch_size=32 --data.init_args.num_workers=10 --data.init_args.chebi_version=[chebi-version] --model.pass_loss_kwargs=false --model.criterion=configs/loss/bce.yml --model.criterion.init_args.beta=0.99 --data.init_args.splits_file_path=[path-to-splits-file] +``` + +> **Note**: It is recommended to use `devices=1` and `num_nodes=1` during testing; multi-device settings use a `DistributedSampler`, which may replicate some samples to maintain equal batch sizes, so using a single device ensures that each sample or batch is evaluated exactly once. + ## Cross-validation You can do inner k-fold cross-validation, i.e., train models on k train-validation splits that all use the same test