55from jsonargparse import CLI
66from torchmetrics .classification import MultilabelConfusionMatrix , MultilabelF1Score
77
8- from chebai .callbacks .epoch_metrics import MacroF1
98from chebai .preprocessing .datasets .base import XYBaseDataModule
109from chebai .result .utils import (
1110 load_data_instance ,
@@ -52,14 +51,12 @@ def compute_classwise_scores(
5251 - False Positives (FP)
5352 - True Negatives (TN)
5453 - False Negatives (FN)
55- - Macro F1 score
56- - Micro F1 score
54+ - F1 score
5755
5856 Args:
5957 metrics_obj_dict: Dictionary containing pre-updated torchmetrics.Metric objects:
6058 {
6159 "cm": MultilabelConfusionMatrix,
62- "macro-f1": MacroF1 (average=None),
6360 "micro-f1": MultilabelF1Score (average=None)
6461 }
6562 class_names: List of class names in the same order as class indices.
@@ -74,27 +71,18 @@ def compute_classwise_scores(
7471 "FP": int,
7572 "FN": int,
7673 "TP": int,
77- "f1-macro": float,
78- "f1-micro": float
74+ "f1": float,
7975 },
8076 ...
8177 }
8278 """
8379 cm_tensor = metrics_obj_dict ["cm" ].compute () # Shape: (num_classes, 2, 2)
84- # shape: (num_classes,)
85- macro_f1_tensor = metrics_obj_dict ["macro-f1" ].compute ()
86- micro_f1_tensor = metrics_obj_dict ["micro-f1" ].compute ()
87-
88- assert (
89- len (class_names )
90- == cm_tensor .shape [0 ]
91- == micro_f1_tensor .shape [0 ]
92- == macro_f1_tensor .shape [0 ]
93- ), (
80+ f1_tensor = metrics_obj_dict ["f1" ].compute () # shape: (num_classes,)
81+
82+ assert len (class_names ) == cm_tensor .shape [0 ] == f1_tensor .shape [0 ], (
9483 f"Mismatch between number of class names ({ len (class_names )} ) and metric tensor sizes: "
9584 f"confusion matrix has { cm_tensor .shape [0 ]} , "
96- f"micro F1 has { micro_f1_tensor .shape [0 ]} , "
97- f"macro F1 has { macro_f1_tensor .shape [0 ]} "
85+ f"F1 has { f1_tensor .shape [0 ]} , "
9886 )
9987
10088 results : dict [str , dict [str , float ]] = {}
@@ -110,8 +98,7 @@ def compute_classwise_scores(
11098 # positive_raw = [p.item() for i, p in enumerate(raw_preds[:, idx]) if true_np[i, idx]]
11199 # negative_raw = [p.item() for i, p in enumerate(raw_preds[:, idx]) if not true_np[i, idx]]
112100
113- macro_f1 = macro_f1_tensor [idx ]
114- micro_f1 = micro_f1_tensor [idx ]
101+ f1 = f1_tensor [idx ]
115102
116103 results [cls_name ] = {
117104 "PPV" : round (ppv , 4 ),
@@ -120,8 +107,7 @@ def compute_classwise_scores(
120107 "FP" : int (fp ),
121108 "FN" : int (fn ),
122109 "TP" : int (tp ),
123- "f1-macro" : round (macro_f1 .item (), 4 ),
124- "f1-micro" : round (micro_f1 .item (), 4 ),
110+ "f1" : round (f1 .item (), 4 ),
125111 # "positive_preds": positive_raw,
126112 # "negative_preds": negative_raw,
127113 }
@@ -179,8 +165,7 @@ def generate_props(
179165 num_classes = len (class_names )
180166 metrics_obj_dict : dict [str , torchmetrics .Metric ] = {
181167 "cm" : MultilabelConfusionMatrix (num_labels = num_classes ),
182- "macro-f1" : MacroF1 (num_labels = num_classes , average = None ),
183- "micro-f1" : MultilabelF1Score (num_labels = num_classes , average = None ),
168+ "f1" : MultilabelF1Score (num_labels = num_classes , average = None ),
184169 }
185170
186171 for batch_idx , batch in enumerate (val_loader ):
@@ -191,7 +176,7 @@ def generate_props(
191176 data , labels , model_output
192177 )
193178 for metric_obj in metrics_obj_dict .values ():
194- metric_obj .update (preds = preds , target = targets )
179+ metric_obj .update (preds , targets )
195180
196181 print ("Computing metrics..." )
197182 if output_path is None :
0 commit comments