Skip to content

Commit 5e0b683

Browse files
committed
update documentation for generate_class_properties.py
1 parent 7668858 commit 5e0b683

File tree

1 file changed

+19
-10
lines changed

1 file changed

+19
-10
lines changed

chebai/result/_generate_classes_props_json.py renamed to chebai/result/generate_class_properties.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515

1616
class ClassesPropertiesGenerator:
1717
"""
18-
Computes PPV (Positive Predictive Value) and NPV (Negative Predictive Value)
18+
Computes PPV (Positive Predictive Value) and NPV (Negative Predictive Value) and counts the number of
19+
true positives (TP), false positives (FP), true negatives (TN), and false negatives (FN)
1920
for each class in a multi-label classification problem using a PyTorch Lightning model.
2021
"""
2122

@@ -35,23 +36,25 @@ def load_class_labels(path: Path) -> list[str]:
3536
return [line.strip() for line in f if line.strip()]
3637

3738
@staticmethod
38-
def compute_tpv_npv(
39+
def compute_classwise_scores(
3940
y_true: list[torch.Tensor],
4041
y_pred: list[torch.Tensor],
42+
raw_preds: torch.Tensor,
4143
class_names: list[str],
4244
) -> dict[str, dict[str, float]]:
4345
"""
44-
Compute TPV (precision) and NPV for each class in a multi-label setting.
46+
Compute PPV (precision, TP/(TP+FP)), NPV (TN/(TN+FN)) and the number of TNs, FPs, FNs and TPs for each class
47+
in a multi-label setting.
4548
4649
Args:
4750
y_true: List of binary ground-truth label tensors, one tensor per sample.
4851
y_pred: List of binary prediction tensors, one tensor per sample.
4952
class_names: Ordered list of class names corresponding to class indices.
5053
5154
Returns:
52-
Dictionary mapping each class name to its TPV and NPV metrics:
55+
Dictionary mapping each class name to its PPV and NPV metrics:
5356
{
54-
"class_name": {"PPV": float, "NPV": float},
57+
"class_name": {"PPV": float, "NPV": float, "TN": int, "FP": int, "FN": int, "TP": int},
5558
...
5659
}
5760
"""
@@ -67,13 +70,17 @@ def compute_tpv_npv(
6770
tn, fp, fn, tp = cm[idx].ravel()
6871
tpv = tp / (tp + fp) if (tp + fp) > 0 else 0.0
6972
npv = tn / (tn + fn) if (tn + fn) > 0 else 0.0
73+
# positive_raw = [p.item() for i, p in enumerate(raw_preds[:, idx]) if true_np[i, idx]]
74+
# negative_raw = [p.item() for i, p in enumerate(raw_preds[:, idx]) if not true_np[i, idx]]
7075
results[cls_name] = {
7176
"PPV": round(tpv, 4),
7277
"NPV": round(npv, 4),
7378
"TN": int(tn),
7479
"FP": int(fp),
7580
"FN": int(fn),
7681
"TP": int(tp),
82+
# "positive_preds": positive_raw,
83+
# "negative_preds": negative_raw,
7784
}
7885
return results
7986

@@ -125,6 +132,7 @@ def generate_props(
125132
print("Running inference on validation data...")
126133

127134
y_true, y_pred = [], []
135+
raw_preds = []
128136
for batch_idx, batch in enumerate(val_loader):
129137
data = model._process_batch( # pylint: disable=W0212
130138
batch, batch_idx=batch_idx
@@ -135,20 +143,21 @@ def generate_props(
135143
preds = torch.sigmoid(logits) > 0.5
136144
y_pred.extend(preds)
137145
y_true.extend(labels)
138-
139-
print("Computing TPV and NPV metrics...")
146+
raw_preds.extend(torch.sigmoid(logits))
147+
raw_preds = torch.stack(raw_preds)
148+
print("Computing metrics...")
140149
classes_file = Path(data_module.processed_dir_main) / "classes.txt"
141150
if output_path is None:
142151
output_file = Path(data_module.processed_dir_main) / "classes.json"
143152
else:
144153
output_file = Path(output_path)
145154

146155
class_names = self.load_class_labels(classes_file)
147-
metrics = self.compute_tpv_npv(y_true, y_pred, class_names)
156+
metrics = self.compute_classwise_scores(y_true, y_pred, raw_preds, class_names)
148157

149158
with output_file.open("w") as f:
150159
json.dump(metrics, f, indent=2)
151-
print(f"Saved TPV/NPV metrics to {output_file}")
160+
print(f"Saved metrics to {output_file}")
152161

153162

154163
class Main:
@@ -164,7 +173,7 @@ def generate(
164173
output_path: str | None = None,
165174
) -> None:
166175
"""
167-
CLI command to generate TPV/NPV JSON.
176+
CLI command to generate JSON with metrics on validation set.
168177
169178
Args:
170179
model_ckpt_path: Path to the PyTorch Lightning checkpoint file.

0 commit comments

Comments
 (0)