Skip to content

Commit 598ed6b

Browse files
committed
fix weight calculation if a model does not make predictions for all classes lists in its trust-weights file
1 parent a11b0b7 commit 598ed6b

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

chebifier/ensemble/weighted_majority_ensemble.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,14 @@ def calculate_classwise_weights(self, predicted_classes):
4545
if model.classwise_weights is None:
4646
continue
4747
for cls, weights in model.classwise_weights.items():
48-
if (2 * weights["TP"] + weights["FP"] + weights["FN"]) > 0:
49-
f1 = (
50-
2
51-
* weights["TP"]
52-
/ (2 * weights["TP"] + weights["FP"] + weights["FN"])
53-
)
54-
weights_by_cls[predicted_classes[cls], j] *= 1 + f1
48+
if cls in predicted_classes:
49+
if (2 * weights["TP"] + weights["FP"] + weights["FN"]) > 0:
50+
f1 = (
51+
2
52+
* weights["TP"]
53+
/ (2 * weights["TP"] + weights["FP"] + weights["FN"])
54+
)
55+
weights_by_cls[predicted_classes[cls], j] *= 1 + f1
5556

5657
print("Calculated model weightings. The average weights are:")
5758
for i, model in enumerate(self.models):

0 commit comments

Comments
 (0)