11import json
22from pathlib import Path
3+ from typing import Literal
34
4- import torch
5+ import torchmetrics
56from jsonargparse import CLI
6- from sklearn . metrics import multilabel_confusion_matrix
7+ from torchmetrics . classification import MultilabelConfusionMatrix , MultilabelF1Score
78
89from chebai .preprocessing .datasets .base import XYBaseDataModule
910from chebai .result .utils import (
@@ -37,55 +38,85 @@ def load_class_labels(path: Path) -> list[str]:
3738
3839 @staticmethod
3940 def compute_classwise_scores (
40- y_true : list [torch .Tensor ],
41- y_pred : list [torch .Tensor ],
42- raw_preds : torch .Tensor ,
41+ metrics_obj_dict : dict [str , torchmetrics .Metric ],
4342 class_names : list [str ],
4443 ) -> dict [str , dict [str , float ]]:
4544 """
46- Compute PPV (precision, TP/(TP+FP)), NPV (TN/(TN+FN)) and the number of TNs, FPs, FNs and TPs for each class
47- in a multi-label setting.
45+ Compute per-class evaluation metrics for a multi-label classification task.
46+
47+ This method uses torchmetrics objects (MultilabelConfusionMatrix, F1 scores, etc.)
48+ to compute the following metrics for each class:
49+ - PPV (Positive Predictive Value or Precision)
50+ - NPV (Negative Predictive Value)
51+ - True Positives (TP)
52+ - False Positives (FP)
53+ - True Negatives (TN)
54+ - False Negatives (FN)
55+ - F1 score
4856
4957 Args:
50- y_true: List of binary ground-truth label tensors, one tensor per sample.
51- y_pred: List of binary prediction tensors, one tensor per sample.
52- class_names: Ordered list of class names corresponding to class indices.
58+ metrics_obj_dict: Dictionary containing pre-updated torchmetrics.Metric objects:
59+ {
60+ "cm": MultilabelConfusionMatrix,
61+ "micro-f1": MultilabelF1Score (average=None)
62+ }
63+ class_names: List of class names in the same order as class indices.
5364
5465 Returns:
55- Dictionary mapping each class name to its PPV and NPV metrics:
66+ Dictionary mapping each class name to a sub-dictionary of computed metrics:
5667 {
57- "class_name": {"PPV": float, "NPV": float, "TN": int, "FP": int, "FN": int, "TP": int},
68+ "class_name_1": {
69+ "PPV": float,
70+ "NPV": float,
71+ "TN": int,
72+ "FP": int,
73+ "FN": int,
74+ "TP": int,
75+ "f1": float,
76+ },
5877 ...
5978 }
6079 """
61- # Stack per-sample tensors into (n_samples, n_classes) numpy arrays
62- true_np = torch .stack (y_true ).cpu ().numpy ().astype (int )
63- pred_np = torch .stack (y_pred ).cpu ().numpy ().astype (int )
80+ cm_tensor = metrics_obj_dict ["cm" ].compute () # Shape: (num_classes, 2, 2)
81+ f1_tensor = metrics_obj_dict ["f1" ].compute () # shape: (num_classes,)
6482
65- # Compute confusion matrix for each class
66- cm = multilabel_confusion_matrix (true_np , pred_np )
83+ assert len (class_names ) == cm_tensor .shape [0 ] == f1_tensor .shape [0 ], (
84+ f"Mismatch between number of class names ({ len (class_names )} ) and metric tensor sizes: "
85+ f"confusion matrix has { cm_tensor .shape [0 ]} , "
86+ f"F1 has { f1_tensor .shape [0 ]} , "
87+ )
6788
6889 results : dict [str , dict [str , float ]] = {}
6990 for idx , cls_name in enumerate (class_names ):
70- tn , fp , fn , tp = cm [idx ].ravel ()
71- tpv = tp / (tp + fp ) if (tp + fp ) > 0 else 0.0
72- npv = tn / (tn + fn ) if (tn + fn ) > 0 else 0.0
91+ tn = cm_tensor [idx ][0 ][0 ].item ()
92+ fp = cm_tensor [idx ][0 ][1 ].item ()
93+ fn = cm_tensor [idx ][1 ][0 ].item ()
94+ tp = cm_tensor [idx ][1 ][1 ].item ()
95+
96+ ppv = tp / (tp + fp ) if (tp + fp ) > 0 else 0.0 # Precision
97+ npv = tn / (tn + fn ) if (tn + fn ) > 0 else 0.0 # Negative predictive value
98+
7399 # positive_raw = [p.item() for i, p in enumerate(raw_preds[:, idx]) if true_np[i, idx]]
74100 # negative_raw = [p.item() for i, p in enumerate(raw_preds[:, idx]) if not true_np[i, idx]]
101+
102+ f1 = f1_tensor [idx ]
103+
75104 results [cls_name ] = {
76- "PPV" : round (tpv , 4 ),
105+ "PPV" : round (ppv , 4 ),
77106 "NPV" : round (npv , 4 ),
78107 "TN" : int (tn ),
79108 "FP" : int (fp ),
80109 "FN" : int (fn ),
81110 "TP" : int (tp ),
111+ "f1" : round (f1 .item (), 4 ),
82112 # "positive_preds": positive_raw,
83113 # "negative_preds": negative_raw,
84114 }
85115 return results
86116
87117 def generate_props (
88118 self ,
119+ data_partition : Literal ["train" , "val" , "test" ],
89120 model_ckpt_path : str ,
90121 model_config_file_path : str ,
91122 data_config_file_path : str ,
@@ -95,14 +126,13 @@ def generate_props(
95126 Run inference on validation set, compute TPV/NPV per class, and save to JSON.
96127
97128 Args:
129+ data_partition: Partition of the dataset to use to generate class properties.
98130 model_ckpt_path: Path to the PyTorch Lightning checkpoint file.
99131 model_config_file_path: Path to yaml config file of the model.
100132 data_config_file_path: Path to yaml config file of the data.
101133 output_path: Optional path where to write the JSON metrics file.
102134 Defaults to '<processed_dir_main>/classes.json'.
103135 """
104- print ("Extracting validation data for computation..." )
105-
106136 data_cls_path , data_cls_kwargs = parse_config_file (data_config_file_path )
107137 data_module : XYBaseDataModule = load_data_instance (
108138 data_cls_path , data_cls_kwargs
@@ -128,32 +158,43 @@ def generate_props(
128158 model_ckpt_path , model_class_path , model_kwargs
129159 )
130160
131- val_loader = data_module .val_dataloader ()
132- print ("Running inference on validation data..." )
161+ if data_partition == "train" :
162+ data_loader = data_module .train_dataloader ()
163+ elif data_partition == "val" :
164+ data_loader = data_module .val_dataloader ()
165+ elif data_partition == "test" :
166+ data_loader = data_module .test_dataloader ()
167+ else :
168+ raise ValueError (f"Unknown data partition: { data_partition } " )
169+ print (f"Running inference on { data_partition } data..." )
133170
134- y_true , y_pred = [], []
135- raw_preds = []
136- for batch_idx , batch in enumerate (val_loader ):
137- data = model ._process_batch ( # pylint: disable=W0212
138- batch , batch_idx = batch_idx
139- )
171+ classes_file = Path (data_module .processed_dir_main ) / "classes.txt"
172+ class_names = self .load_class_labels (classes_file )
173+ num_classes = len (class_names )
174+ metrics_obj_dict : dict [str , torchmetrics .Metric ] = {
175+ "cm" : MultilabelConfusionMatrix (num_labels = num_classes ),
176+ "f1" : MultilabelF1Score (num_labels = num_classes , average = None ),
177+ }
178+
179+ for batch_idx , batch in enumerate (data_loader ):
180+ data = model ._process_batch (batch , batch_idx = batch_idx )
140181 labels = data ["labels" ]
141- outputs = model (data , ** data .get ("model_kwargs" , {}))
142- logits = outputs [ "logits" ] if isinstance ( outputs , dict ) else outputs
143- preds = torch . sigmoid ( logits ) > 0.5
144- y_pred . extend ( preds )
145- y_true . extend ( labels )
146- raw_preds . extend ( torch . sigmoid ( logits ) )
147- raw_preds = torch . stack ( raw_preds )
182+ model_output = model (data , ** data .get ("model_kwargs" , {}))
183+ preds , targets = model . _get_prediction_and_labels (
184+ data , labels , model_output
185+ )
186+ for metric_obj in metrics_obj_dict . values ():
187+ metric_obj . update ( preds , targets )
188+
148189 print ("Computing metrics..." )
149- classes_file = Path (data_module .processed_dir_main ) / "classes.txt"
150190 if output_path is None :
151- output_file = Path (data_module .processed_dir_main ) / "classes.json"
191+ output_file = (
192+ Path (data_module .processed_dir_main ) / f"classes_{ data_partition } .json"
193+ )
152194 else :
153195 output_file = Path (output_path )
154196
155- class_names = self .load_class_labels (classes_file )
156- metrics = self .compute_classwise_scores (y_true , y_pred , raw_preds , class_names )
197+ metrics = self .compute_classwise_scores (metrics_obj_dict , class_names )
157198
158199 with output_file .open ("w" ) as f :
159200 json .dump (metrics , f , indent = 2 )
@@ -167,6 +208,7 @@ class Main:
167208
168209 def generate (
169210 self ,
211+ data_partition : Literal ["train" , "val" , "test" ],
170212 model_ckpt_path : str ,
171213 model_config_file_path : str ,
172214 data_config_file_path : str ,
@@ -176,14 +218,21 @@ def generate(
176218 CLI command to generate JSON with metrics on validation set.
177219
178220 Args:
221+ data_partition: Partition of dataset to use to generate class properties.
179222 model_ckpt_path: Path to the PyTorch Lightning checkpoint file.
180223 model_config_file_path: Path to yaml config file of the model.
181224 data_config_file_path: Path to yaml config file of the data.
182225 output_path: Optional path where to write the JSON metrics file.
183226 Defaults to '<processed_dir_main>/classes.json'.
184227 """
228+ assert data_partition in [
229+ "train" ,
230+ "val" ,
231+ "test" ,
232+ ], f"Given data partition invalid: { data_partition } , Choose one of the value among `train`, `val`, `test` "
185233 generator = ClassesPropertiesGenerator ()
186234 generator .generate_props (
235+ data_partition ,
187236 model_ckpt_path ,
188237 model_config_file_path ,
189238 data_config_file_path ,
@@ -193,6 +242,7 @@ def generate(
193242
194243if __name__ == "__main__" :
195244 # _generate_classes_props_json.py generate \
245+ # --data_partition "val" \
196246 # --model_ckpt_path "model/ckpt/path" \
197247 # --model_config_file_path "model/config/file/path" \
198248 # --data_config_file_path "data/config/file/path" \
0 commit comments