Skip to content

Commit 5cc2eba

Browse files
committed
introduce threshold as parameter, use confidence
1 parent 434b52a commit 5cc2eba

File tree

1 file changed

+18
-25
lines changed

1 file changed

+18
-25
lines changed

chebifier/ensemble/base_ensemble.py

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -19,33 +19,37 @@ class BaseEnsemble(ABC):
1919

2020
def __init__(self, model_configs: dict):
2121
self.models = []
22+
self.positive_prediction_threshold = 0.5
2223
for model_name, model_config in model_configs.items():
2324
model_cls = MODEL_TYPES[model_config["type"]]
2425
model_instance = model_cls(**model_config)
2526
assert isinstance(model_instance, BasePredictor)
2627
self.models.append(model_instance)
2728

2829
def gather_predictions(self, smiles_list):
30+
# get predictions from all models for the SMILES list
31+
# order them by alphabetically by label class
2932
model_predictions = []
3033
predicted_classes = set()
3134
for model in self.models:
3235
model_predictions.append(model.predict_smiles_list(smiles_list))
33-
for predicted_labels_for_smiles in model_predictions[-1]:
34-
if predicted_labels_for_smiles is not None:
35-
for cls in predicted_labels_for_smiles:
36+
for logits_for_smiles in model_predictions[-1]:
37+
if logits_for_smiles is not None:
38+
for cls in logits_for_smiles:
3639
predicted_classes.add(cls)
3740
print(f"Sorting predictions...")
3841
predicted_classes = sorted(list(predicted_classes))
3942
predicted_classes = {cls: i for i, cls in enumerate(predicted_classes)}
40-
ordered_predictions = torch.zeros(len(smiles_list), len(predicted_classes), len(self.models)) * torch.nan
43+
ordered_logits = torch.zeros(len(smiles_list), len(predicted_classes), len(self.models)) * torch.nan
4144
for i, model_prediction in enumerate(model_predictions):
42-
for j, predicted_labels_for_smiles in tqdm.tqdm(enumerate(model_prediction),
45+
for j, logits_for_smiles in tqdm.tqdm(enumerate(model_prediction),
4346
total=len(model_prediction),
4447
desc=f"Sorting predictions for {self.models[i].model_name}"):
45-
if predicted_labels_for_smiles is not None:
46-
for cls in predicted_labels_for_smiles:
47-
ordered_predictions[j, predicted_classes[cls], i] = predicted_labels_for_smiles[cls]
48-
return ordered_predictions, predicted_classes
48+
if logits_for_smiles is not None:
49+
for cls in logits_for_smiles:
50+
ordered_logits[j, predicted_classes[cls], i] = logits_for_smiles[cls]
51+
52+
return ordered_logits, predicted_classes
4953

5054

5155
def consolidate_predictions(self, predictions, predicted_classes, classwise_weights, **kwargs):
@@ -70,15 +74,17 @@ def consolidate_predictions(self, predictions, predicted_classes, classwise_weig
7074
positive_mask = (predictions > 0.5) & valid_predictions
7175
negative_mask = (predictions < 0.5) & valid_predictions
7276

77+
confidence = 2 * torch.abs(predictions.nan_to_num() - self.positive_prediction_threshold)
78+
7379
# Extract positive and negative weights
7480
pos_weights = classwise_weights[0] # Shape: (num_classes, num_models)
7581
neg_weights = classwise_weights[1] # Shape: (num_classes, num_models)
7682

7783
# Calculate weighted predictions using broadcasting
7884
# predictions shape: (num_smiles, num_classes, num_models)
7985
# weights shape: (num_classes, num_models)
80-
positive_weighted = positive_mask.float() * (predictions.nan_to_num() - 0.5) * pos_weights.unsqueeze(0)
81-
negative_weighted = negative_mask.float() * (0.5 - predictions.nan_to_num()) * neg_weights.unsqueeze(0)
86+
positive_weighted = positive_mask.float() * confidence * pos_weights.unsqueeze(0)
87+
negative_weighted = negative_mask.float() * confidence * neg_weights.unsqueeze(0)
8288

8389
# Sum over models dimension
8490
positive_sum = positive_weighted.sum(dim=2) # Shape: (num_smiles, num_classes)
@@ -96,18 +102,6 @@ def consolidate_predictions(self, predictions, predicted_classes, classwise_weig
96102

97103
return result
98104

99-
def normalize_smiles_list(self, smiles_list):
100-
new = []
101-
print(f"Normalizing SMILES strings...")
102-
for smiles in tqdm.tqdm(smiles_list):
103-
try:
104-
mol = Chem.MolFromSmiles(smiles)
105-
canonical_smiles = Chem.MolToSmiles(mol)
106-
except Exception as e:
107-
print(f"Failed to parse SMILES '{smiles}': {e}")
108-
canonical_smiles = None
109-
new.append(canonical_smiles)
110-
return new
111105

112106
def calculate_classwise_weights(self, predicted_classes):
113107
"""No weights, simple majority voting"""
@@ -120,8 +114,7 @@ def predict_smiles_list(self, smiles_list, load_preds_if_possible=True) -> list:
120114
preds_file = f"predictions_by_model_{'_'.join(model.model_name for model in self.models)}.pt"
121115
predicted_classes_file = f"predicted_classes_{'_'.join(model.model_name for model in self.models)}.txt"
122116
if not load_preds_if_possible or not os.path.isfile(preds_file):
123-
#smiles_list = self.normalize_smiles_list(smiles_list)
124-
ordered_predictions, predicted_classes = self.gather_predictions(smiles_list)
117+
ordered_predictions = predicted_classes = self.gather_predictions(smiles_list)
125118
# save predictions
126119
torch.save(ordered_predictions, preds_file)
127120
with open(predicted_classes_file, "w") as f:

0 commit comments

Comments
 (0)