Skip to content

Commit 6c0fce1

Browse files
committed
add fmax to evaluation script
1 parent 264bd94 commit 6c0fce1

File tree

1 file changed

+38
-1
lines changed

1 file changed

+38
-1
lines changed

chebai/result/evaluate_predictions.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1+
from typing import Tuple
2+
3+
import numpy as np
14
import torch
25
from jsonargparse import CLI
36
from torchmetrics.functional.classification import multilabel_auroc
47

8+
from chebai.callbacks.epoch_metrics import MacroF1
59
from chebai.result.utils import load_results_from_buffer
610

711

@@ -48,7 +52,7 @@ def validate_eval_dir(label_files: torch.Tensor, pred_files: torch.Tensor) -> No
4852

4953
def evaluate(self) -> None:
5054
"""
51-
Loads predictions and labels, validates file correspondence, and calculates Multilabel AUROC.
55+
Loads predictions and labels, validates file correspondence, and calculates Multilabel AUROC and Fmax.
5256
"""
5357
test_preds, test_labels = load_results_from_buffer(self.eval_dir, self.device)
5458
self.validate_eval_dir(test_labels, test_preds)
@@ -60,11 +64,44 @@ def evaluate(self) -> None:
6064

6165
print("Multilabel AUC-ROC:", ml_auroc)
6266

67+
fmax, threshold = self.calculate_fmax(test_preds, test_labels)
68+
print(f"F-max : {fmax}, threshold: {threshold}")
69+
70+
def calculate_fmax(
71+
self, test_preds: torch.Tensor, test_labels: torch.Tensor
72+
) -> Tuple[float, float]:
73+
"""
74+
Calculates the Fmax metric using the F1 score at various thresholds.
75+
76+
Args:
77+
test_preds (torch.Tensor): Predicted scores for the labels.
78+
test_labels (torch.Tensor): True labels for the evaluation.
79+
80+
Returns:
81+
Tuple[float, float]: The maximum F1 score and the corresponding threshold.
82+
"""
83+
thresholds = np.linspace(0, 1, 100)
84+
fmax = 0.0
85+
best_threshold = 0.0
86+
87+
for t in thresholds:
88+
custom_f1_metric = MacroF1(num_labels=self.num_labels, threshold=t)
89+
custom_f1_metric.update(test_preds, test_labels)
90+
custom_f1_metric_score = custom_f1_metric.compute().item()
91+
92+
# Check if the current score is the best we've seen
93+
if custom_f1_metric_score > fmax:
94+
fmax = custom_f1_metric_score
95+
best_threshold = t
96+
97+
return fmax, best_threshold
98+
6399

64100
class Main:
65101
def evaluate(self, eval_dir: str):
66102
EvaluatePredictions(eval_dir).evaluate()
67103

68104

69105
if __name__ == "__main__":
106+
# evaluate_predictions.py evaluate <path/to/file>
70107
CLI(Main)

0 commit comments

Comments
 (0)