Skip to content

Commit 128a395

Browse files
committed
test: add control check for inputs to the metric, update the test
1 parent 3fd5aaa commit 128a395

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

src/pruna/evaluation/metrics/metric_evalharness.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,10 @@ def update(
9191
outputs : List[Any] | torch.Tensor
9292
Output data.
9393
"""
94-
if len(x) != len(gt) != len(outputs):
94+
if not (len(x) == len(gt) == len(outputs)):
95+
pruna_logger.error(f"Input, ground truth, and output length mismatch: {len(x)} vs {len(gt)} vs {len(outputs)}")
9596
raise ValueError(f"Input, ground truth, and output length mismatch: {len(x)} vs {len(gt)} vs {len(outputs)}")
97+
pruna_logger.debug(f"Processing {len(x)} samples for {self.metric_name}")
9698
inputs = metric_data_processor(x, gt, outputs, self.call_type)
9799
for ref, pred in zip(inputs[0], inputs[1]):
98100
raw_item = self.metric_fn((ref, pred))

tests/evaluation/test_evalharness_metrics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def test_lm_eval_metric_length_mismatch():
4848
refs = ["a", "b", "c"]
4949
preds = ["a", "b"]
5050

51-
with pytest.raises(ValueError, match="Preds and refs length mismatch"):
51+
with pytest.raises(ValueError):
5252
metric.update(refs, refs, preds)
5353

5454

0 commit comments

Comments
 (0)