Skip to content

Commit 06a1869

Browse files
authored
Merge pull request #115 from ChEB-AI/feature/augment_smiles
Data Augmentation : SMILES
2 parents 8e51a61 + 40b98bc commit 06a1869

File tree

3 files changed

+176
-10
lines changed

3 files changed

+176
-10
lines changed

README.md

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,21 @@ The `classes_path` is the path to the dataset's `raw/classes.txt` file that cont
7878

7979
## Evaluation
8080

81-
An example for evaluating a model trained on the ontology extension task is given in `tutorials/eval_model_basic.ipynb`.
82-
It takes in the finetuned model as input for performing the evaluation.
81+
You can evaluate a model trained on the ontology extension task in one of two ways:
82+
83+
### 1. Using the Jupyter Notebook
84+
An example notebook is provided at `tutorials/eval_model_basic.ipynb`.
85+
- Load your finetuned model and run the evaluation cells to compute metrics on the test set.
86+
87+
### 2. Using the Lightning CLI
88+
Alternatively, you can evaluate the model via the CLI:
89+
90+
```bash
91+
python -m chebai test --trainer=configs/training/default_trainer.yml --trainer.devices=1 --trainer.num_nodes=1 --ckpt_path=[path-to-finetuned-model] --model=configs/model/electra.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --data=configs/data/chebi/chebi50.yml --data.init_args.batch_size=32 --data.init_args.num_workers=10 --data.init_args.chebi_version=[chebi-version] --model.pass_loss_kwargs=false --model.criterion=configs/loss/bce.yml --model.criterion.init_args.beta=0.99 --data.init_args.splits_file_path=[path-to-splits-file]
92+
```
93+
94+
> **Note**: It is recommended to use `devices=1` and `num_nodes=1` during testing; multi-device settings use a `DistributedSampler`, which may replicate some samples to maintain equal batch sizes, so using a single device ensures that each sample or batch is evaluated exactly once.
95+
8396

8497
## Cross-validation
8598
You can do inner k-fold cross-validation, i.e., train models on k train-validation splits that all use the same test

chebai/preprocessing/datasets/base.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import random
33
from abc import ABC, abstractmethod
4+
from pathlib import Path
45
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple, Union
56

67
import lightning as pl
@@ -416,10 +417,17 @@ def prepare_data(self, *args, **kwargs) -> None:
416417

417418
self._prepare_data_flag += 1
418419
self._perform_data_preparation(*args, **kwargs)
420+
self._after_prepare_data(*args, **kwargs)
419421

420422
def _perform_data_preparation(self, *args, **kwargs) -> None:
421423
raise NotImplementedError
422424

