Skip to content

Commit 8f61ff7

Browse files
committed
add automatic inconsistency removal to ensemble
1 parent fd0b633 commit 8f61ff7

File tree

2 files changed

+34
-22
lines changed

2 files changed

+34
-22
lines changed

chebifier/cli.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ def cli():
2525
@click.option('--smiles-file', '-f', type=click.Path(exists=True), help='File containing SMILES strings (one per line)')
2626
@click.option('--output', '-o', type=click.Path(), help='Output file to save predictions (optional)')
2727
@click.option('--ensemble-type', '-e', type=click.Choice(ENSEMBLES.keys()), default='mv', help='Type of ensemble to use (default: Majority Voting)')
28-
def predict(config_file, smiles, smiles_file, output, ensemble_type):
28+
@click.option("--chebi-version", "-v", type=int, default=241, help="ChEBI version to use for checking consistency (default: 241)")
29+
def predict(config_file, smiles, smiles_file, output, ensemble_type, chebi_version):
2930
"""Predict ChEBI classes for SMILES strings using an ensemble model.
3031
3132
CONFIG_FILE is the path to a YAML configuration file for the ensemble model.
@@ -35,7 +36,7 @@ def predict(config_file, smiles, smiles_file, output, ensemble_type):
3536
config = yaml.safe_load(f)
3637

3738
# Instantiate ensemble model
38-
ensemble = ENSEMBLES[ensemble_type](config)
39+
ensemble = ENSEMBLES[ensemble_type](config, chebi_version=chebi_version)
3940

4041
# Collect SMILES strings from arguments and/or file
4142
smiles_list = list(smiles)

chebifier/ensemble/base_ensemble.py

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from abc import ABC
33
import torch
44
import tqdm
5+
from chebai.preprocessing.datasets.chebi import ChEBIOver50
6+
from chebai.result.analyse_sem import PredictionSmoother
57
from rdkit import Chem
68

79
from chebifier.prediction_models.base_predictor import BasePredictor
@@ -17,7 +19,7 @@
1719

1820
class BaseEnsemble(ABC):
1921

20-
def __init__(self, model_configs: dict):
22+
def __init__(self, model_configs: dict, chebi_version: int = 241):
2123
self.models = []
2224
self.positive_prediction_threshold = 0.5
2325
for model_name, model_config in model_configs.items():
@@ -26,6 +28,12 @@ def __init__(self, model_configs: dict):
2628
assert isinstance(model_instance, BasePredictor)
2729
self.models.append(model_instance)
2830

31+
self.smoother = PredictionSmoother(ChEBIOver50(chebi_version=chebi_version), disjoint_files=[
32+
os.path.join("data", "disjoint_chebi.csv"),
33+
os.path.join("data", "disjoint_additional.csv")
34+
])
35+
36+
2937
def gather_predictions(self, smiles_list):
3038
# get predictions from all models for the SMILES list
3139
# order them by alphabetically by label class
@@ -52,17 +60,13 @@ def gather_predictions(self, smiles_list):
5260
return ordered_logits, predicted_classes
5361

5462

55-
def consolidate_predictions(self, predictions, predicted_classes, classwise_weights, **kwargs):
63+
def consolidate_predictions(self, predictions, classwise_weights, **kwargs):
5664
"""
5765
Aggregates predictions from multiple models using weighted majority voting.
5866
Optimized version using tensor operations instead of for loops.
5967
"""
6068
num_smiles, num_classes, num_models = predictions.shape
6169

62-
# Create a mapping from class indices to class names for faster lookup
63-
class_names = list(predicted_classes.keys())
64-
class_indices = {predicted_classes[cls]: cls for cls in class_names}
65-
6670
# Get predictions for all classes
6771
valid_predictions = ~torch.isnan(predictions)
6872
valid_counts = valid_predictions.sum(dim=2) # Sum over models dimension
@@ -94,14 +98,9 @@ def consolidate_predictions(self, predictions, predicted_classes, classwise_weig
9498
net_score = positive_sum - negative_sum # Shape: (num_smiles, num_classes)
9599
class_decisions = (net_score > 0) & has_valid_predictions # Shape: (num_smiles, num_classes)
96100

97-
# Convert tensor decisions to result list using list comprehension for efficiency
98-
result = [
99-
[class_indices[idx.item()] for idx in torch.nonzero(class_decisions[i], as_tuple=True)[0]]
100-
for i in range(num_smiles)
101-
]
102101

103-
return result
104102

103+
return class_decisions
105104

106105
def calculate_classwise_weights(self, predicted_classes):
107106
"""No weights, simple majority voting"""
@@ -128,14 +127,26 @@ def predict_smiles_list(self, smiles_list, load_preds_if_possible=True) -> list:
128127
predicted_classes = {line.strip(): i for i, line in enumerate(f.readlines())}
129128

130129
classwise_weights = self.calculate_classwise_weights(predicted_classes)
131-
aggregated_predictions = self.consolidate_predictions(ordered_predictions, predicted_classes, classwise_weights)
132-
return aggregated_predictions
130+
class_decisions = self.consolidate_predictions(ordered_predictions, classwise_weights)
131+
# Smooth predictions
132+
class_names = list(predicted_classes.keys())
133+
self.smoother.label_names = class_names
134+
class_decisions = self.smoother(class_decisions)
135+
136+
class_names = list(predicted_classes.keys())
137+
class_indices = {predicted_classes[cls]: cls for cls in class_names}
138+
result = [
139+
[class_indices[idx.item()] for idx in torch.nonzero(i, as_tuple=True)[0]]
140+
for i in class_decisions
141+
]
142+
143+
return result
133144

134145
if __name__ == "__main__":
135146
ensemble = BaseEnsemble({"resgated_0ps1g189":{
136147
"type": "resgated",
137-
"ckpt_path": "../python-chebai/logs/downloaded_ckpts/electra_resgated_comp/resgated_80-10-10_0ps1g189_epoch=122.ckpt",
138-
"target_labels_path": "../python-chebai/data/chebi_v241/ChEBI50/processed/classes.txt",
148+
"ckpt_path": "data/0ps1g189/epoch=122.ckpt",
149+
"target_labels_path": "data/chebi_v241/ChEBI50/processed/classes.txt",
139150
"molecular_properties": [
140151
"chebai_graph.preprocessing.properties.AtomType",
141152
"chebai_graph.preprocessing.properties.NumAtomBonds",
@@ -148,14 +159,14 @@ def predict_smiles_list(self, smiles_list, load_preds_if_possible=True) -> list:
148159
"chebai_graph.preprocessing.properties.BondAromaticity",
149160
"chebai_graph.preprocessing.properties.RDKit2DNormalized",
150161
],
151-
"classwise_weights_path" : "../python-chebai/metrics_0ps1g189_80-10-10.json"
162+
#"classwise_weights_path" : "../python-chebai/metrics_0ps1g189_80-10-10.json"
152163
},
153164

154165
"electra_14ko0zcf": {
155166
"type": "electra",
156-
"ckpt_path": "../python-chebai/logs/downloaded_ckpts/electra_resgated_comp/electra_80-10-10_14ko0zcf_epoch=193.ckpt",
157-
"target_labels_path": "../python-chebai/data/chebi_v241/ChEBI50/processed/classes.txt",
158-
"classwise_weights_path": "../python-chebai/metrics_electra_14ko0zcf_80-10-10.json",
167+
"ckpt_path": "data/14ko0zcf/epoch=193.ckpt",
168+
"target_labels_path": "data/chebi_v241/ChEBI50/processed/classes.txt",
169+
#"classwise_weights_path": "../python-chebai/metrics_electra_14ko0zcf_80-10-10.json",
159170
}
160171
})
161172
r = ensemble.predict_smiles_list(["[NH3+]CCCC[C@H](NC(=O)[C@@H]([NH3+])CC([O-])=O)C([O-])=O"], load_preds_if_possible=False)

0 commit comments

Comments
 (0)