Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ The `classes_path` is the path to the dataset's `raw/classes.txt` file that cont
You can evaluate a model trained on the ontology extension task in one of two ways:

### 1. Using the Jupyter Notebook
An example notebook is provided at `tutorials/eval_model_basic.ipynb`.
An example notebook is provided at `tutorials/eval_model_basic.ipynb`.
- Load your finetuned model and run the evaluation cells to compute metrics on the test set.

### 2. Using the Lightning CLI
Expand Down
13 changes: 12 additions & 1 deletion chebai/preprocessing/datasets/chebi.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,8 @@ def _perform_smiles_augmentation(self) -> None:
self.processed_main_file_names_dict["data"]
)

AUG_SMILES_VARIATIONS = self.aug_smiles_variations
# +1 to account for if original SMILES is generated by random chance
AUG_SMILES_VARIATIONS = self.aug_smiles_variations + 1

def generate_augmented_smiles(smiles: str) -> list[str]:
mol: Chem.Mol = Chem.MolFromSmiles(smiles)
Expand Down Expand Up @@ -397,6 +398,16 @@ def generate_augmented_smiles(smiles: str) -> list[str]:
else:
augmented = frag_smiles[0]

if smiles in augmented:
augmented.remove(smiles)

if len(augmented) > AUG_SMILES_VARIATIONS - 1:
# AUG_SMILES_VARIATIONS = number of new smiles needed to generate + original smiles
# if 3 smiles variations are needed, and 4 are generated because none
# correponds to original smiles and no-one is elimnated in previous if condition
augmented = random.sample(augmented, AUG_SMILES_VARIATIONS - 1)

# original smiles always first in the list
return [smiles] + list(augmented)

data_df["SMILES"] = data_df["SMILES"].apply(generate_augmented_smiles)
Expand Down