Skip to content

Commit 2b9f335

Browse files
committed
gather_predictions will return predicted_classes_dict
1 parent 2c2aba2 commit 2b9f335

File tree

1 file changed

+48
-22
lines changed

1 file changed

+48
-22
lines changed

chebifier/ensemble/base_ensemble.py

Lines changed: 48 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import os
22
from abc import ABC
3+
34
import torch
45
import tqdm
5-
from rdkit import Chem
66

77
from chebifier.prediction_models.base_predictor import BasePredictor
88
from chebifier.prediction_models.chemlog_predictor import ChemLogPredictor
@@ -12,11 +12,11 @@
1212
MODEL_TYPES = {
1313
"electra": ElectraPredictor,
1414
"resgated": ResGatedPredictor,
15-
"chemlog": ChemLogPredictor
15+
"chemlog": ChemLogPredictor,
1616
}
1717

18-
class BaseEnsemble(ABC):
1918

19+
class BaseEnsemble(ABC):
2020
def __init__(self, model_configs: dict):
2121
self.models = []
2222
self.positive_prediction_threshold = 0.5
@@ -37,22 +37,30 @@ def gather_predictions(self, smiles_list):
3737
if logits_for_smiles is not None:
3838
for cls in logits_for_smiles:
3939
predicted_classes.add(cls)
40-
print(f"Sorting predictions...")
40+
print("Sorting predictions...")
4141
predicted_classes = sorted(list(predicted_classes))
4242
predicted_classes_dict = {cls: i for i, cls in enumerate(predicted_classes)}
43-
ordered_logits = torch.zeros(len(smiles_list), len(predicted_classes), len(self.models)) * torch.nan
43+
ordered_logits = (
44+
torch.zeros(len(smiles_list), len(predicted_classes), len(self.models))
45+
* torch.nan
46+
)
4447
for i, model_prediction in enumerate(model_predictions):
45-
for j, logits_for_smiles in tqdm.tqdm(enumerate(model_prediction),
46-
total=len(model_prediction),
47-
desc=f"Sorting predictions for {self.models[i].model_name}"):
48+
for j, logits_for_smiles in tqdm.tqdm(
49+
enumerate(model_prediction),
50+
total=len(model_prediction),
51+
desc=f"Sorting predictions for {self.models[i].model_name}",
52+
):
4853
if logits_for_smiles is not None:
4954
for cls in logits_for_smiles:
50-
ordered_logits[j, predicted_classes_dict[cls], i] = logits_for_smiles[cls]
55+
ordered_logits[j, predicted_classes_dict[cls], i] = (
56+
logits_for_smiles[cls]
57+
)
5158

52-
return ordered_logits, predicted_classes
59+
return ordered_logits, predicted_classes_dict
5360

54-
55-
def consolidate_predictions(self, predictions, predicted_classes, classwise_weights, **kwargs):
61+
def consolidate_predictions(
62+
self, predictions, predicted_classes, classwise_weights, **kwargs
63+
):
5664
"""
5765
Aggregates predictions from multiple models using weighted majority voting.
5866
Optimized version using tensor operations instead of for loops.
@@ -74,7 +82,9 @@ def consolidate_predictions(self, predictions, predicted_classes, classwise_weig
7482
positive_mask = (predictions > 0.5) & valid_predictions
7583
negative_mask = (predictions < 0.5) & valid_predictions
7684

77-
confidence = 2 * torch.abs(predictions.nan_to_num() - self.positive_prediction_threshold)
85+
confidence = 2 * torch.abs(
86+
predictions.nan_to_num() - self.positive_prediction_threshold
87+
)
7888

7989
# Extract positive and negative weights
8090
pos_weights = classwise_weights[0] # Shape: (num_classes, num_models)
@@ -83,26 +93,34 @@ def consolidate_predictions(self, predictions, predicted_classes, classwise_weig
8393
# Calculate weighted predictions using broadcasting
8494
# predictions shape: (num_smiles, num_classes, num_models)
8595
# weights shape: (num_classes, num_models)
86-
positive_weighted = positive_mask.float() * confidence * pos_weights.unsqueeze(0)
87-
negative_weighted = negative_mask.float() * confidence * neg_weights.unsqueeze(0)
96+
positive_weighted = (
97+
positive_mask.float() * confidence * pos_weights.unsqueeze(0)
98+
)
99+
negative_weighted = (
100+
negative_mask.float() * confidence * neg_weights.unsqueeze(0)
101+
)
88102

89103
# Sum over models dimension
90104
positive_sum = positive_weighted.sum(dim=2) # Shape: (num_smiles, num_classes)
91105
negative_sum = negative_weighted.sum(dim=2) # Shape: (num_smiles, num_classes)
92106

93107
# Determine which classes to include for each SMILES
94108
net_score = positive_sum - negative_sum # Shape: (num_smiles, num_classes)
95-
class_decisions = (net_score > 0) & has_valid_predictions # Shape: (num_smiles, num_classes)
109+
class_decisions = (
110+
net_score > 0
111+
) & has_valid_predictions # Shape: (num_smiles, num_classes)
96112

97113
# Convert tensor decisions to result list using list comprehension for efficiency
98114
result = [
99-
[class_indices[idx.item()] for idx in torch.nonzero(class_decisions[i], as_tuple=True)[0]]
115+
[
116+
class_indices[idx.item()]
117+
for idx in torch.nonzero(class_decisions[i], as_tuple=True)[0]
118+
]
100119
for i in range(num_smiles)
101120
]
102121

103122
return result
104123

105-
106124
def calculate_classwise_weights(self, predicted_classes):
107125
"""No weights, simple majority voting"""
108126
positive_weights = torch.ones(len(predicted_classes), len(self.models))
@@ -114,18 +132,26 @@ def predict_smiles_list(self, smiles_list, load_preds_if_possible=True) -> list:
114132
preds_file = f"predictions_by_model_{'_'.join(model.model_name for model in self.models)}.pt"
115133
predicted_classes_file = f"predicted_classes_{'_'.join(model.model_name for model in self.models)}.txt"
116134
if not load_preds_if_possible or not os.path.isfile(preds_file):
117-
ordered_predictions, predicted_classes = self.gather_predictions(smiles_list)
135+
ordered_predictions, predicted_classes = self.gather_predictions(
136+
smiles_list
137+
)
118138
# save predictions
119139
torch.save(ordered_predictions, preds_file)
120140
with open(predicted_classes_file, "w") as f:
121141
for cls in predicted_classes:
122142
f.write(f"{cls}\n")
123143
else:
124-
print(f"Loading predictions from {preds_file} and label indexes from {predicted_classes_file}")
144+
print(
145+
f"Loading predictions from {preds_file} and label indexes from {predicted_classes_file}"
146+
)
125147
ordered_predictions = torch.load(preds_file)
126148
with open(predicted_classes_file, "r") as f:
127-
predicted_classes = {line.strip(): i for i, line in enumerate(f.readlines())}
149+
predicted_classes = {
150+
line.strip(): i for i, line in enumerate(f.readlines())
151+
}
128152

129153
classwise_weights = self.calculate_classwise_weights(predicted_classes)
130-
aggregated_predictions = self.consolidate_predictions(ordered_predictions, predicted_classes, classwise_weights)
154+
aggregated_predictions = self.consolidate_predictions(
155+
ordered_predictions, predicted_classes, classwise_weights
156+
)
131157
return aggregated_predictions

0 commit comments

Comments
 (0)