|
| 1 | +import torch |
| 2 | +from jsonargparse import CLI |
| 3 | +from torchmetrics.functional.classification import multilabel_auroc |
| 4 | + |
| 5 | +from chebai.result.utils import load_results_from_buffer |
| 6 | + |
| 7 | + |
| 8 | +class EvaluatePredictions: |
| 9 | + def __init__(self, eval_dir: str): |
| 10 | + """ |
| 11 | + Initializes the EvaluatePredictions class. |
| 12 | +
|
| 13 | + Args: |
| 14 | + eval_dir (str): Path to the directory containing evaluation files. |
| 15 | + """ |
| 16 | + self.eval_dir = eval_dir |
| 17 | + self.metrics = [] |
| 18 | + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| 19 | + self.num_labels = None |
| 20 | + |
| 21 | + @staticmethod |
| 22 | + def validate_eval_dir(label_files: torch.Tensor, pred_files: torch.Tensor) -> None: |
| 23 | + """ |
| 24 | + Validates that the number of labels matches the number of predictions, |
| 25 | + ensuring that they have the same shape. |
| 26 | +
|
| 27 | + Args: |
| 28 | + label_files (torch.Tensor): Tensor containing label data. |
| 29 | + pred_files (torch.Tensor): Tensor containing prediction data. |
| 30 | +
|
| 31 | + Raises: |
| 32 | + ValueError: If label and prediction tensors are mismatched in shape. |
| 33 | + """ |
| 34 | + if label_files is None or pred_files is None: |
| 35 | + raise ValueError("Both label and prediction tensors must be provided.") |
| 36 | + |
| 37 | + # Check if the number of labels matches the number of predictions |
| 38 | + if label_files.shape[0] != pred_files.shape[0]: |
| 39 | + raise ValueError( |
| 40 | + "Number of label tensors does not match the number of prediction tensors." |
| 41 | + ) |
| 42 | + |
| 43 | + # Validate that the last dimension matches the expected number of classes |
| 44 | + if label_files.shape[1] != pred_files.shape[1]: |
| 45 | + raise ValueError( |
| 46 | + "Label and prediction tensors must have the same shape in terms of class outputs." |
| 47 | + ) |
| 48 | + |
| 49 | + def evaluate(self) -> None: |
| 50 | + """ |
| 51 | + Loads predictions and labels, validates file correspondence, and calculates Multilabel AUROC. |
| 52 | + """ |
| 53 | + test_preds, test_labels = load_results_from_buffer(self.eval_dir, self.device) |
| 54 | + self.validate_eval_dir(test_labels, test_preds) |
| 55 | + self.num_labels = test_preds.shape[1] |
| 56 | + |
| 57 | + ml_auroc = multilabel_auroc( |
| 58 | + test_preds, test_labels, num_labels=self.num_labels |
| 59 | + ).item() |
| 60 | + |
| 61 | + print("Multilabel AUC-ROC:", ml_auroc) |
| 62 | + |
| 63 | + |
| 64 | +class Main: |
| 65 | + def evaluate(self, eval_dir: str): |
| 66 | + EvaluatePredictions(eval_dir).evaluate() |
| 67 | + |
| 68 | + |
| 69 | +if __name__ == "__main__": |
| 70 | + CLI(Main) |
0 commit comments