Skip to content

Commit bdba442

Browse files
committed
script to evaluate go predictions
1 parent 66dd504 commit bdba442

File tree

1 file changed

+70
-0
lines changed

1 file changed

+70
-0
lines changed
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import torch
2+
from jsonargparse import CLI
3+
from torchmetrics.functional.classification import multilabel_auroc
4+
5+
from chebai.result.utils import load_results_from_buffer
6+
7+
8+
class EvaluatePredictions:
9+
def __init__(self, eval_dir: str):
10+
"""
11+
Initializes the EvaluatePredictions class.
12+
13+
Args:
14+
eval_dir (str): Path to the directory containing evaluation files.
15+
"""
16+
self.eval_dir = eval_dir
17+
self.metrics = []
18+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19+
self.num_labels = None
20+
21+
@staticmethod
22+
def validate_eval_dir(label_files: torch.Tensor, pred_files: torch.Tensor) -> None:
23+
"""
24+
Validates that the number of labels matches the number of predictions,
25+
ensuring that they have the same shape.
26+
27+
Args:
28+
label_files (torch.Tensor): Tensor containing label data.
29+
pred_files (torch.Tensor): Tensor containing prediction data.
30+
31+
Raises:
32+
ValueError: If label and prediction tensors are mismatched in shape.
33+
"""
34+
if label_files is None or pred_files is None:
35+
raise ValueError("Both label and prediction tensors must be provided.")
36+
37+
# Check if the number of labels matches the number of predictions
38+
if label_files.shape[0] != pred_files.shape[0]:
39+
raise ValueError(
40+
"Number of label tensors does not match the number of prediction tensors."
41+
)
42+
43+
# Validate that the last dimension matches the expected number of classes
44+
if label_files.shape[1] != pred_files.shape[1]:
45+
raise ValueError(
46+
"Label and prediction tensors must have the same shape in terms of class outputs."
47+
)
48+
49+
def evaluate(self) -> None:
50+
"""
51+
Loads predictions and labels, validates file correspondence, and calculates Multilabel AUROC.
52+
"""
53+
test_preds, test_labels = load_results_from_buffer(self.eval_dir, self.device)
54+
self.validate_eval_dir(test_labels, test_preds)
55+
self.num_labels = test_preds.shape[1]
56+
57+
ml_auroc = multilabel_auroc(
58+
test_preds, test_labels, num_labels=self.num_labels
59+
).item()
60+
61+
print("Multilabel AUC-ROC:", ml_auroc)
62+
63+
64+
class Main:
65+
def evaluate(self, eval_dir: str):
66+
EvaluatePredictions(eval_dir).evaluate()
67+
68+
69+
if __name__ == "__main__":
70+
CLI(Main)

0 commit comments

Comments
 (0)