Skip to content

Commit 9fc5d20

Browse files
committed
script to generate classes props
1 parent c0cb6c9 commit 9fc5d20

File tree

1 file changed

+195
-0
lines changed

1 file changed

+195
-0
lines changed
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
"""Generate TPV/NPV JSON for multi-class classification models."""
2+
3+
import json
4+
from pathlib import Path
5+
6+
import pandas as pd
7+
import torch
8+
from jsonargparse import CLI
9+
from sklearn.metrics import confusion_matrix
10+
from torch.utils.data import DataLoader
11+
12+
from chebai.ensemble._utils import load_class
13+
from chebai.models.base import ChebaiBaseNet
14+
from chebai.preprocessing.collate import Collator
15+
16+
17+
class ClassesPropertiesGenerator:
18+
"""
19+
Computes TPV (Precision/ True Predictive Value) and NPV (Negative Predictive Value)
20+
for each class in a multi-class classification problem using a PyTorch Lightning model.
21+
"""
22+
23+
@staticmethod
24+
def load_class_labels(path: str) -> list[str]:
25+
"""
26+
Load a list of class names from a .json or .txt file.
27+
28+
Args:
29+
path (str): Path to class labels file.
30+
31+
Returns:
32+
list[str]: List of class names.
33+
"""
34+
with open(path) as f:
35+
return [line.strip() for line in f if line.strip()]
36+
37+
@staticmethod
38+
def compute_tpv_npv(
39+
y_true: list[int], y_pred: list[int], class_names: list[str]
40+
) -> dict[str, dict[str, float]]:
41+
"""
42+
Compute TPV and NPV for each class using the confusion matrix.
43+
44+
Args:
45+
y_true (list[int]): Ground truth labels.
46+
y_pred (list[int]): Predicted labels.
47+
class_names (list[str]): List of class names corresponding to class indices.
48+
49+
Returns:
50+
dict[str, dict[str, float]]: Dictionary with class names as keys and TPV/NPV as values.
51+
"""
52+
cm = confusion_matrix(y_true, y_pred, labels=list(range(len(class_names))))
53+
metrics = {}
54+
55+
for i, cls in enumerate(class_names):
56+
TP = cm[i, i]
57+
FP = cm[:, i].sum() - TP
58+
FN = cm[i, :].sum() - TP
59+
TN = cm.sum() - (TP + FP + FN)
60+
61+
TPV = TP / (TP + FP) if (TP + FP) > 0 else 0.0
62+
NPV = TN / (TN + FN) if (TN + FN) > 0 else 0.0
63+
64+
metrics[cls] = {"TPV": round(TPV, 4), "NPV": round(NPV, 4)}
65+
66+
return metrics
67+
68+
def generate_props(
69+
self,
70+
model_path: str,
71+
model_class_path: str,
72+
splits_path: str,
73+
data_path: str,
74+
classes_file_path: str,
75+
collator_class_path: str,
76+
output_path: str,
77+
batch_size: int = 32,
78+
) -> None:
79+
"""
80+
Main method to compute TPV/NPV from validation data and save as JSON.
81+
82+
Args:
83+
model_path (str): Path to the PyTorch Lightning model checkpoint.
84+
model_class_path (str): Full path to the model class to load.
85+
splits_path (str): CSV file with 'id' and 'split' columns.
86+
data_path (str): processed `data.pt` file path.
87+
classes_file_path (str): Path to file containing class names `classes.txt`.
88+
collator_class_path (str): Full path to the collator class.
89+
output_path (str): Output path for the saving JSON file.
90+
batch_size (int): Batch size for inference.
91+
"""
92+
print("Extracting validation data for computation...")
93+
splits_df = pd.read_csv(splits_path)
94+
validation_ids = set(splits_df[splits_df["split"] == "validation"]["id"])
95+
data_df = pd.DataFrame(torch.load(data_path, weights_only=False))
96+
val_df = data_df[data_df["ident"].isin(validation_ids)]
97+
98+
# Load model
99+
print(f"Loading model from {model_path} ...")
100+
model_cls = load_class(model_class_path)
101+
if not issubclass(model_cls, ChebaiBaseNet):
102+
raise TypeError("Loaded model is not a valid LightningModule.")
103+
model = model_cls.load_from_checkpoint(model_path, input_dim=3)
104+
model.freeze()
105+
model.eval()
106+
107+
# Load collator
108+
collator_cls = load_class(collator_class_path)
109+
if not issubclass(collator_cls, Collator):
110+
raise TypeError(f"{collator_cls} must be subclass of Collator")
111+
collator = collator_cls()
112+
113+
val_loader = DataLoader(
114+
val_df.to_dict(orient="records"),
115+
collate_fn=collator,
116+
batch_size=batch_size,
117+
shuffle=False,
118+
)
119+
120+
print("Running inference on validation data...")
121+
y_true, y_pred = [], []
122+
for batch_idx, batch in enumerate(val_loader):
123+
data = model._process_batch(batch, batch_idx=batch_idx)
124+
labels = data["labels"]
125+
model_output = model(data, **data.get("model_kwargs", dict()))
126+
sigmoid_logits = torch.sigmoid(model_output["logits"])
127+
preds = sigmoid_logits > 0.5
128+
y_pred.extend(preds)
129+
y_true.extend(labels)
130+
131+
# Compute and save metrics
132+
print("Computing TPV and NPV metrics...")
133+
classes_file_path = Path(classes_file_path)
134+
if output_path is None:
135+
output_path = classes_file_path.parent / "classes.json"
136+
class_names = self.load_class_labels(classes_file_path)
137+
metrics = self.compute_tpv_npv(y_true, y_pred, class_names)
138+
with open(output_path, "w") as f:
139+
json.dump(metrics, f, indent=2)
140+
print(f"Saved TPV/NPV metrics to {output_path}")
141+
142+
143+
class Main:
144+
"""
145+
Command-line interface wrapper for the ClassesPropertiesGenerator.
146+
"""
147+
148+
def generate(
149+
self,
150+
model_path: str,
151+
splits_path: str,
152+
data_path: str,
153+
classes_file_path: str,
154+
model_class_path: str,
155+
collator_class_path: str = "chebai.preprocessing.collate.RaggedCollator",
156+
batch_size: int = 32,
157+
output_path: str = None, # Default path will be the directory of classes_file_path
158+
) -> None:
159+
"""
160+
Entry point for CLI use.
161+
162+
Args:
163+
model_path (str): Path to the PyTorch Lightning model checkpoint.
164+
model_class_path (str): Full path to the model class to load.
165+
splits_path (str): CSV file with 'id' and 'split' columns.
166+
data_path (str): processed `data.pt` file path.
167+
classes_file_path (str): Path to file containing class names `classes.txt`.
168+
collator_class_path (str): Full path to the collator class.
169+
output_path (str): Output path for the saving JSON file.
170+
batch_size (int): Batch size for inference.
171+
"""
172+
generator = ClassesPropertiesGenerator()
173+
generator.generate_props(
174+
model_path=model_path,
175+
model_class_path=model_class_path,
176+
splits_path=splits_path,
177+
data_path=data_path,
178+
classes_file_path=classes_file_path,
179+
collator_class_path=collator_class_path,
180+
output_path=output_path,
181+
batch_size=batch_size,
182+
)
183+
184+
185+
if __name__ == "__main__":
186+
# _generate_classes_props_json.py generate \
187+
# --model_path "model/ckpt/path" \
188+
# --splits_path "splits/file/path" \
189+
# --data_path "data.pt/file/path" \
190+
# --classes_file_path "classes/file/path" \
191+
# --model_class_path "model.class.path" \
192+
# --collator_class_path "collator.class.path" \
193+
# --batch_size 32 \ # Optional, default is 32
194+
# --output_path "output/file/path" # Optional, default will be the directory of classes_file_path
195+
CLI(Main, as_positional=False)

0 commit comments

Comments
 (0)