Skip to content

Commit 8688c3b

Browse files
committed
update to add f1-score metrics
1 parent fcaceb5 commit 8688c3b

File tree

1 file changed

+83
-37
lines changed

1 file changed

+83
-37
lines changed

chebai/result/generate_class_properties.py

Lines changed: 83 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import json
22
from pathlib import Path
33

4-
import torch
4+
import torchmetrics
55
from jsonargparse import CLI
6-
from sklearn.metrics import multilabel_confusion_matrix
6+
from torchmetrics.classification import MultilabelConfusionMatrix, MultilabelF1Score
77

8+
from chebai.callbacks.epoch_metrics import MacroF1
89
from chebai.preprocessing.datasets.base import XYBaseDataModule
910
from chebai.result.utils import (
1011
load_data_instance,
@@ -37,48 +38,90 @@ def load_class_labels(path: Path) -> list[str]:
3738

3839
@staticmethod
3940
def compute_classwise_scores(
40-
y_true: list[torch.Tensor],
41-
y_pred: list[torch.Tensor],
42-
raw_preds: torch.Tensor,
41+
metrics_obj_dict: dict[str, torchmetrics.Metric],
4342
class_names: list[str],
4443
) -> dict[str, dict[str, float]]:
4544
"""
46-
Compute PPV (precision, TP/(TP+FP)), NPV (TN/(TN+FN)) and the number of TNs, FPs, FNs and TPs for each class
47-
in a multi-label setting.
45+
Compute per-class evaluation metrics for a multi-label classification task.
46+
47+
This method uses torchmetrics objects (MultilabelConfusionMatrix, F1 scores, etc.)
48+
to compute the following metrics for each class:
49+
- PPV (Positive Predictive Value or Precision)
50+
- NPV (Negative Predictive Value)
51+
- True Positives (TP)
52+
- False Positives (FP)
53+
- True Negatives (TN)
54+
- False Negatives (FN)
55+
- Macro F1 score
56+
- Micro F1 score
4857
4958
Args:
50-
y_true: List of binary ground-truth label tensors, one tensor per sample.
51-
y_pred: List of binary prediction tensors, one tensor per sample.
52-
class_names: Ordered list of class names corresponding to class indices.
59+
metrics_obj_dict: Dictionary containing pre-updated torchmetrics.Metric objects:
60+
{
61+
"cm": MultilabelConfusionMatrix,
62+
"macro-f1": MacroF1 (average=None),
63+
"micro-f1": MultilabelF1Score (average=None)
64+
}
65+
class_names: List of class names in the same order as class indices.
5366
5467
Returns:
55-
Dictionary mapping each class name to its PPV and NPV metrics:
68+
Dictionary mapping each class name to a sub-dictionary of computed metrics:
5669
{
57-
"class_name": {"PPV": float, "NPV": float, "TN": int, "FP": int, "FN": int, "TP": int},
70+
"class_name_1": {
71+
"PPV": float,
72+
"NPV": float,
73+
"TN": int,
74+
"FP": int,
75+
"FN": int,
76+
"TP": int,
77+
"f1-macro": float,
78+
"f1-micro": float
79+
},
5880
...
5981
}
6082
"""
61-
# Stack per-sample tensors into (n_samples, n_classes) numpy arrays
62-
true_np = torch.stack(y_true).cpu().numpy().astype(int)
63-
pred_np = torch.stack(y_pred).cpu().numpy().astype(int)
64-
65-
# Compute confusion matrix for each class
66-
cm = multilabel_confusion_matrix(true_np, pred_np)
83+
cm_tensor = metrics_obj_dict["cm"].compute() # Shape: (num_classes, 2, 2)
84+
# shape: (num_classes,)
85+
macro_f1_tensor = metrics_obj_dict["macro-f1"].compute()
86+
micro_f1_tensor = metrics_obj_dict["micro-f1"].compute()
87+
88+
assert (
89+
len(class_names)
90+
== cm_tensor.shape[0]
91+
== micro_f1_tensor.shape[0]
92+
== macro_f1_tensor.shape[0]
93+
), (
94+
f"Mismatch between number of class names ({len(class_names)}) and metric tensor sizes: "
95+
f"confusion matrix has {cm_tensor.shape[0]}, "
96+
f"micro F1 has {micro_f1_tensor.shape[0]}, "
97+
f"macro F1 has {macro_f1_tensor.shape[0]}"
98+
)
6799

68100
results: dict[str, dict[str, float]] = {}
69101
for idx, cls_name in enumerate(class_names):
70-
tn, fp, fn, tp = cm[idx].ravel()
71-
tpv = tp / (tp + fp) if (tp + fp) > 0 else 0.0
72-
npv = tn / (tn + fn) if (tn + fn) > 0 else 0.0
102+
tn = cm_tensor[idx][0][0].item()
103+
fp = cm_tensor[idx][0][1].item()
104+
fn = cm_tensor[idx][1][0].item()
105+
tp = cm_tensor[idx][1][1].item()
106+
107+
ppv = tp / (tp + fp) if (tp + fp) > 0 else 0.0 # Precision
108+
npv = tn / (tn + fn) if (tn + fn) > 0 else 0.0 # Negative predictive value
109+
73110
# positive_raw = [p.item() for i, p in enumerate(raw_preds[:, idx]) if true_np[i, idx]]
74111
# negative_raw = [p.item() for i, p in enumerate(raw_preds[:, idx]) if not true_np[i, idx]]
112+
113+
macro_f1 = macro_f1_tensor[idx]
114+
micro_f1 = micro_f1_tensor[idx]
115+
75116
results[cls_name] = {
76-
"PPV": round(tpv, 4),
117+
"PPV": round(ppv, 4),
77118
"NPV": round(npv, 4),
78119
"TN": int(tn),
79120
"FP": int(fp),
80121
"FN": int(fn),
81122
"TP": int(tp),
123+
"f1-macro": round(macro_f1.item(), 4),
124+
"f1-micro": round(micro_f1.item(), 4),
82125
# "positive_preds": positive_raw,
83126
# "negative_preds": negative_raw,
84127
}
@@ -131,29 +174,32 @@ def generate_props(
131174
val_loader = data_module.val_dataloader()
132175
print("Running inference on validation data...")
133176

134-
y_true, y_pred = [], []
135-
raw_preds = []
177+
classes_file = Path(data_module.processed_dir_main) / "classes.txt"
178+
class_names = self.load_class_labels(classes_file)
179+
num_classes = len(class_names)
180+
metrics_obj_dict: dict[str, torchmetrics.Metric] = {
181+
"cm": MultilabelConfusionMatrix(num_labels=num_classes),
182+
"macro-f1": MacroF1(num_labels=num_classes, average=None),
183+
"micro-f1": MultilabelF1Score(num_labels=num_classes, average=None),
184+
}
185+
136186
for batch_idx, batch in enumerate(val_loader):
137-
data = model._process_batch( # pylint: disable=W0212
138-
batch, batch_idx=batch_idx
139-
)
187+
data = model._process_batch(batch, batch_idx=batch_idx)
140188
labels = data["labels"]
141-
outputs = model(data, **data.get("model_kwargs", {}))
142-
logits = outputs["logits"] if isinstance(outputs, dict) else outputs
143-
preds = torch.sigmoid(logits) > 0.5
144-
y_pred.extend(preds)
145-
y_true.extend(labels)
146-
raw_preds.extend(torch.sigmoid(logits))
147-
raw_preds = torch.stack(raw_preds)
189+
model_output = model(data, **data.get("model_kwargs", {}))
190+
preds, targets = model._get_prediction_and_labels(
191+
data, labels, model_output
192+
)
193+
for metric_obj in metrics_obj_dict.values():
194+
metric_obj.update(preds=preds, target=targets)
195+
148196
print("Computing metrics...")
149-
classes_file = Path(data_module.processed_dir_main) / "classes.txt"
150197
if output_path is None:
151198
output_file = Path(data_module.processed_dir_main) / "classes.json"
152199
else:
153200
output_file = Path(output_path)
154201

155-
class_names = self.load_class_labels(classes_file)
156-
metrics = self.compute_classwise_scores(y_true, y_pred, raw_preds, class_names)
202+
metrics = self.compute_classwise_scores(metrics_obj_dict, class_names)
157203

158204
with output_file.open("w") as f:
159205
json.dump(metrics, f, indent=2)

0 commit comments

Comments
 (0)