Skip to content

Commit c13423c

Browse files
committed
Merge branch 'dev' into feature/api_downloadble_models
2 parents 02c5409 + ca37c7d commit c13423c

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

chebifier/ensemble/base_ensemble.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,15 @@ def gather_predictions(self, smiles_list):
3939
predicted_classes.add(cls)
4040
print(f"Sorting predictions...")
4141
predicted_classes = sorted(list(predicted_classes))
42-
predicted_classes = {cls: i for i, cls in enumerate(predicted_classes)}
42+
predicted_classes_dict = {cls: i for i, cls in enumerate(predicted_classes)}
4343
ordered_logits = torch.zeros(len(smiles_list), len(predicted_classes), len(self.models)) * torch.nan
4444
for i, model_prediction in enumerate(model_predictions):
4545
for j, logits_for_smiles in tqdm.tqdm(enumerate(model_prediction),
4646
total=len(model_prediction),
4747
desc=f"Sorting predictions for {self.models[i].model_name}"):
4848
if logits_for_smiles is not None:
4949
for cls in logits_for_smiles:
50-
ordered_logits[j, predicted_classes[cls], i] = logits_for_smiles[cls]
50+
ordered_logits[j, predicted_classes_dict[cls], i] = logits_for_smiles[cls]
5151

5252
return ordered_logits, predicted_classes
5353

@@ -114,7 +114,7 @@ def predict_smiles_list(self, smiles_list, load_preds_if_possible=True) -> list:
114114
preds_file = f"predictions_by_model_{'_'.join(model.model_name for model in self.models)}.pt"
115115
predicted_classes_file = f"predicted_classes_{'_'.join(model.model_name for model in self.models)}.txt"
116116
if not load_preds_if_possible or not os.path.isfile(preds_file):
117-
ordered_predictions = predicted_classes = self.gather_predictions(smiles_list)
117+
ordered_predictions, predicted_classes = self.gather_predictions(smiles_list)
118118
# save predictions
119119
torch.save(ordered_predictions, preds_file)
120120
with open(predicted_classes_file, "w") as f:

chebifier/ensemble/weighted_majority_ensemble.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,9 @@ def calculate_classwise_weights(self, predicted_classes):
4444
if model.classwise_weights is None:
4545
continue
4646
for cls, weights in model.classwise_weights.items():
47-
f1 = 2 * weights["TP"] / (2 * weights["TP"] + weights["FP"] + weights["FN"])
48-
weights_by_cls[predicted_classes[cls], j] *= f1
47+
if (2 * weights["TP"] + weights["FP"] + weights["FN"]) > 0:
48+
f1 = 2 * weights["TP"] / (2 * weights["TP"] + weights["FP"] + weights["FN"])
49+
weights_by_cls[predicted_classes[cls], j] *= f1
4950

5051
print(f"Calculated model weightings. The average weights are:")
5152
for i, model in enumerate(self.models):

0 commit comments

Comments
 (0)