Skip to content

Commit 2da2149

Browse files
committed
there is no macro or micro for classwise f1
1 parent ec55bfb commit 2da2149

File tree

1 file changed

+10
-25
lines changed

1 file changed

+10
-25
lines changed

chebai/result/generate_class_properties.py

Lines changed: 10 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from jsonargparse import CLI
66
from torchmetrics.classification import MultilabelConfusionMatrix, MultilabelF1Score
77

8-
from chebai.callbacks.epoch_metrics import MacroF1
98
from chebai.preprocessing.datasets.base import XYBaseDataModule
109
from chebai.result.utils import (
1110
load_data_instance,
@@ -52,14 +51,12 @@ def compute_classwise_scores(
5251
- False Positives (FP)
5352
- True Negatives (TN)
5453
- False Negatives (FN)
55-
- Macro F1 score
56-
- Micro F1 score
54+
- F1 score
5755
5856
Args:
5957
metrics_obj_dict: Dictionary containing pre-updated torchmetrics.Metric objects:
6058
{
6159
"cm": MultilabelConfusionMatrix,
62-
"macro-f1": MacroF1 (average=None),
6360
"micro-f1": MultilabelF1Score (average=None)
6461
}
6562
class_names: List of class names in the same order as class indices.
@@ -74,27 +71,18 @@ def compute_classwise_scores(
7471
"FP": int,
7572
"FN": int,
7673
"TP": int,
77-
"f1-macro": float,
78-
"f1-micro": float
74+
"f1": float,
7975
},
8076
...
8177
}
8278
"""
8379
cm_tensor = metrics_obj_dict["cm"].compute() # Shape: (num_classes, 2, 2)
84-
# shape: (num_classes,)
85-
macro_f1_tensor = metrics_obj_dict["macro-f1"].compute()
86-
micro_f1_tensor = metrics_obj_dict["micro-f1"].compute()
87-
88-
assert (
89-
len(class_names)
90-
== cm_tensor.shape[0]
91-
== micro_f1_tensor.shape[0]
92-
== macro_f1_tensor.shape[0]
93-
), (
80+
f1_tensor = metrics_obj_dict["f1"].compute() # shape: (num_classes,)
81+
82+
assert len(class_names) == cm_tensor.shape[0] == f1_tensor.shape[0], (
9483
f"Mismatch between number of class names ({len(class_names)}) and metric tensor sizes: "
9584
f"confusion matrix has {cm_tensor.shape[0]}, "
96-
f"micro F1 has {micro_f1_tensor.shape[0]}, "
97-
f"macro F1 has {macro_f1_tensor.shape[0]}"
85+
f"F1 has {f1_tensor.shape[0]}, "
9886
)
9987

10088
results: dict[str, dict[str, float]] = {}
@@ -110,8 +98,7 @@ def compute_classwise_scores(
11098
# positive_raw = [p.item() for i, p in enumerate(raw_preds[:, idx]) if true_np[i, idx]]
11199
# negative_raw = [p.item() for i, p in enumerate(raw_preds[:, idx]) if not true_np[i, idx]]
112100

113-
macro_f1 = macro_f1_tensor[idx]
114-
micro_f1 = micro_f1_tensor[idx]
101+
f1 = f1_tensor[idx]
115102

116103
results[cls_name] = {
117104
"PPV": round(ppv, 4),
@@ -120,8 +107,7 @@ def compute_classwise_scores(
120107
"FP": int(fp),
121108
"FN": int(fn),
122109
"TP": int(tp),
123-
"f1-macro": round(macro_f1.item(), 4),
124-
"f1-micro": round(micro_f1.item(), 4),
110+
"f1": round(f1.item(), 4),
125111
# "positive_preds": positive_raw,
126112
# "negative_preds": negative_raw,
127113
}
@@ -179,8 +165,7 @@ def generate_props(
179165
num_classes = len(class_names)
180166
metrics_obj_dict: dict[str, torchmetrics.Metric] = {
181167
"cm": MultilabelConfusionMatrix(num_labels=num_classes),
182-
"macro-f1": MacroF1(num_labels=num_classes, average=None),
183-
"micro-f1": MultilabelF1Score(num_labels=num_classes, average=None),
168+
"f1": MultilabelF1Score(num_labels=num_classes, average=None),
184169
}
185170

186171
for batch_idx, batch in enumerate(val_loader):
@@ -191,7 +176,7 @@ def generate_props(
191176
data, labels, model_output
192177
)
193178
for metric_obj in metrics_obj_dict.values():
194-
metric_obj.update(preds=preds, target=targets)
179+
metric_obj.update(preds, targets)
195180

196181
print("Computing metrics...")
197182
if output_path is None:

0 commit comments

Comments
 (0)