1+ from typing import Tuple
2+
3+ import numpy as np
14import torch
25from jsonargparse import CLI
36from torchmetrics .functional .classification import multilabel_auroc
47
8+ from chebai .callbacks .epoch_metrics import MacroF1
59from chebai .result .utils import load_results_from_buffer
610
711
@@ -48,7 +52,7 @@ def validate_eval_dir(label_files: torch.Tensor, pred_files: torch.Tensor) -> No
4852
4953 def evaluate (self ) -> None :
5054 """
51- Loads predictions and labels, validates file correspondence, and calculates Multilabel AUROC.
55+ Loads predictions and labels, validates file correspondence, and calculates Multilabel AUROC and Fmax .
5256 """
5357 test_preds , test_labels = load_results_from_buffer (self .eval_dir , self .device )
5458 self .validate_eval_dir (test_labels , test_preds )
@@ -60,11 +64,44 @@ def evaluate(self) -> None:
6064
6165 print ("Multilabel AUC-ROC:" , ml_auroc )
6266
67+ fmax , threshold = self .calculate_fmax (test_preds , test_labels )
68+ print (f"F-max : { fmax } , threshold: { threshold } " )
69+
70+ def calculate_fmax (
71+ self , test_preds : torch .Tensor , test_labels : torch .Tensor
72+ ) -> Tuple [float , float ]:
73+ """
74+ Calculates the Fmax metric using the F1 score at various thresholds.
75+
76+ Args:
77+ test_preds (torch.Tensor): Predicted scores for the labels.
78+ test_labels (torch.Tensor): True labels for the evaluation.
79+
80+ Returns:
81+ Tuple[float, float]: The maximum F1 score and the corresponding threshold.
82+ """
83+ thresholds = np .linspace (0 , 1 , 100 )
84+ fmax = 0.0
85+ best_threshold = 0.0
86+
87+ for t in thresholds :
88+ custom_f1_metric = MacroF1 (num_labels = self .num_labels , threshold = t )
89+ custom_f1_metric .update (test_preds , test_labels )
90+ custom_f1_metric_score = custom_f1_metric .compute ().item ()
91+
92+ # Check if the current score is the best we've seen
93+ if custom_f1_metric_score > fmax :
94+ fmax = custom_f1_metric_score
95+ best_threshold = t
96+
97+ return fmax , best_threshold
98+
6399
64100class Main :
65101 def evaluate (self , eval_dir : str ):
66102 EvaluatePredictions (eval_dir ).evaluate ()
67103
68104
69105if __name__ == "__main__" :
106+ # evaluate_predictions.py evaluate <path/to/file>
70107 CLI (Main )
0 commit comments