Skip to content

Commit 434b52a

Browse files
committed
add weighted majority voting, gnn and chemlog
1 parent 6f2c3ab commit 434b52a

File tree

7 files changed

+278
-51
lines changed

7 files changed

+278
-51
lines changed

chebifier/cli.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,27 @@
55
import yaml
66
import sys
77
from chebifier.ensemble.base_ensemble import BaseEnsemble
8+
from chebifier.ensemble.weighted_majority_ensemble import WMVwithPPVNPVEnsemble, WMVwithF1Ensemble
89

910

1011
@click.group()
1112
def cli():
1213
"""Command line interface for Chebifier."""
1314
pass
1415

16+
ENSEMBLES = {
17+
"mv": BaseEnsemble,
18+
"wmv-ppvnpv": WMVwithPPVNPVEnsemble,
19+
"wmv-f1": WMVwithF1Ensemble
20+
}
1521

1622
@cli.command()
1723
@click.argument('config_file', type=click.Path(exists=True))
1824
@click.option('--smiles', '-s', multiple=True, help='SMILES strings to predict')
1925
@click.option('--smiles-file', '-f', type=click.Path(exists=True), help='File containing SMILES strings (one per line)')
2026
@click.option('--output', '-o', type=click.Path(), help='Output file to save predictions (optional)')
21-
def predict(config_file, smiles, smiles_file, output):
27+
@click.option('--ensemble-type', '-e', type=click.Choice(ENSEMBLES.keys()), default='mv', help='Type of ensemble to use (default: Majority Voting)')
28+
def predict(config_file, smiles, smiles_file, output, ensemble_type):
2229
"""Predict ChEBI classes for SMILES strings using an ensemble model.
2330
2431
CONFIG_FILE is the path to a YAML configuration file for the ensemble model.
@@ -28,7 +35,7 @@ def predict(config_file, smiles, smiles_file, output):
2835
config = yaml.safe_load(f)
2936

3037
# Instantiate ensemble model
31-
ensemble = BaseEnsemble(config)
38+
ensemble = ENSEMBLES[ensemble_type](config)
3239

3340
# Collect SMILES strings from arguments and/or file
3441
smiles_list = list(smiles)
Lines changed: 75 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
1+
import os
12
from abc import ABC
23
import torch
34
import tqdm
45
from rdkit import Chem
56

67
from chebifier.prediction_models.base_predictor import BasePredictor
8+
from chebifier.prediction_models.chemlog_predictor import ChemLogPredictor
79
from chebifier.prediction_models.electra_predictor import ElectraPredictor
10+
from chebifier.prediction_models.gnn_predictor import ResGatedPredictor
811

912
MODEL_TYPES = {
1013
"electra": ElectraPredictor,
11-
# todo add other model types here
14+
"resgated": ResGatedPredictor,
15+
"chemlog": ChemLogPredictor
1216
}
1317

1418
class BaseEnsemble(ABC):
@@ -22,70 +26,73 @@ def __init__(self, model_configs: dict):
2226
self.models.append(model_instance)
2327

2428
def gather_predictions(self, smiles_list):
25-
"""
26-
27-
:param smiles_list: list of SMILES strings to predict
28-
:return:
29-
ordered_predictions: torch.Tensor of shape (num_smiles, num_classes, num_models)
30-
predicted_classes: list of ChEBI IDs predicted by the models
31-
"""
3229
model_predictions = []
3330
predicted_classes = set()
3431
for model in self.models:
3532
model_predictions.append(model.predict_smiles_list(smiles_list))
36-
for predicted_smiles in model_predictions[-1]:
37-
if predicted_smiles is not None:
38-
for cls in predicted_smiles:
33+
for predicted_labels_for_smiles in model_predictions[-1]:
34+
if predicted_labels_for_smiles is not None:
35+
for cls in predicted_labels_for_smiles:
3936
predicted_classes.add(cls)
4037
print(f"Sorting predictions...")
4138
predicted_classes = sorted(list(predicted_classes))
39+
predicted_classes = {cls: i for i, cls in enumerate(predicted_classes)}
4240
ordered_predictions = torch.zeros(len(smiles_list), len(predicted_classes), len(self.models)) * torch.nan
4341
for i, model_prediction in enumerate(model_predictions):
44-
for j, predicted_smiles in tqdm.tqdm(enumerate(model_prediction),
42+
for j, predicted_labels_for_smiles in tqdm.tqdm(enumerate(model_prediction),
4543
total=len(model_prediction),
4644
desc=f"Sorting predictions for {self.models[i].model_name}"):
47-
if predicted_smiles is not None:
48-
for cls in predicted_smiles:
49-
ordered_predictions[j, predicted_classes.index(cls), i] = predicted_smiles[cls]
45+
if predicted_labels_for_smiles is not None:
46+
for cls in predicted_labels_for_smiles:
47+
ordered_predictions[j, predicted_classes[cls], i] = predicted_labels_for_smiles[cls]
5048
return ordered_predictions, predicted_classes
5149

5250

53-
def aggregate_predictions(self, predictions, predicted_classes, **kwargs):
51+
def consolidate_predictions(self, predictions, predicted_classes, classwise_weights, **kwargs):
5452
"""
55-
Aggregates predictions from multiple models using majority voting.
56-
57-
:param predictions: torch.Tensor of shape (num_smiles, num_classes, num_models)
58-
:param predicted_classes: list of ChEBI IDs predicted by the models
59-
:param kwargs: Additional arguments
60-
:return: list of lists, where each inner list contains the class IDs that received
61-
positive predictions from the majority of models for a given SMILES
53+
Aggregates predictions from multiple models using weighted majority voting.
54+
Optimized version using tensor operations instead of for loops.
6255
"""
6356
num_smiles, num_classes, num_models = predictions.shape
64-
result = []
6557

66-
for i in tqdm.tqdm(range(num_smiles), total=num_smiles, desc="Aggregating predictions"):
67-
smiles_result = []
68-
for j in range(num_classes):
69-
# Get predictions for this SMILES and class across all models
70-
class_predictions = predictions[i, j, :]
58+
# Create a mapping from class indices to class names for faster lookup
59+
class_names = list(predicted_classes.keys())
60+
class_indices = {predicted_classes[cls]: cls for cls in class_names}
61+
62+
# Get predictions for all classes
63+
valid_predictions = ~torch.isnan(predictions)
64+
valid_counts = valid_predictions.sum(dim=2) # Sum over models dimension
65+
66+
# Skip classes with no valid predictions
67+
has_valid_predictions = valid_counts > 0
68+
69+
# Calculate positive and negative predictions for all classes at once
70+
positive_mask = (predictions > 0.5) & valid_predictions
71+
negative_mask = (predictions < 0.5) & valid_predictions
7172

72-
# Count models that made a prediction (not NaN)
73-
valid_predictions = ~torch.isnan(class_predictions)
74-
num_valid_predictions = valid_predictions.sum().item()
73+
# Extract positive and negative weights
74+
pos_weights = classwise_weights[0] # Shape: (num_classes, num_models)
75+
neg_weights = classwise_weights[1] # Shape: (num_classes, num_models)
7576

76-
# If no valid predictions, skip this class
77-
if num_valid_predictions == 0:
78-
continue
77+
# Calculate weighted predictions using broadcasting
78+
# predictions shape: (num_smiles, num_classes, num_models)
79+
# weights shape: (num_classes, num_models)
80+
positive_weighted = positive_mask.float() * (predictions.nan_to_num() - 0.5) * pos_weights.unsqueeze(0)
81+
negative_weighted = negative_mask.float() * (0.5 - predictions.nan_to_num()) * neg_weights.unsqueeze(0)
7982

80-
# Count positive predictions (assuming positive is > 0)
81-
positive_predictions = class_predictions > 0
82-
num_positive = (positive_predictions & valid_predictions).sum().item()
83+
# Sum over models dimension
84+
positive_sum = positive_weighted.sum(dim=2) # Shape: (num_smiles, num_classes)
85+
negative_sum = negative_weighted.sum(dim=2) # Shape: (num_smiles, num_classes)
8386

84-
# If majority of models that made a prediction are positive, add this class
85-
if num_positive > num_valid_predictions / 2:
86-
smiles_result.append(predicted_classes[j])
87+
# Determine which classes to include for each SMILES
88+
net_score = positive_sum - negative_sum # Shape: (num_smiles, num_classes)
89+
class_decisions = (net_score > 0) & has_valid_predictions # Shape: (num_smiles, num_classes)
8790

88-
result.append(smiles_result)
91+
# Convert tensor decisions to result list using list comprehension for efficiency
92+
result = [
93+
[class_indices[idx.item()] for idx in torch.nonzero(class_decisions[i], as_tuple=True)[0]]
94+
for i in range(num_smiles)
95+
]
8996

9097
return result
9198

@@ -102,8 +109,30 @@ def normalize_smiles_list(self, smiles_list):
102109
new.append(canonical_smiles)
103110
return new
104111

105-
def predict_smiles_list(self, smiles_list) -> list:
106-
#smiles_list = self.normalize_smiles_list(smiles_list)
107-
ordered_predictions, predicted_classes = self.gather_predictions(smiles_list)
108-
aggregated_predictions = self.aggregate_predictions(ordered_predictions, predicted_classes)
112+
def calculate_classwise_weights(self, predicted_classes):
113+
"""No weights, simple majority voting"""
114+
positive_weights = torch.ones(len(predicted_classes), len(self.models))
115+
negative_weights = torch.ones(len(predicted_classes), len(self.models))
116+
117+
return positive_weights, negative_weights
118+
119+
def predict_smiles_list(self, smiles_list, load_preds_if_possible=True) -> list:
120+
preds_file = f"predictions_by_model_{'_'.join(model.model_name for model in self.models)}.pt"
121+
predicted_classes_file = f"predicted_classes_{'_'.join(model.model_name for model in self.models)}.txt"
122+
if not load_preds_if_possible or not os.path.isfile(preds_file):
123+
#smiles_list = self.normalize_smiles_list(smiles_list)
124+
ordered_predictions, predicted_classes = self.gather_predictions(smiles_list)
125+
# save predictions
126+
torch.save(ordered_predictions, preds_file)
127+
with open(predicted_classes_file, "w") as f:
128+
for cls in predicted_classes:
129+
f.write(f"{cls}\n")
130+
else:
131+
print(f"Loading predictions from {preds_file} and label indexes from {predicted_classes_file}")
132+
ordered_predictions = torch.load(preds_file)
133+
with open(predicted_classes_file, "r") as f:
134+
predicted_classes = {line.strip(): i for i, line in enumerate(f.readlines())}
135+
136+
classwise_weights = self.calculate_classwise_weights(predicted_classes)
137+
aggregated_predictions = self.consolidate_predictions(ordered_predictions, predicted_classes, classwise_weights)
109138
return aggregated_predictions
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import torch
2+
3+
from chebifier.ensemble.base_ensemble import BaseEnsemble
4+
5+
6+
7+
class WMVwithPPVNPVEnsemble(BaseEnsemble):
8+
9+
def calculate_classwise_weights(self, predicted_classes):
10+
"""
11+
Given the positions of predicted classes in the predictions tensor, assign weights to each class. The
12+
result is two tensors of shape (num_predicted_classes, num_models). The weight for each class is the model_weight
13+
(default: 1) multiplied by the class-specific positive / negative weight (default 1).
14+
"""
15+
positive_weights = torch.ones(len(predicted_classes), len(self.models))
16+
negative_weights = torch.ones(len(predicted_classes), len(self.models))
17+
for j, model in enumerate(self.models):
18+
positive_weights[:, j] *= model.model_weight
19+
negative_weights[:, j] *= model.model_weight
20+
if model.classwise_weights is None:
21+
continue
22+
for cls, weights in model.classwise_weights.items():
23+
positive_weights[predicted_classes[cls], j] *= weights["PPV"]
24+
negative_weights[predicted_classes[cls], j] *= weights["NPV"]
25+
26+
print(f"Calculated model weightings. The averages for positive / negative weights are:")
27+
for i, model in enumerate(self.models):
28+
print(f"{model.model_name}: {positive_weights[:, i].mean().item():.3f} / {negative_weights[:, i].mean().item():.3f}")
29+
30+
return positive_weights, negative_weights
31+
32+
33+
class WMVwithF1Ensemble(BaseEnsemble):
34+
35+
def calculate_classwise_weights(self, predicted_classes):
36+
"""
37+
Given the positions of predicted classes in the predictions tensor, assign weights to each class. The
38+
result is two tensors of shape (num_predicted_classes, num_models). The weight for each class is the model_weight
39+
(default: 1) multiplied by the class-specific validation-f1 (default 1).
40+
"""
41+
weights_by_cls = torch.ones(len(predicted_classes), len(self.models))
42+
for j, model in enumerate(self.models):
43+
weights_by_cls[:, j] *= model.model_weight
44+
if model.classwise_weights is None:
45+
continue
46+
for cls, weights in model.classwise_weights.items():
47+
f1 = 2 * weights["TP"] / (2 * weights["TP"] + weights["FP"] + weights["FN"])
48+
weights_by_cls[predicted_classes[cls], j] *= f1
49+
50+
print(f"Calculated model weightings. The average weights are:")
51+
for i, model in enumerate(self.models):
52+
print(f"{model.model_name}: {weights_by_cls[:, i].mean().item():.3f}")
53+
54+
return weights_by_cls, weights_by_cls
Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
from abc import ABC
2-
2+
import json
33

44
class BasePredictor(ABC):
55

6-
def __init__(self, model_name: str, **kwargs):
6+
def __init__(self, model_name: str, model_weight: int = 1, classwise_weights_path: str = None, **kwargs):
77
self.model_name = model_name
8+
self.model_weight = model_weight
9+
if classwise_weights_path is not None:
10+
self.classwise_weights = json.load(open(classwise_weights_path, encoding="utf-8"))
11+
else:
12+
self.classwise_weights = None
13+
814

915
def predict_smiles_list(self, smiles_list: list[str]) -> dict:
1016
raise NotImplementedError
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import tqdm
2+
3+
from chebifier.prediction_models.base_predictor import BasePredictor
4+
from chemlog.alg_classification.charge_classifier import AlgChargeClassifier
5+
from chemlog.alg_classification.peptide_size_classifier import AlgPeptideSizeClassifier
6+
from chemlog.alg_classification.proteinogenics_classifier import AlgProteinogenicsClassifier
7+
from chemlog.alg_classification.substructure_classifier import AlgSubstructureClassifier
8+
from chemlog.cli import strategy_call, _smiles_to_mol, CLASSIFIERS
9+
10+
class ChemLogPredictor(BasePredictor):
11+
12+
def __init__(self, model_name: str, **kwargs):
13+
super().__init__(model_name, **kwargs)
14+
self.strategy = "algo"
15+
self.classifier_instances = {
16+
k: v() for k, v in CLASSIFIERS[self.strategy].items()
17+
}
18+
self.peptide_labels = ["15841", "16670", "24866", "25676", "25696", "25697", "27369", "46761", "47923",
19+
"48030", "48545", "60194", "60334", "60466", "64372", "65061", "90799", "155837"]
20+
21+
print(f"Initialised ChemLog model {self.model_name}")
22+
23+
def predict_smiles_list(self, smiles_list: list[str]) -> list:
24+
results = []
25+
for i, smiles in tqdm.tqdm(enumerate(smiles_list)):
26+
mol = _smiles_to_mol(smiles)
27+
if mol is None:
28+
results.append(None)
29+
else:
30+
results.append({label: 1 if label in strategy_call(self.strategy, self.classifier_instances, mol)["chebi_classes"] else 0 for label in self.peptide_labels})
31+
32+
for classifier in self.classifier_instances.values():
33+
classifier.on_finish()
34+
35+
return results

0 commit comments

Comments
 (0)