1515
1616class ClassesPropertiesGenerator :
1717 """
18- Computes PPV (Positive Predictive Value) and NPV (Negative Predictive Value)
18+ Computes PPV (Positive Predictive Value) and NPV (Negative Predictive Value) and counts the number of
19+ true positives (TP), false positives (FP), true negatives (TN), and false negatives (FN)
1920 for each class in a multi-label classification problem using a PyTorch Lightning model.
2021 """
2122
@@ -35,23 +36,25 @@ def load_class_labels(path: Path) -> list[str]:
3536 return [line .strip () for line in f if line .strip ()]
3637
3738 @staticmethod
38- def compute_tpv_npv (
39+ def compute_classwise_scores (
3940 y_true : list [torch .Tensor ],
4041 y_pred : list [torch .Tensor ],
42+ raw_preds : torch .Tensor ,
4143 class_names : list [str ],
4244 ) -> dict [str , dict [str , float ]]:
4345 """
44- Compute TPV (precision) and NPV for each class in a multi-label setting.
46+ Compute PPV (precision, TP/(TP+FP)), NPV (TN/(TN+FN)) and the number of TNs, FPs, FNs and TPs for each class
47+ in a multi-label setting.
4548
4649 Args:
4750 y_true: List of binary ground-truth label tensors, one tensor per sample.
4851 y_pred: List of binary prediction tensors, one tensor per sample.
4952 class_names: Ordered list of class names corresponding to class indices.
5053
5154 Returns:
52- Dictionary mapping each class name to its TPV and NPV metrics:
55+ Dictionary mapping each class name to its PPV and NPV metrics:
5356 {
54- "class_name": {"PPV": float, "NPV": float},
57+ "class_name": {"PPV": float, "NPV": float, "TN": int, "FP": int, "FN": int, "TP": int },
5558 ...
5659 }
5760 """
@@ -67,13 +70,17 @@ def compute_tpv_npv(
6770 tn , fp , fn , tp = cm [idx ].ravel ()
6871 tpv = tp / (tp + fp ) if (tp + fp ) > 0 else 0.0
6972 npv = tn / (tn + fn ) if (tn + fn ) > 0 else 0.0
73+ # positive_raw = [p.item() for i, p in enumerate(raw_preds[:, idx]) if true_np[i, idx]]
74+ # negative_raw = [p.item() for i, p in enumerate(raw_preds[:, idx]) if not true_np[i, idx]]
7075 results [cls_name ] = {
7176 "PPV" : round (tpv , 4 ),
7277 "NPV" : round (npv , 4 ),
7378 "TN" : int (tn ),
7479 "FP" : int (fp ),
7580 "FN" : int (fn ),
7681 "TP" : int (tp ),
82+ # "positive_preds": positive_raw,
83+ # "negative_preds": negative_raw,
7784 }
7885 return results
7986
@@ -125,6 +132,7 @@ def generate_props(
125132 print ("Running inference on validation data..." )
126133
127134 y_true , y_pred = [], []
135+ raw_preds = []
128136 for batch_idx , batch in enumerate (val_loader ):
129137 data = model ._process_batch ( # pylint: disable=W0212
130138 batch , batch_idx = batch_idx
@@ -135,20 +143,21 @@ def generate_props(
135143 preds = torch .sigmoid (logits ) > 0.5
136144 y_pred .extend (preds )
137145 y_true .extend (labels )
138-
139- print ("Computing TPV and NPV metrics..." )
146+ raw_preds .extend (torch .sigmoid (logits ))
147+ raw_preds = torch .stack (raw_preds )
148+ print ("Computing metrics..." )
140149 classes_file = Path (data_module .processed_dir_main ) / "classes.txt"
141150 if output_path is None :
142151 output_file = Path (data_module .processed_dir_main ) / "classes.json"
143152 else :
144153 output_file = Path (output_path )
145154
146155 class_names = self .load_class_labels (classes_file )
147- metrics = self .compute_tpv_npv (y_true , y_pred , class_names )
156+ metrics = self .compute_classwise_scores (y_true , y_pred , raw_preds , class_names )
148157
149158 with output_file .open ("w" ) as f :
150159 json .dump (metrics , f , indent = 2 )
151- print (f"Saved TPV/NPV metrics to { output_file } " )
160+ print (f"Saved metrics to { output_file } " )
152161
153162
154163class Main :
@@ -164,7 +173,7 @@ def generate(
164173 output_path : str | None = None ,
165174 ) -> None :
166175 """
167- CLI command to generate TPV/NPV JSON.
176+ CLI command to generate JSON with metrics on validation set .
168177
169178 Args:
170179 model_ckpt_path: Path to the PyTorch Lightning checkpoint file.
0 commit comments