diff --git a/README.md b/README.md index eeecd714..73179ac0 100644 --- a/README.md +++ b/README.md @@ -81,7 +81,7 @@ The `classes_path` is the path to the dataset's `raw/classes.txt` file that cont You can evaluate a model trained on the ontology extension task in one of two ways: ### 1. Using the Jupyter Notebook -An example notebook is provided at `tutorials/eval_model_basic.ipynb`. +An example notebook is provided at `tutorials/eval_model_basic.ipynb`. - Load your finetuned model and run the evaluation cells to compute metrics on the test set. ### 2. Using the Lightning CLI diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index edcc8c41..379a7f62 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -358,7 +358,8 @@ def _perform_smiles_augmentation(self) -> None: self.processed_main_file_names_dict["data"] ) - AUG_SMILES_VARIATIONS = self.aug_smiles_variations + # +1 to account for if original SMILES is generated by random chance + AUG_SMILES_VARIATIONS = self.aug_smiles_variations + 1 def generate_augmented_smiles(smiles: str) -> list[str]: mol: Chem.Mol = Chem.MolFromSmiles(smiles) @@ -397,6 +398,16 @@ def generate_augmented_smiles(smiles: str) -> list[str]: else: augmented = frag_smiles[0] + if smiles in augmented: + augmented.remove(smiles) + + if len(augmented) > AUG_SMILES_VARIATIONS - 1: + # AUG_SMILES_VARIATIONS = number of new smiles needed to generate + original smiles + # if 3 smiles variations are needed, and 4 are generated because none + # correponds to original smiles and no-one is elimnated in previous if condition + augmented = random.sample(augmented, AUG_SMILES_VARIATIONS - 1) + + # original smiles always first in the list return [smiles] + list(augmented) data_df["SMILES"] = data_df["SMILES"].apply(generate_augmented_smiles)