|
1 | 1 | import json |
2 | 2 | from pathlib import Path |
3 | 3 |
|
4 | | -import torch |
| 4 | +import torchmetrics |
5 | 5 | from jsonargparse import CLI |
6 | | -from sklearn.metrics import multilabel_confusion_matrix |
| 6 | +from torchmetrics.classification import MultilabelConfusionMatrix, MultilabelF1Score |
7 | 7 |
|
| 8 | +from chebai.callbacks.epoch_metrics import MacroF1 |
8 | 9 | from chebai.preprocessing.datasets.base import XYBaseDataModule |
9 | 10 | from chebai.result.utils import ( |
10 | 11 | load_data_instance, |
@@ -37,48 +38,90 @@ def load_class_labels(path: Path) -> list[str]: |
37 | 38 |
|
38 | 39 | @staticmethod |
39 | 40 | def compute_classwise_scores( |
40 | | - y_true: list[torch.Tensor], |
41 | | - y_pred: list[torch.Tensor], |
42 | | - raw_preds: torch.Tensor, |
| 41 | + metrics_obj_dict: dict[str, torchmetrics.Metric], |
43 | 42 | class_names: list[str], |
44 | 43 | ) -> dict[str, dict[str, float]]: |
45 | 44 | """ |
46 | | - Compute PPV (precision, TP/(TP+FP)), NPV (TN/(TN+FN)) and the number of TNs, FPs, FNs and TPs for each class |
47 | | - in a multi-label setting. |
| 45 | + Compute per-class evaluation metrics for a multi-label classification task. |
| 46 | +
|
| 47 | + This method uses torchmetrics objects (MultilabelConfusionMatrix, F1 scores, etc.) |
| 48 | + to compute the following metrics for each class: |
| 49 | + - PPV (Positive Predictive Value or Precision) |
| 50 | + - NPV (Negative Predictive Value) |
| 51 | + - True Positives (TP) |
| 52 | + - False Positives (FP) |
| 53 | + - True Negatives (TN) |
| 54 | + - False Negatives (FN) |
| 55 | + - Macro F1 score |
| 56 | + - Micro F1 score |
48 | 57 |
|
49 | 58 | Args: |
50 | | - y_true: List of binary ground-truth label tensors, one tensor per sample. |
51 | | - y_pred: List of binary prediction tensors, one tensor per sample. |
52 | | - class_names: Ordered list of class names corresponding to class indices. |
| 59 | + metrics_obj_dict: Dictionary containing pre-updated torchmetrics.Metric objects: |
| 60 | + { |
| 61 | + "cm": MultilabelConfusionMatrix, |
| 62 | + "macro-f1": MacroF1 (average=None), |
| 63 | + "micro-f1": MultilabelF1Score (average=None) |
| 64 | + } |
| 65 | + class_names: List of class names in the same order as class indices. |
53 | 66 |
|
54 | 67 | Returns: |
55 | | - Dictionary mapping each class name to its PPV and NPV metrics: |
| 68 | + Dictionary mapping each class name to a sub-dictionary of computed metrics: |
56 | 69 | { |
57 | | - "class_name": {"PPV": float, "NPV": float, "TN": int, "FP": int, "FN": int, "TP": int}, |
| 70 | + "class_name_1": { |
| 71 | + "PPV": float, |
| 72 | + "NPV": float, |
| 73 | + "TN": int, |
| 74 | + "FP": int, |
| 75 | + "FN": int, |
| 76 | + "TP": int, |
| 77 | + "f1-macro": float, |
| 78 | + "f1-micro": float |
| 79 | + }, |
58 | 80 | ... |
59 | 81 | } |
60 | 82 | """ |
61 | | - # Stack per-sample tensors into (n_samples, n_classes) numpy arrays |
62 | | - true_np = torch.stack(y_true).cpu().numpy().astype(int) |
63 | | - pred_np = torch.stack(y_pred).cpu().numpy().astype(int) |
64 | | - |
65 | | - # Compute confusion matrix for each class |
66 | | - cm = multilabel_confusion_matrix(true_np, pred_np) |
| 83 | + cm_tensor = metrics_obj_dict["cm"].compute() # Shape: (num_classes, 2, 2) |
| 84 | + # shape: (num_classes,) |
| 85 | + macro_f1_tensor = metrics_obj_dict["macro-f1"].compute() |
| 86 | + micro_f1_tensor = metrics_obj_dict["micro-f1"].compute() |
| 87 | + |
| 88 | + assert ( |
| 89 | + len(class_names) |
| 90 | + == cm_tensor.shape[0] |
| 91 | + == micro_f1_tensor.shape[0] |
| 92 | + == macro_f1_tensor.shape[0] |
| 93 | + ), ( |
| 94 | + f"Mismatch between number of class names ({len(class_names)}) and metric tensor sizes: " |
| 95 | + f"confusion matrix has {cm_tensor.shape[0]}, " |
| 96 | + f"micro F1 has {micro_f1_tensor.shape[0]}, " |
| 97 | + f"macro F1 has {macro_f1_tensor.shape[0]}" |
| 98 | + ) |
67 | 99 |
|
68 | 100 | results: dict[str, dict[str, float]] = {} |
69 | 101 | for idx, cls_name in enumerate(class_names): |
70 | | - tn, fp, fn, tp = cm[idx].ravel() |
71 | | - tpv = tp / (tp + fp) if (tp + fp) > 0 else 0.0 |
72 | | - npv = tn / (tn + fn) if (tn + fn) > 0 else 0.0 |
| 102 | + tn = cm_tensor[idx][0][0].item() |
| 103 | + fp = cm_tensor[idx][0][1].item() |
| 104 | + fn = cm_tensor[idx][1][0].item() |
| 105 | + tp = cm_tensor[idx][1][1].item() |
| 106 | + |
| 107 | + ppv = tp / (tp + fp) if (tp + fp) > 0 else 0.0 # Precision |
| 108 | + npv = tn / (tn + fn) if (tn + fn) > 0 else 0.0 # Negative predictive value |
| 109 | + |
73 | 110 | # positive_raw = [p.item() for i, p in enumerate(raw_preds[:, idx]) if true_np[i, idx]] |
74 | 111 | # negative_raw = [p.item() for i, p in enumerate(raw_preds[:, idx]) if not true_np[i, idx]] |
| 112 | + |
| 113 | + macro_f1 = macro_f1_tensor[idx] |
| 114 | + micro_f1 = micro_f1_tensor[idx] |
| 115 | + |
75 | 116 | results[cls_name] = { |
76 | | - "PPV": round(tpv, 4), |
| 117 | + "PPV": round(ppv, 4), |
77 | 118 | "NPV": round(npv, 4), |
78 | 119 | "TN": int(tn), |
79 | 120 | "FP": int(fp), |
80 | 121 | "FN": int(fn), |
81 | 122 | "TP": int(tp), |
| 123 | + "f1-macro": round(macro_f1.item(), 4), |
| 124 | + "f1-micro": round(micro_f1.item(), 4), |
82 | 125 | # "positive_preds": positive_raw, |
83 | 126 | # "negative_preds": negative_raw, |
84 | 127 | } |
@@ -131,29 +174,32 @@ def generate_props( |
131 | 174 | val_loader = data_module.val_dataloader() |
132 | 175 | print("Running inference on validation data...") |
133 | 176 |
|
134 | | - y_true, y_pred = [], [] |
135 | | - raw_preds = [] |
| 177 | + classes_file = Path(data_module.processed_dir_main) / "classes.txt" |
| 178 | + class_names = self.load_class_labels(classes_file) |
| 179 | + num_classes = len(class_names) |
| 180 | + metrics_obj_dict: dict[str, torchmetrics.Metric] = { |
| 181 | + "cm": MultilabelConfusionMatrix(num_labels=num_classes), |
| 182 | + "macro-f1": MacroF1(num_labels=num_classes, average=None), |
| 183 | + "micro-f1": MultilabelF1Score(num_labels=num_classes, average=None), |
| 184 | + } |
| 185 | + |
136 | 186 | for batch_idx, batch in enumerate(val_loader): |
137 | | - data = model._process_batch( # pylint: disable=W0212 |
138 | | - batch, batch_idx=batch_idx |
139 | | - ) |
| 187 | + data = model._process_batch(batch, batch_idx=batch_idx) |
140 | 188 | labels = data["labels"] |
141 | | - outputs = model(data, **data.get("model_kwargs", {})) |
142 | | - logits = outputs["logits"] if isinstance(outputs, dict) else outputs |
143 | | - preds = torch.sigmoid(logits) > 0.5 |
144 | | - y_pred.extend(preds) |
145 | | - y_true.extend(labels) |
146 | | - raw_preds.extend(torch.sigmoid(logits)) |
147 | | - raw_preds = torch.stack(raw_preds) |
| 189 | + model_output = model(data, **data.get("model_kwargs", {})) |
| 190 | + preds, targets = model._get_prediction_and_labels( |
| 191 | + data, labels, model_output |
| 192 | + ) |
| 193 | + for metric_obj in metrics_obj_dict.values(): |
| 194 | + metric_obj.update(preds=preds, target=targets) |
| 195 | + |
148 | 196 | print("Computing metrics...") |
149 | | - classes_file = Path(data_module.processed_dir_main) / "classes.txt" |
150 | 197 | if output_path is None: |
151 | 198 | output_file = Path(data_module.processed_dir_main) / "classes.json" |
152 | 199 | else: |
153 | 200 | output_file = Path(output_path) |
154 | 201 |
|
155 | | - class_names = self.load_class_labels(classes_file) |
156 | | - metrics = self.compute_classwise_scores(y_true, y_pred, raw_preds, class_names) |
| 202 | + metrics = self.compute_classwise_scores(metrics_obj_dict, class_names) |
157 | 203 |
|
158 | 204 | with output_file.open("w") as f: |
159 | 205 | json.dump(metrics, f, indent=2) |
|
0 commit comments