Skip to content

Commit 48725dd

Browse files
authored
Merge pull request #112 from ChEB-AI/feature/classwise_f1_scores
add classwise f1, data split option to evaluation script
2 parents 1d8a7c3 + a3cc197 commit 48725dd

File tree

1 file changed

+92
-42
lines changed

1 file changed

+92
-42
lines changed

chebai/result/generate_class_properties.py

Lines changed: 92 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import json
22
from pathlib import Path
3+
from typing import Literal
34

4-
import torch
5+
import torchmetrics
56
from jsonargparse import CLI
6-
from sklearn.metrics import multilabel_confusion_matrix
7+
from torchmetrics.classification import MultilabelConfusionMatrix, MultilabelF1Score
78

89
from chebai.preprocessing.datasets.base import XYBaseDataModule
910
from chebai.result.utils import (
@@ -37,55 +38,85 @@ def load_class_labels(path: Path) -> list[str]:
3738

3839
@staticmethod
3940
def compute_classwise_scores(
40-
y_true: list[torch.Tensor],
41-
y_pred: list[torch.Tensor],
42-
raw_preds: torch.Tensor,
41+
metrics_obj_dict: dict[str, torchmetrics.Metric],
4342
class_names: list[str],
4443
) -> dict[str, dict[str, float]]:
4544
"""
46-
Compute PPV (precision, TP/(TP+FP)), NPV (TN/(TN+FN)) and the number of TNs, FPs, FNs and TPs for each class
47-
in a multi-label setting.
45+
Compute per-class evaluation metrics for a multi-label classification task.
46+
47+
This method uses torchmetrics objects (MultilabelConfusionMatrix, F1 scores, etc.)
48+
to compute the following metrics for each class:
49+
- PPV (Positive Predictive Value or Precision)
50+
- NPV (Negative Predictive Value)
51+
- True Positives (TP)
52+
- False Positives (FP)
53+
- True Negatives (TN)
54+
- False Negatives (FN)
55+
- F1 score
4856
4957
Args:
50-
y_true: List of binary ground-truth label tensors, one tensor per sample.
51-
y_pred: List of binary prediction tensors, one tensor per sample.
52-
class_names: Ordered list of class names corresponding to class indices.
58+
metrics_obj_dict: Dictionary containing pre-updated torchmetrics.Metric objects:
59+
{
60+
"cm": MultilabelConfusionMatrix,
61+
"micro-f1": MultilabelF1Score (average=None)
62+
}
63+
class_names: List of class names in the same order as class indices.
5364
5465
Returns:
55-
Dictionary mapping each class name to its PPV and NPV metrics:
66+
Dictionary mapping each class name to a sub-dictionary of computed metrics:
5667
{
57-
"class_name": {"PPV": float, "NPV": float, "TN": int, "FP": int, "FN": int, "TP": int},
68+
"class_name_1": {
69+
"PPV": float,
70+
"NPV": float,
71+
"TN": int,
72+
"FP": int,
73+
"FN": int,
74+
"TP": int,
75+
"f1": float,
76+
},
5877
...
5978
}
6079
"""
61-
# Stack per-sample tensors into (n_samples, n_classes) numpy arrays
62-
true_np = torch.stack(y_true).cpu().numpy().astype(int)
63-
pred_np = torch.stack(y_pred).cpu().numpy().astype(int)
80+
cm_tensor = metrics_obj_dict["cm"].compute() # Shape: (num_classes, 2, 2)
81+
f1_tensor = metrics_obj_dict["f1"].compute() # shape: (num_classes,)
6482

65-
# Compute confusion matrix for each class
66-
cm = multilabel_confusion_matrix(true_np, pred_np)
83+
assert len(class_names) == cm_tensor.shape[0] == f1_tensor.shape[0], (
84+
f"Mismatch between number of class names ({len(class_names)}) and metric tensor sizes: "
85+
f"confusion matrix has {cm_tensor.shape[0]}, "
86+
f"F1 has {f1_tensor.shape[0]}, "
87+
)
6788

6889
results: dict[str, dict[str, float]] = {}
6990
for idx, cls_name in enumerate(class_names):
70-
tn, fp, fn, tp = cm[idx].ravel()
71-
tpv = tp / (tp + fp) if (tp + fp) > 0 else 0.0
72-
npv = tn / (tn + fn) if (tn + fn) > 0 else 0.0
91+
tn = cm_tensor[idx][0][0].item()
92+
fp = cm_tensor[idx][0][1].item()
93+
fn = cm_tensor[idx][1][0].item()
94+
tp = cm_tensor[idx][1][1].item()
95+
96+
ppv = tp / (tp + fp) if (tp + fp) > 0 else 0.0 # Precision
97+
npv = tn / (tn + fn) if (tn + fn) > 0 else 0.0 # Negative predictive value
98+
7399
# positive_raw = [p.item() for i, p in enumerate(raw_preds[:, idx]) if true_np[i, idx]]
74100
# negative_raw = [p.item() for i, p in enumerate(raw_preds[:, idx]) if not true_np[i, idx]]
101+
102+
f1 = f1_tensor[idx]
103+
75104
results[cls_name] = {
76-
"PPV": round(tpv, 4),
105+
"PPV": round(ppv, 4),
77106
"NPV": round(npv, 4),
78107
"TN": int(tn),
79108
"FP": int(fp),
80109
"FN": int(fn),
81110
"TP": int(tp),
111+
"f1": round(f1.item(), 4),
82112
# "positive_preds": positive_raw,
83113
# "negative_preds": negative_raw,
84114
}
85115
return results
86116

87117
def generate_props(
88118
self,
119+
data_partition: Literal["train", "val", "test"],
89120
model_ckpt_path: str,
90121
model_config_file_path: str,
91122
data_config_file_path: str,
@@ -95,14 +126,13 @@ def generate_props(
95126
Run inference on validation set, compute TPV/NPV per class, and save to JSON.
96127
97128
Args:
129+
data_partition: Partition of the dataset to use to generate class properties.
98130
model_ckpt_path: Path to the PyTorch Lightning checkpoint file.
99131
model_config_file_path: Path to yaml config file of the model.
100132
data_config_file_path: Path to yaml config file of the data.
101133
output_path: Optional path where to write the JSON metrics file.
102134
Defaults to '<processed_dir_main>/classes.json'.
103135
"""
104-
print("Extracting validation data for computation...")
105-
106136
data_cls_path, data_cls_kwargs = parse_config_file(data_config_file_path)
107137
data_module: XYBaseDataModule = load_data_instance(
108138
data_cls_path, data_cls_kwargs
@@ -128,32 +158,43 @@ def generate_props(
128158
model_ckpt_path, model_class_path, model_kwargs
129159
)
130160

131-
val_loader = data_module.val_dataloader()
132-
print("Running inference on validation data...")
161+
if data_partition == "train":
162+
data_loader = data_module.train_dataloader()
163+
elif data_partition == "val":
164+
data_loader = data_module.val_dataloader()
165+
elif data_partition == "test":
166+
data_loader = data_module.test_dataloader()
167+
else:
168+
raise ValueError(f"Unknown data partition: {data_partition}")
169+
print(f"Running inference on {data_partition} data...")
133170

134-
y_true, y_pred = [], []
135-
raw_preds = []
136-
for batch_idx, batch in enumerate(val_loader):
137-
data = model._process_batch( # pylint: disable=W0212
138-
batch, batch_idx=batch_idx
139-
)
171+
classes_file = Path(data_module.processed_dir_main) / "classes.txt"
172+
class_names = self.load_class_labels(classes_file)
173+
num_classes = len(class_names)
174+
metrics_obj_dict: dict[str, torchmetrics.Metric] = {
175+
"cm": MultilabelConfusionMatrix(num_labels=num_classes),
176+
"f1": MultilabelF1Score(num_labels=num_classes, average=None),
177+
}
178+
179+
for batch_idx, batch in enumerate(data_loader):
180+
data = model._process_batch(batch, batch_idx=batch_idx)
140181
labels = data["labels"]
141-
outputs = model(data, **data.get("model_kwargs", {}))
142-
logits = outputs["logits"] if isinstance(outputs, dict) else outputs
143-
preds = torch.sigmoid(logits) > 0.5
144-
y_pred.extend(preds)
145-
y_true.extend(labels)
146-
raw_preds.extend(torch.sigmoid(logits))
147-
raw_preds = torch.stack(raw_preds)
182+
model_output = model(data, **data.get("model_kwargs", {}))
183+
preds, targets = model._get_prediction_and_labels(
184+
data, labels, model_output
185+
)
186+
for metric_obj in metrics_obj_dict.values():
187+
metric_obj.update(preds, targets)
188+
148189
print("Computing metrics...")
149-
classes_file = Path(data_module.processed_dir_main) / "classes.txt"
150190
if output_path is None:
151-
output_file = Path(data_module.processed_dir_main) / "classes.json"
191+
output_file = (
192+
Path(data_module.processed_dir_main) / f"classes_{data_partition}.json"
193+
)
152194
else:
153195
output_file = Path(output_path)
154196

155-
class_names = self.load_class_labels(classes_file)
156-
metrics = self.compute_classwise_scores(y_true, y_pred, raw_preds, class_names)
197+
metrics = self.compute_classwise_scores(metrics_obj_dict, class_names)
157198

158199
with output_file.open("w") as f:
159200
json.dump(metrics, f, indent=2)
@@ -167,6 +208,7 @@ class Main:
167208

168209
def generate(
169210
self,
211+
data_partition: Literal["train", "val", "test"],
170212
model_ckpt_path: str,
171213
model_config_file_path: str,
172214
data_config_file_path: str,
@@ -176,14 +218,21 @@ def generate(
176218
CLI command to generate JSON with metrics on validation set.
177219
178220
Args:
221+
data_partition: Partition of dataset to use to generate class properties.
179222
model_ckpt_path: Path to the PyTorch Lightning checkpoint file.
180223
model_config_file_path: Path to yaml config file of the model.
181224
data_config_file_path: Path to yaml config file of the data.
182225
output_path: Optional path where to write the JSON metrics file.
183226
Defaults to '<processed_dir_main>/classes.json'.
184227
"""
228+
assert data_partition in [
229+
"train",
230+
"val",
231+
"test",
232+
], f"Given data partition invalid: {data_partition}, Choose one of the value among `train`, `val`, `test` "
185233
generator = ClassesPropertiesGenerator()
186234
generator.generate_props(
235+
data_partition,
187236
model_ckpt_path,
188237
model_config_file_path,
189238
data_config_file_path,
@@ -193,6 +242,7 @@ def generate(
193242

194243
if __name__ == "__main__":
195244
# _generate_classes_props_json.py generate \
245+
# --data_partition "val" \
196246
# --model_ckpt_path "model/ckpt/path" \
197247
# --model_config_file_path "model/config/file/path" \
198248
# --data_config_file_path "data/config/file/path" \

0 commit comments

Comments
 (0)