|
16 | 16 |
|
17 | 17 | from typing import Any, List, Tuple |
18 | 18 |
|
| 19 | +import torch |
19 | 20 | from lm_eval.api import metrics # noqa: F401 # needed to register lm-eval metrics |
20 | 21 | from lm_eval.api import registry as lm_registry |
21 | 22 |
|
22 | 23 | from pruna.evaluation.metrics.metric_stateful import StatefulMetric |
23 | 24 | from pruna.evaluation.metrics.registry import MetricRegistry |
24 | 25 | from pruna.evaluation.metrics.result import MetricResult |
| 26 | +from pruna.evaluation.metrics.utils import metric_data_processor |
25 | 27 | from pruna.logging.logger import pruna_logger |
26 | 28 |
|
27 | 29 | METRIC_EVALHARNESS = "lm_eval_metric" |
@@ -71,12 +73,26 @@ def __init__(self, metric_name: str, call_type: str = "y_gt") -> None: |
71 | 73 |
|
72 | 74 | pruna_logger.info(f"LMEvalMetric initialized: {metric_name} (higher_is_better={self.higher_is_better})") |
73 | 75 |
|
74 | | - def update(self, preds, refs) -> None: |
75 | | - """Accumulate predictions and references for later aggregation.""" |
76 | | - if len(preds) != len(refs): |
77 | | - raise ValueError(f"Preds and refs length mismatch: {len(preds)} vs {len(refs)}") |
78 | | - |
79 | | - for ref, pred in zip(refs, preds): |
| 76 | + def update( |
| 77 | + self, |
| 78 | + x: List[Any] | torch.Tensor, |
| 79 | + gt: List[Any] | torch.Tensor, |
| 80 | + outputs: List[Any] | torch.Tensor, |
| 81 | + ) -> None: |
| 82 | + """ |
| 83 | + Accumulate predictions and references for later aggregation. |
| 84 | +
|
| 85 | + Parameters |
| 86 | + ---------- |
| 87 | + x : List[Any] | torch.Tensor |
| 88 | + Input data. |
| 89 | + gt : List[Any] | torch.Tensor |
| 90 | + Ground truth data. |
| 91 | + outputs : List[Any] | torch.Tensor |
| 92 | + Output data. |
| 93 | + """ |
| 94 | + inputs = metric_data_processor(x, gt, outputs, self.call_type) |
| 95 | + for ref, pred in zip(inputs[0], inputs[1]): |
80 | 96 | raw_item = self.metric_fn((ref, pred)) |
81 | 97 | self.pairs.append(raw_item) |
82 | 98 |
|
|
0 commit comments