Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,21 @@ The `classes_path` is the path to the dataset's `raw/classes.txt` file that cont

## Evaluation

An example for evaluating a model trained on the ontology extension task is given in `tutorials/eval_model_basic.ipynb`.
It takes in the finetuned model as input for performing the evaluation.
You can evaluate a model trained on the ontology extension task in one of two ways:

### 1. Using the Jupyter Notebook
An example notebook is provided at `tutorials/eval_model_basic.ipynb`.
- Load your finetuned model and run the evaluation cells to compute metrics on the test set.

### 2. Using the Lightning CLI
Alternatively, you can evaluate the model via the CLI:

```bash
python -m chebai test --trainer=configs/training/default_trainer.yml --trainer.devices=1 --trainer.num_nodes=1 --ckpt_path=[path-to-finetuned-model] --model=configs/model/electra.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --data=configs/data/chebi/chebi50.yml --data.init_args.batch_size=32 --data.init_args.num_workers=10 --data.init_args.chebi_version=[chebi-version] --model.pass_loss_kwargs=false --model.criterion=configs/loss/bce.yml --model.criterion.init_args.beta=0.99 --data.init_args.splits_file_path=[path-to-splits-file]
```

> **Note**: It is recommended to use `devices=1` and `num_nodes=1` during testing; multi-device settings use a `DistributedSampler`, which may replicate some samples to maintain equal batch sizes, so using a single device ensures that each sample or batch is evaluated exactly once.


## Cross-validation
You can do inner k-fold cross-validation, i.e., train models on k train-validation splits that all use the same test
Expand Down
39 changes: 34 additions & 5 deletions chebai/preprocessing/datasets/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import random
from abc import ABC, abstractmethod
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple, Union

import lightning as pl
Expand Down Expand Up @@ -416,10 +417,17 @@ def prepare_data(self, *args, **kwargs) -> None:

self._prepare_data_flag += 1
self._perform_data_preparation(*args, **kwargs)
self._after_prepare_data(*args, **kwargs)

def _perform_data_preparation(self, *args, **kwargs) -> None:
raise NotImplementedError

def _after_prepare_data(self, *args, **kwargs) -> None:
"""
Hook to perform additional pre-processing after pre-processed data is available.
"""
...

