Skip to content

Commit 264c086

Browse files
committed
test: add control check for inputs to the metric, update the test
1 parent 4ef8b35 commit 264c086

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

src/pruna/evaluation/metrics/metric_evalharness.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,11 @@ def update(
9191
outputs : List[Any] | torch.Tensor
9292
Output data.
9393
"""
94-
if len(x) != len(gt) != len(outputs):
95-
raise ValueError(f"Input, ground truth, and output length mismatch: {len(x)} vs {len(gt)} vs {len(outputs)}")
94+
if not (len(x) == len(gt) == len(outputs)):
95+
error_message = f"Input, ground truth, and output length mismatch: {len(x)} vs {len(gt)} vs {len(outputs)}"
96+
pruna_logger.error(error_message)
97+
raise ValueError(error_message)
98+
pruna_logger.debug(f"Processing {len(x)} samples for {self.metric_name}")
9699
inputs = metric_data_processor(x, gt, outputs, self.call_type)
97100
for ref, pred in zip(inputs[0], inputs[1]):
98101
raw_item = self.metric_fn((ref, pred))

tests/evaluation/test_evalharness_metrics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def test_lm_eval_metric_length_mismatch():
4848
refs = ["a", "b", "c"]
4949
preds = ["a", "b"]
5050

51-
with pytest.raises(ValueError, match="Preds and refs length mismatch"):
51+
with pytest.raises(ValueError):
5252
metric.update(refs, refs, preds)
5353

5454

0 commit comments

Comments
 (0)