425+
def _after_prepare_data(self, *args, **kwargs) -> None:
426+
"""
427+
Hook to perform additional pre-processing after pre-processed data is available.
428+
"""
429+
...
430+
423431
def setup(self, *args, **kwargs) -> None:
424432
"""
425433
Setup the data module.
@@ -461,14 +469,17 @@ def _set_processed_data_props(self):
461469
- self._num_of_labels: Number of target labels in the dataset.
462470
- self._feature_vector_size: Maximum feature vector length across all data points.
463471
"""
464-
data_pt = torch.load(
465-
os.path.join(self.processed_dir, self.processed_file_names_dict["data"]),
466-
weights_only=False,
472+
pt_file_path = os.path.join(
473+
self.processed_dir, self.processed_file_names_dict["data"]
467474
)
475+
data_pt = torch.load(pt_file_path, weights_only=False)
468476

469477
self._num_of_labels = len(data_pt[0]["labels"])
470478
self._feature_vector_size = max(len(d["features"]) for d in data_pt)
471479

480+
print(
481+
f"Number of samples in encoded data ({pt_file_path}): {len(data_pt)} samples"
482+
)
472483
print(f"Number of labels for loaded data: {self._num_of_labels}")
473484
print(f"Feature vector size: {self._feature_vector_size}")
474485

@@ -731,6 +742,7 @@ def __init__(
731742
self.splits_file_path = self._validate_splits_file_path(
732743
kwargs.get("splits_file_path", None)
733744
)
745+
self._data_pkl_filename: str = "data.pkl"
734746

735747
@staticmethod
736748
def _validate_splits_file_path(splits_file_path: Optional[str]) -> Optional[str]:
@@ -869,6 +881,21 @@ def save_processed(self, data: pd.DataFrame, filename: str) -> None:
869881
"""
870882
pd.to_pickle(data, open(os.path.join(self.processed_dir_main, filename), "wb"))
871883

884+
def get_processed_pickled_df_file(self, filename: str) -> Optional[pd.DataFrame]:
885+
"""
886+
Gets the processed dataset pickle file.
887+
888+
Args:
889+
filename (str): The filename for the pickle file.
890+
891+
Returns:
892+
pd.DataFrame: The processed dataset as a DataFrame.
893+
"""
894+
file_path = Path(self.processed_dir_main) / filename
895+
if file_path.exists():
896+
return pd.read_pickle(file_path)
897+
return None
898+
872899
# ------------------------------ Phase: Setup data -----------------------------------
873900
def setup_processed(self) -> None:
874901
"""
@@ -907,7 +934,9 @@ def _get_data_size(input_file_path: str) -> int:
907934
int: The size of the data.
908935
"""
909936
with open(input_file_path, "rb") as f:
910-
return len(pd.read_pickle(f))
937+
df = pd.read_pickle(f)
938+
print(f"Processed data size ({input_file_path}): {len(df)} rows")
939+
return len(df)
911940

912941
@abstractmethod
913942
def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, None]:
@@ -1223,7 +1252,7 @@ def processed_main_file_names_dict(self) -> dict:
12231252
dict: A dictionary mapping dataset key to their respective file names.
12241253
For example, {"data": "data.pkl"}.
12251254
"""
1226-
return {"data": "data.pkl"}
1255+
return {"data": self._data_pkl_filename}
12271256

12281257
@property
12291258
def raw_file_names(self) -> List[str]:

chebai/preprocessing/datasets/chebi.py

Lines changed: 127 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,15 @@
1111

1212
import os
1313
import pickle
14+
import random
1415
from abc import ABC
1516
from collections import OrderedDict
16-
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple, Union
17+
from itertools import cycle, permutations, product
18+
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Union
1719

1820
import pandas as pd
1921
import torch
22+
from rdkit import Chem
2023

2124
from chebai.preprocessing import reader as dr
2225
from chebai.preprocessing.datasets.base import XYBaseDataModule, _DynamicDataset
@@ -135,8 +138,27 @@ def __init__(
135138
self,
136139
chebi_version_train: Optional[int] = None,
137140
single_class: Optional[int] = None,
141+
augment_smiles: bool = False,
142+
aug_smiles_variations: Optional[int] = None,
138143
**kwargs,
139144
):
145+
if bool(augment_smiles):
146+
assert (
147+
int(aug_smiles_variations) > 0
148+
), "Number of variations must be greater than 0"
149+
aug_smiles_variations = int(aug_smiles_variations)
150+
151+
if not kwargs.get("splits_file_path", None):
152+
raise ValueError(
153+
"When using SMILES augmentation, a splits_file_path must be provided to ensure consistent splits."
154+
)
155+
156+
reader_kwargs = kwargs.get("reader_kwargs", {})
157+
reader_kwargs["canonicalize_smiles"] = False
158+
kwargs["reader_kwargs"] = reader_kwargs
159+
160+
self.augment_smiles = bool(augment_smiles)
161+
self.aug_smiles_variations = aug_smiles_variations
140162
# predict only single class (given as id of one of the classes present in the raw data set)
141163
self.single_class = single_class
142164
super(_ChEBIDataExtractor, self).__init__(**kwargs)
@@ -151,6 +173,8 @@ def __init__(
151173
_init_kwargs["chebi_version"] = self.chebi_version_train
152174
self._chebi_version_train_obj = self.__class__(
153175
single_class=self.single_class,
176+
augment_smiles=self.augment_smiles,
177+
aug_smiles_variations=self.aug_smiles_variations,
154178
**_init_kwargs,
155179
)
156180

@@ -312,6 +336,75 @@ def _graph_to_raw_dataset(self, g: "nx.DiGraph") -> pd.DataFrame:
312336

313337
return data
314338

339+
def _after_prepare_data(self, *args, **kwargs) -> None:
340+
self._perform_smiles_augmentation()
341+
342+
def _perform_smiles_augmentation(self) -> None:
343+
if not self.augment_smiles:
344+
return
345+
346+
aug_pkl_file_name = self.processed_main_file_names_dict["aug_data"]
347+
aug_data_df = self.get_processed_pickled_df_file(aug_pkl_file_name)
348+
if aug_data_df is not None:
349+
self._data_pkl_filename = aug_pkl_file_name
350+
return
351+
352+
data_df = self.get_processed_pickled_df_file(
353+
self.processed_main_file_names_dict["data"]
354+
)
355+
356+
AUG_SMILES_VARIATIONS = self.aug_smiles_variations
357+
358+
def generate_augmented_smiles(smiles: str) -> list[str]:
359+
mol: Chem.Mol = Chem.MolFromSmiles(smiles)
360+
if mol is None:
361+
return [smiles] # if mol is None, return original SMILES
362+
363+
# sanitization set to False, as it can alter the fragment representation in ways you might not want.
364+
# As we don’t want RDKit to "fix" fragments, only need the fragments as-is, to generate SMILES strings.
365+
frags = Chem.GetMolFrags(mol, asMols=True, sanitizeFrags=False)
366+
augmented = set()
367+
368+
frag_smiles: list[set] = []
369+
for frag in frags:
370+
atom_ids = [atom.GetIdx() for atom in frag.GetAtoms()]
371+
random.shuffle(atom_ids) # seed set by lightning
372+
atom_id_iter = cycle(atom_ids)
373+
frag_smiles.append(
374+
{
375+
Chem.MolToSmiles(
376+
frag, rootedAtAtom=next(atom_id_iter), doRandom=True
377+
)
378+
for _ in range(AUG_SMILES_VARIATIONS)
379+
}
380+
)
381+
if len(frags) > 1:
382+
# all permutations (ignoring the set order, meaning mixing sets in every order),
383+
aug_counter: int = 0
384+
for perm in permutations(frag_smiles):
385+
for combo in product(*perm):
386+
augmented.add(".".join(combo))
387+
aug_counter += 1
388+
if aug_counter >= AUG_SMILES_VARIATIONS:
389+
break
390+
if aug_counter >= AUG_SMILES_VARIATIONS:
391+
break
392+
else:
393+
augmented = frag_smiles[0]
394+
395+
return [smiles] + list(augmented)
396+
397+
data_df["SMILES"] = data_df["SMILES"].apply(generate_augmented_smiles)
398+
399+
# Explode the list of augmented smiles into multiple rows
400+
# augmented smiles will have same ident, as of the original, but does it matter ?
401+
# instead its helpful to group augmented smiles generated from the same original SMILES
402+
exploded_df = data_df.explode("SMILES").reset_index(drop=True)
403+
self.save_processed(
404+
exploded_df, self.processed_main_file_names_dict["aug_data"]
405+
)
406+
self._data_pkl_filename = aug_pkl_file_name
407+
315408
# ------------------------------ Phase: Setup data -----------------------------------
316409
def setup_processed(self) -> None:
317410
"""
@@ -339,7 +432,7 @@ def setup_processed(self) -> None:
339432
print("Calling the setup method related to it")
340433
self._chebi_version_train_obj.setup()
341434

342-
def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, None]:
435+
def _load_dict(self, input_file_path: str) -> Generator[dict[str, Any], None, None]:
343436
"""
344437
Loads a dictionary from a pickled file, yielding individual dictionaries for each row.
345438
@@ -380,7 +473,7 @@ def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, No
380473
)
381474

382475
# ------------------------------ Phase: Dynamic Splits -----------------------------------
383-
def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
476+
def _get_data_splits(self) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
384477
"""
385478
Loads encoded/transformed data and generates training, validation, and test splits.
386479
@@ -544,6 +637,37 @@ def processed_dir(self) -> str:
544637
def raw_file_names_dict(self) -> dict:
545638
return {"chebi": "chebi.obo"}
546639

640+
@property
641+
def processed_main_file_names_dict(self) -> dict:
642+
"""
643+
Returns a dictionary mapping processed data file names.
644+
645+
Returns:
646+
dict: A dictionary mapping dataset key to their respective file names.
647+
For example, {"data": "data.pkl"}.
648+
"""
649+
p_dict = super().processed_main_file_names_dict
650+
if self.augment_smiles:
651+
p_dict["aug_data"] = f"aug_data_var{self.aug_smiles_variations}.pkl"
652+
return p_dict
653+
654+
@property
655+
def processed_file_names_dict(self) -> dict:
656+
"""
657+
Returns a dictionary for the processed and tokenized data files.
658+
659+
Returns:
660+
dict: A dictionary mapping dataset keys to their respective file names.
661+
For example, {"data": "data.pt"}.
662+
"""
663+
if not self.augment_smiles:
664+
return super().processed_file_names_dict
665+
if self.n_token_limit is not None:
666+
return {
667+
"data": f"aug_data_var{self.aug_smiles_variations}_maxlen{self.n_token_limit}.pt"
668+
}
669+
return {"data": f"aug_data_var{self.aug_smiles_variations}.pt"}
670+
547671

548672
class JCIExtendedBase(_ChEBIDataExtractor):
549673
@property

0 commit comments

Comments
 (0)