def setup(self, *args, **kwargs) -> None:
"""
Setup the data module.
Expand Down Expand Up @@ -461,14 +469,17 @@ def _set_processed_data_props(self):
- self._num_of_labels: Number of target labels in the dataset.
- self._feature_vector_size: Maximum feature vector length across all data points.
"""
data_pt = torch.load(
os.path.join(self.processed_dir, self.processed_file_names_dict["data"]),
weights_only=False,
pt_file_path = os.path.join(
self.processed_dir, self.processed_file_names_dict["data"]
)
data_pt = torch.load(pt_file_path, weights_only=False)

self._num_of_labels = len(data_pt[0]["labels"])
self._feature_vector_size = max(len(d["features"]) for d in data_pt)

print(
f"Number of samples in encoded data ({pt_file_path}): {len(data_pt)} samples"
)
print(f"Number of labels for loaded data: {self._num_of_labels}")
print(f"Feature vector size: {self._feature_vector_size}")

Expand Down Expand Up @@ -731,6 +742,7 @@ def __init__(
self.splits_file_path = self._validate_splits_file_path(
kwargs.get("splits_file_path", None)
)
self._data_pkl_filename: str = "data.pkl"

@staticmethod
def _validate_splits_file_path(splits_file_path: Optional[str]) -> Optional[str]:
Expand Down Expand Up @@ -869,6 +881,21 @@ def save_processed(self, data: pd.DataFrame, filename: str) -> None:
"""
pd.to_pickle(data, open(os.path.join(self.processed_dir_main, filename), "wb"))

def get_processed_pickled_df_file(self, filename: str) -> Optional[pd.DataFrame]:
"""
Gets the processed dataset pickle file.

Args:
filename (str): The filename for the pickle file.

Returns:
pd.DataFrame: The processed dataset as a DataFrame.
"""
file_path = Path(self.processed_dir_main) / filename
if file_path.exists():
return pd.read_pickle(file_path)
return None

# ------------------------------ Phase: Setup data -----------------------------------
def setup_processed(self) -> None:
"""
Expand Down Expand Up @@ -907,7 +934,9 @@ def _get_data_size(input_file_path: str) -> int:
int: The size of the data.
"""
with open(input_file_path, "rb") as f:
return len(pd.read_pickle(f))
df = pd.read_pickle(f)
print(f"Processed data size ({input_file_path}): {len(df)} rows")
return len(df)

@abstractmethod
def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, None]:
Expand Down Expand Up @@ -1223,7 +1252,7 @@ def processed_main_file_names_dict(self) -> dict:
dict: A dictionary mapping dataset key to their respective file names.
For example, {"data": "data.pkl"}.
"""
return {"data": "data.pkl"}
return {"data": self._data_pkl_filename}

@property
def raw_file_names(self) -> List[str]:
Expand Down
130 changes: 127 additions & 3 deletions chebai/preprocessing/datasets/chebi.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@

import os
import pickle
import random
from abc import ABC
from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple, Union
from itertools import cycle, permutations, product
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Union

import pandas as pd
import torch
from rdkit import Chem

from chebai.preprocessing import reader as dr
from chebai.preprocessing.datasets.base import XYBaseDataModule, _DynamicDataset
Expand Down Expand Up @@ -135,8 +138,27 @@ def __init__(
self,
chebi_version_train: Optional[int] = None,
single_class: Optional[int] = None,
augment_smiles: bool = False,
aug_smiles_variations: Optional[int] = None,
**kwargs,
):
if bool(augment_smiles):
assert (
int(aug_smiles_variations) > 0
), "Number of variations must be greater than 0"
aug_smiles_variations = int(aug_smiles_variations)

if not kwargs.get("splits_file_path", None):
raise ValueError(
"When using SMILES augmentation, a splits_file_path must be provided to ensure consistent splits."
)

reader_kwargs = kwargs.get("reader_kwargs", {})
reader_kwargs["canonicalize_smiles"] = False
kwargs["reader_kwargs"] = reader_kwargs

self.augment_smiles = bool(augment_smiles)
self.aug_smiles_variations = aug_smiles_variations
# predict only single class (given as id of one of the classes present in the raw data set)
self.single_class = single_class
super(_ChEBIDataExtractor, self).__init__(**kwargs)
Expand All @@ -151,6 +173,8 @@ def __init__(
_init_kwargs["chebi_version"] = self.chebi_version_train
self._chebi_version_train_obj = self.__class__(
single_class=self.single_class,
augment_smiles=self.augment_smiles,
aug_smiles_variations=self.aug_smiles_variations,
**_init_kwargs,
)

Expand Down Expand Up @@ -312,6 +336,75 @@ def _graph_to_raw_dataset(self, g: "nx.DiGraph") -> pd.DataFrame:

return data

def _after_prepare_data(self, *args, **kwargs) -> None:
self._perform_smiles_augmentation()

def _perform_smiles_augmentation(self) -> None:
if not self.augment_smiles:
return

aug_pkl_file_name = self.processed_main_file_names_dict["aug_data"]
aug_data_df = self.get_processed_pickled_df_file(aug_pkl_file_name)
if aug_data_df is not None:
self._data_pkl_filename = aug_pkl_file_name
return

data_df = self.get_processed_pickled_df_file(
self.processed_main_file_names_dict["data"]
)

AUG_SMILES_VARIATIONS = self.aug_smiles_variations

def generate_augmented_smiles(smiles: str) -> list[str]:
mol: Chem.Mol = Chem.MolFromSmiles(smiles)
if mol is None:
return [smiles] # if mol is None, return original SMILES

# sanitization set to False, as it can alter the fragment representation in ways you might not want.
# As we don’t want RDKit to "fix" fragments, only need the fragments as-is, to generate SMILES strings.
frags = Chem.GetMolFrags(mol, asMols=True, sanitizeFrags=False)
augmented = set()

frag_smiles: list[set] = []
for frag in frags:
atom_ids = [atom.GetIdx() for atom in frag.GetAtoms()]
random.shuffle(atom_ids) # seed set by lightning
atom_id_iter = cycle(atom_ids)
frag_smiles.append(
{
Chem.MolToSmiles(
frag, rootedAtAtom=next(atom_id_iter), doRandom=True
)
for _ in range(AUG_SMILES_VARIATIONS)
}
)
if len(frags) > 1:
# all permutations (ignoring the set order, meaning mixing sets in every order),
aug_counter: int = 0
for perm in permutations(frag_smiles):
for combo in product(*perm):
augmented.add(".".join(combo))
aug_counter += 1
if aug_counter >= AUG_SMILES_VARIATIONS:
break
if aug_counter >= AUG_SMILES_VARIATIONS:
break
else:
augmented = frag_smiles[0]

return [smiles] + list(augmented)

data_df["SMILES"] = data_df["SMILES"].apply(generate_augmented_smiles)

# Explode the list of augmented smiles into multiple rows
# augmented smiles will have same ident, as of the original, but does it matter ?
# instead its helpful to group augmented smiles generated from the same original SMILES
exploded_df = data_df.explode("SMILES").reset_index(drop=True)
self.save_processed(
exploded_df, self.processed_main_file_names_dict["aug_data"]
)
self._data_pkl_filename = aug_pkl_file_name

# ------------------------------ Phase: Setup data -----------------------------------
def setup_processed(self) -> None:
"""
Expand Down Expand Up @@ -339,7 +432,7 @@ def setup_processed(self) -> None:
print("Calling the setup method related to it")
self._chebi_version_train_obj.setup()

def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, None]:
def _load_dict(self, input_file_path: str) -> Generator[dict[str, Any], None, None]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a reason or this change? The only difference I see is that we lose support for python<3.9 with dict

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we’ve already set Python 3.10 as the minimum requirement in #105
, there’s no need to import type hints from typing or typing_extensions. Using built-in type hints (e.g., dict[...]) is both explicit and consistent with Python 3.10+ features, which is why this change was made.

"""
Loads a dictionary from a pickled file, yielding individual dictionaries for each row.

Expand Down Expand Up @@ -380,7 +473,7 @@ def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, No
)

# ------------------------------ Phase: Dynamic Splits -----------------------------------
def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
def _get_data_splits(self) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
"""
Loads encoded/transformed data and generates training, validation, and test splits.

Expand Down Expand Up @@ -544,6 +637,37 @@ def processed_dir(self) -> str:
def raw_file_names_dict(self) -> dict:
return {"chebi": "chebi.obo"}

@property
def processed_main_file_names_dict(self) -> dict:
"""
Returns a dictionary mapping processed data file names.

Returns:
dict: A dictionary mapping dataset key to their respective file names.
For example, {"data": "data.pkl"}.
"""
p_dict = super().processed_main_file_names_dict
if self.augment_smiles:
p_dict["aug_data"] = f"aug_data_var{self.aug_smiles_variations}.pkl"
return p_dict

@property
def processed_file_names_dict(self) -> dict:
"""
Returns a dictionary for the processed and tokenized data files.

Returns:
dict: A dictionary mapping dataset keys to their respective file names.
For example, {"data": "data.pt"}.
"""
if not self.augment_smiles:
return super().processed_file_names_dict
if self.n_token_limit is not None:
return {
"data": f"aug_data_var{self.aug_smiles_variations}_maxlen{self.n_token_limit}.pt"
}
return {"data": f"aug_data_var{self.aug_smiles_variations}.pt"}


class JCIExtendedBase(_ChEBIDataExtractor):
@property
Expand Down