Skip to content

Commit 77eee8c

Browse files
Adding the target perplexity fix back (#15)
--------- Co-authored-by: Thomas Wolf <[email protected]>
1 parent 37db422 commit 77eee8c

File tree

4 files changed

+23
-13
lines changed

4 files changed

+23
-13
lines changed

src/lighteval/metrics/__init__.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,18 @@
88

99
def apply_target_perplexity_metric(results: list[ModelReturn], formatted_doc: Doc, metrics: list[str]):
1010
outputs = {}
11-
current_results = [results.pop(0) for _ in range(len(formatted_doc.get_golds()))]
11+
reference_text = formatted_doc.get_golds()[0]
12+
current_result = results.pop(0)
13+
target_logprob = current_result.result[0]
14+
target_acc = current_result.result[1]
1215

1316
for metric in metrics:
14-
if Metrics[metric].value.category == MetricCategory.PERPLEXITY:
15-
outputs.update(Metrics[metric].value.compute(results=current_results))
17+
if Metrics[metric].value.category == MetricCategory.TARGET_PERPLEXITY:
18+
outputs.update(
19+
Metrics[metric].value.compute(
20+
logprobs=target_logprob, target_acc=target_acc, reference_text=reference_text
21+
)
22+
)
1623

1724
return results, outputs
1825

@@ -30,7 +37,9 @@ def apply_perplexity_metric(results: list[ModelReturn], formatted_doc: Doc, metr
3037

3138
for metric in metrics:
3239
if Metrics[metric].value.category == MetricCategory.PERPLEXITY:
33-
outputs.update(Metrics[metric].value.compute(results=current_result, reference_text=reference_text))
40+
outputs.update(
41+
Metrics[metric].value.compute(logprobs=current_result.result, reference_text=reference_text)
42+
)
3443

3544
return results, outputs
3645

src/lighteval/metrics/metrics_sample.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""This module manages all the metrics occurring at the sample level. The results of said metrics are then aggregated
22
using simple function (min, mean, max, ...) at the corpus level. Most metrics fall under this category.
33
"""
4+
from typing import Union
5+
46
import nltk
57
import numpy as np
68
from nltk.metrics.distance import edit_distance
@@ -275,17 +277,16 @@ def compute(self, choices_logprob: list[float], gold_ixs: list[float], formatted
275277
return 1.0 / (min(ranked_choices) + 1)
276278

277279

278-
def acc_golds_likelihood(results: list[tuple[float, int]], **kwargs) -> int:
280+
def acc_golds_likelihood(target_acc: Union[list[int], int], **kwargs) -> int:
279281
"""Tests if at least one of predicted gold targets' log-likelihood is above 0.5.
280282
281283
Args:
282-
results (list[int]): List of tuples containing, for each gold, the predictions log-probabilities associated with whether they are above 0.5 aggregated.
283-
formatted_doc (Doc): _description_
284+
target_acc (list[int]): List of scores indicating whether the predictions log-probabilities are above 0.5 aggregated.
284285
285286
Returns:
286287
int: 1 if at least one of the possible golds had a log-likelihood above 0.5.
287288
"""
288-
return max([int(acc_ppl) for _, acc_ppl in results])
289+
return max([int(acc_ppl) for acc_ppl in as_list(target_acc)])
289290

290291

291292
class ROUGE:

src/lighteval/metrics/sample_preparator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,14 +106,14 @@ def count_units(self, text: str) -> int:
106106
if self.units_type == "bytes":
107107
return len(text.encode("utf-8"))
108108

109-
def prepare(self, results, reference_text, **kwargs):
109+
def prepare(self, logprobs: list[float] | float, reference_text: str, **kwargs):
110110
"""Prepares an individual perplexity example to the format expected by metrics computed at the corpus level (aggregated).
111111
112112
Args:
113-
results (list[float]): List of the logprobabilities computed for each item
113+
logprobs (list[float]): List of the logprobabilities computed for each item of the sequence or single aggregated logprob over the sequence
114114
reference_text (str): Current reference text for which to compute the length in self.units_type
115115
116116
Returns:
117117
PerplexityCorpusMetricInput: Stores the measured logprobs and associated text lengths, counted in the reference unit.
118118
"""
119-
return PerplexityCorpusMetricInput(logprobs=results.result, weights=self.count_units(reference_text))
119+
return PerplexityCorpusMetricInput(logprobs=logprobs, weights=self.count_units(reference_text))
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
version https://git-lfs.github.com/spec/v1
2-
oid sha256:a1965f0b9c66cfe1b1f3cc380a80949e32eab92ae8eac079c0339506ce827093
3-
size 48373142
2+
oid sha256:408956938a6b7a18b03658bb9772b471efcea4aa04afb0b35d76cecfca6a706e
3+
size 48376580

0 commit comments

Comments
 (0)