Skip to content

Commit c4bb8d6

Browse files
committed
refactor: change metric update function to fit evaluation agent expectations
1 parent 01f868b commit c4bb8d6

File tree

3 files changed

+25
-9
lines changed

3 files changed

+25
-9
lines changed

.github/actions/setup-uv-project/action.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@ runs:
1212
github-token: ${{ github.token }}
1313

1414
- shell: bash
15-
run: uv sync --extra dev --extra evalharness
15+
run: uv sync --extra dev --extra lmharness

src/pruna/evaluation/metrics/metric_evalharness.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,14 @@
1616

1717
from typing import Any, List, Tuple
1818

19+
import torch
1920
from lm_eval.api import metrics # noqa: F401 # needed to register lm-eval metrics
2021
from lm_eval.api import registry as lm_registry
2122

2223
from pruna.evaluation.metrics.metric_stateful import StatefulMetric
2324
from pruna.evaluation.metrics.registry import MetricRegistry
2425
from pruna.evaluation.metrics.result import MetricResult
26+
from pruna.evaluation.metrics.utils import metric_data_processor
2527
from pruna.logging.logger import pruna_logger
2628

2729
METRIC_EVALHARNESS = "lm_eval_metric"
@@ -71,12 +73,26 @@ def __init__(self, metric_name: str, call_type: str = "y_gt") -> None:
7173

7274
pruna_logger.info(f"LMEvalMetric initialized: {metric_name} (higher_is_better={self.higher_is_better})")
7375

74-
def update(self, preds, refs) -> None:
75-
"""Accumulate predictions and references for later aggregation."""
76-
if len(preds) != len(refs):
77-
raise ValueError(f"Preds and refs length mismatch: {len(preds)} vs {len(refs)}")
78-
79-
for ref, pred in zip(refs, preds):
76+
def update(
77+
self,
78+
x: List[Any] | torch.Tensor,
79+
gt: List[Any] | torch.Tensor,
80+
outputs: List[Any] | torch.Tensor,
81+
) -> None:
82+
"""
83+
Accumulate predictions and references for later aggregation.
84+
85+
Parameters
86+
----------
87+
x : List[Any] | torch.Tensor
88+
Input data.
89+
gt : List[Any] | torch.Tensor
90+
Ground truth data.
91+
outputs : List[Any] | torch.Tensor
92+
Output data.
93+
"""
94+
inputs = metric_data_processor(x, gt, outputs, self.call_type)
95+
for ref, pred in zip(inputs[0], inputs[1]):
8096
raw_item = self.metric_fn((ref, pred))
8197
self.pairs.append(raw_item)
8298

tests/evaluation/test_evalharness_metrics.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def test_lm_eval_metric_bleu_like():
1818
preds = ["the cat is on mat", "a quick brown fox"]
1919

2020
metric = LMEvalMetric(metric_name="bleu")
21-
metric.update(preds, refs)
21+
metric.update(refs,refs, preds)
2222
result = metric.compute()
2323

2424
assert isinstance(result, MetricResult)
@@ -49,7 +49,7 @@ def test_lm_eval_metric_length_mismatch():
4949
preds = ["a", "b"]
5050

5151
with pytest.raises(ValueError, match="Preds and refs length mismatch"):
52-
metric.update(preds, refs)
52+
metric.update(refs, refs, preds)
5353

5454

5555
@pytest.mark.cpu

0 commit comments

Comments
 (0)