From 2218fc8b483d8182bf295d021d2f28318f68562e Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 6 Dec 2025 12:54:00 +0100 Subject: [PATCH 1/2] avoid generation of original smiles in augmentation --- chebai/preprocessing/datasets/chebi.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index edcc8c41..c6659d05 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -144,9 +144,9 @@ def __init__( **kwargs, ): if bool(augment_smiles): - assert ( - int(aug_smiles_variations) > 0 - ), "Number of variations must be greater than 0" + assert int(aug_smiles_variations) > 0, ( + "Number of variations must be greater than 0" + ) aug_smiles_variations = int(aug_smiles_variations) if not kwargs.get("splits_file_path", None): @@ -358,7 +358,8 @@ def _perform_smiles_augmentation(self) -> None: self.processed_main_file_names_dict["data"] ) - AUG_SMILES_VARIATIONS = self.aug_smiles_variations + # +1 to account for if original SMILES is generated by random chance + AUG_SMILES_VARIATIONS = self.aug_smiles_variations + 1 def generate_augmented_smiles(smiles: str) -> list[str]: mol: Chem.Mol = Chem.MolFromSmiles(smiles) @@ -397,6 +398,16 @@ def generate_augmented_smiles(smiles: str) -> list[str]: else: augmented = frag_smiles[0] + if smiles in augmented: + augmented.remove(smiles) + + if len(augmented) > AUG_SMILES_VARIATIONS - 1: + # AUG_SMILES_VARIATIONS = number of new smiles needed to generate + original smiles + # if 3 smiles variations are needed, and 4 are generated because none + # correponds to original smiles and no-one is elimnated in previous if condition + augmented = random.sample(augmented, AUG_SMILES_VARIATIONS - 1) + + # original smiles always first in the list return [smiles] + list(augmented) data_df["SMILES"] = data_df["SMILES"].apply(generate_augmented_smiles) From df02f553010a6da19def7d3a4ee93662aec08464 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 6 Dec 2025 13:22:35 +0100 Subject: [PATCH 2/2] pre-commit formatting --- README.md | 2 +- chebai/preprocessing/datasets/chebi.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index eeecd714..73179ac0 100644 --- a/README.md +++ b/README.md @@ -81,7 +81,7 @@ The `classes_path` is the path to the dataset's `raw/classes.txt` file that cont You can evaluate a model trained on the ontology extension task in one of two ways: ### 1. Using the Jupyter Notebook -An example notebook is provided at `tutorials/eval_model_basic.ipynb`. +An example notebook is provided at `tutorials/eval_model_basic.ipynb`. - Load your finetuned model and run the evaluation cells to compute metrics on the test set. ### 2. Using the Lightning CLI diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index c6659d05..379a7f62 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -144,9 +144,9 @@ def __init__( **kwargs, ): if bool(augment_smiles): - assert int(aug_smiles_variations) > 0, ( - "Number of variations must be greater than 0" - ) + assert ( + int(aug_smiles_variations) > 0 + ), "Number of variations must be greater than 0" aug_smiles_variations = int(aug_smiles_variations) if not kwargs.get("splits_file_path", None):