Skip to content

Commit 41f6bfa

Browse files
authored
Refactor: Decouple ContextSampler from Task; build_qa_turn (#3429)
Refactor: Decouple `ContextSampler` from Task; add `build_qa_turn` - Refactor `ContextSampler` to be independent of `Task` - Introduce Message class for dual plain-text/chat rendering - Extract prompt construction into `build_qa_turn` and fewshot_context - Add helper functions: `maybe_delimit`, `requires_delimiter`
1 parent 95b20f9 commit 41f6bfa

File tree

7 files changed

+1627
-449
lines changed

7 files changed

+1627
-449
lines changed

lm_eval/api/samplers.py

Lines changed: 87 additions & 183 deletions
Original file line numberDiff line numberDiff line change
@@ -1,223 +1,127 @@
1-
import logging
2-
import warnings
3-
from functools import partial
4-
from typing import TYPE_CHECKING, Iterable, Optional, Union
1+
from __future__ import annotations
52

6-
import datasets
3+
import logging
4+
from random import Random
5+
from typing import TYPE_CHECKING
76

87

98
if TYPE_CHECKING:
10-
from random import Random
9+
from collections.abc import Iterable, Sequence
10+
from typing import Any, TypeVar
1111

12-
from lm_eval.api.task import ConfigurableTask, Task
12+
_T = TypeVar("_T")
1313

14-
eval_logger = logging.getLogger("lm-eval")
14+
eval_logger = logging.getLogger(__name__)
1515

1616

1717
class ContextSampler:
1818
def __init__(
1919
self,
20-
docs: list[dict],
21-
task: Union["Task", "ConfigurableTask"],
22-
fewshot_indices: Optional[Iterable] = None,
23-
rnd: Optional["Random"] = None,
20+
df: Sequence[dict[str, Any]] | None = None,
21+
*,
22+
rnd: int | None = None,
23+
fewshot_indices: list[int] | None = None,
24+
**kwargs,
2425
) -> None:
25-
self.rnd = rnd
26-
if not self.rnd:
27-
raise ValueError(
28-
"A `random.Random` generator argument must be provided to `rnd` of FewShotSampler!"
29-
)
26+
self.rnd = Random(rnd)
27+
self.df = df or []
28+
self.fewshot_indices = fewshot_indices
29+
self._loaded = False # to iterate over fewshot_indices when needed
3030

31-
self.task = task
32-
self.config = task._config
31+
def sample(
32+
self,
33+
n: int,
34+
eval_doc: dict[str, Any] | None = None,
35+
df: Sequence[dict[str, Any]] | None = None,
36+
**kwargs,
37+
) -> Sequence[dict[str, Any]]:
38+
"""
39+
Sample n documents from the pool.
3340
34-
self.target_delimiter = self.config.target_delimiter
35-
self.fewshot_delimiter = self.config.fewshot_delimiter
41+
Args:
42+
n: Number of documents to sample
43+
eval_doc: Optional document to exclude from sampling
44+
df: Optional list of documents to sample from
3645
37-
if (
38-
self.config.fewshot_config is not None
39-
and self.config.fewshot_config.get("doc_to_text", None) is not None
40-
):
41-
self.doc_to_text = partial(
42-
self.task.doc_to_text,
43-
doc_to_text=self.config.fewshot_config.get("doc_to_text", None),
44-
)
45-
else:
46-
self.doc_to_text = self.task.doc_to_text
47-
48-
if (
49-
self.config.fewshot_config is not None
50-
and self.config.fewshot_config.get("doc_to_target", None) is not None
51-
):
52-
self.doc_to_target = partial(
53-
self.task.doc_to_target,
54-
doc_to_target=self.config.fewshot_config.get("doc_to_target", None),
55-
)
56-
else:
57-
self.doc_to_target = self.task.doc_to_target
58-
59-
if (
60-
self.config.fewshot_config is not None
61-
and self.config.fewshot_config.get("doc_to_choice", None) is not None
62-
):
63-
self.doc_to_choice = partial(
64-
self.task.doc_to_choice,
65-
doc_to_choice=self.config.fewshot_config.get("doc_to_choice", None),
46+
Returns:
47+
List of sampled documents
48+
"""
49+
assert n >= 0, "Error: number of samples requested must be >=0"
50+
if n == 0:
51+
return []
52+
53+
if df:
54+
self.df = df
55+
56+
assert self.df, "Error: no documents available for sampling."
57+
res = (
58+
self.rnd.sample(self.fewshot_docs(), n)
59+
if not eval_doc
60+
else self.rm_eval_doc(
61+
eval_doc, self.rnd.sample(self.fewshot_docs(), n + 1), n
6662
)
67-
else:
68-
self.doc_to_choice = self.task.doc_to_choice
69-
70-
self.docs = docs # HF dataset split, provided by task._fewshot_docs()
71-
if fewshot_indices: # subset few-shot docs from
72-
if not isinstance(self.docs, datasets.Dataset):
73-
raise ValueError(
74-
"Got `fewshot_indices` but fewshot_docs are not a HF dataset. Don't use both `fewshot_indices` and a user-defined few-shot sample list simultaneously"
75-
)
76-
self.docs = self.docs.select(fewshot_indices)
77-
78-
def get_context(self, doc: dict, num_fewshot: int, gen_prefix: str = None):
79-
# draw an extra fewshot sample if using same split as evaluating on
80-
prefix = gen_prefix + " " if gen_prefix else ""
81-
n_samples = (
82-
num_fewshot + 1
83-
if self.config.fewshot_split == self.config.test_split
84-
else num_fewshot
8563
)
86-
87-
# draw `n_samples` docs from fewshot_docs
88-
fewshotex = self.sample(n_samples)
89-
90-
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
91-
# TODO: should we just stop people from using fewshot from same split as evaluating?
92-
selected_docs = [x for x in fewshotex if x != doc][:num_fewshot]
93-
94-
labeled_examples = ""
95-
for doc in selected_docs:
96-
doc_content = self.doc_to_text(doc)
97-
doc_target = self.doc_to_target(doc)
98-
if self.config.doc_to_choice is None or isinstance(doc_content, str):
99-
labeled_examples += doc_content
100-
else:
101-
labeled_examples += self.doc_to_choice(doc)[doc_content]
102-
103-
if doc_target != "":
104-
if self.target_delimiter.isspace() and str(doc_target)[0].isspace():
105-
# TODO: add logger warn once here.
106-
warnings.warn(
107-
"Both target_delimiter and target start with a space. This may cause issues.",
108-
Warning,
109-
stacklevel=2,
110-
)
111-
labeled_examples += self.target_delimiter
112-
labeled_examples += prefix
113-
labeled_examples += (
114-
str(doc_target[0])
115-
if isinstance(doc_target, list)
116-
else doc_target
117-
if self.config.doc_to_choice is None or isinstance(doc_target, str)
118-
else str(self.doc_to_choice(doc)[doc_target])
119-
)
120-
labeled_examples += self.fewshot_delimiter
121-
122-
return labeled_examples
123-
124-
def get_chat_context(
125-
self,
126-
doc: dict,
127-
num_fewshot: int,
128-
fewshot_as_multiturn: bool = False,
129-
gen_prefix: Optional[str] = None,
130-
):
131-
# TODO: Do we need any other delimiter
132-
prefix = gen_prefix + " " if gen_prefix else ""
133-
chat_history = []
134-
# draw an extra fewshot sample if using same split as evaluating on
135-
n_samples = (
136-
num_fewshot + 1
137-
if self.config.fewshot_split == self.config.test_split
138-
else num_fewshot
64+
assert len(res) == n, (
65+
f"Error: number of fewshot samples returned ({len(res)}) not equal to number requested ({n})."
66+
)
67+
return res
68+
69+
def set_rnd(self, rnd: int | None):
70+
self.rnd = Random(rnd)
71+
return self
72+
73+
def replace_df(self, df: Sequence[dict[str, Any]]):
74+
self.df = df
75+
self._loaded = False
76+
return self
77+
78+
def fewshot_docs(self):
79+
"""Return cached fewshot docs if available"""
80+
if self._loaded:
81+
return self.df
82+
if self.fewshot_indices and self.df and not self._loaded:
83+
self.df = [self.df[i] for i in self.fewshot_indices]
84+
self._loaded = True
85+
return list(self.df)
86+
87+
@staticmethod
88+
def rm_eval_doc(doc: _T, _iter: Iterable[_T], n=None) -> Sequence[_T]:
89+
return (
90+
[x for x in _iter if x != doc]
91+
if n is None
92+
else [x for x in _iter if x != doc][:n]
13993
)
140-
# draw `n_samples` docs from fewshot_docs
141-
fewshotex = self.sample(n_samples)
142-
143-
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
144-
# TODO: should we just stop people from using fewshot from same split as evaluating?
145-
selected_docs = [x for x in fewshotex if x != doc][:num_fewshot]
146-
147-
if fewshot_as_multiturn:
148-
for doc in selected_docs:
149-
doc_content = self.doc_to_text(doc)
150-
doc_target = self.doc_to_target(doc)
151-
chat_history.append(
152-
{
153-
"role": "user",
154-
"content": doc_content
155-
if self.config.doc_to_choice is None
156-
or isinstance(doc_content, str)
157-
else self.doc_to_choice(doc)[doc_content],
158-
}
159-
)
160-
chat_history.append(
161-
{
162-
"role": "assistant",
163-
"content": prefix + str(doc_target[0])
164-
if isinstance(doc_target, list)
165-
else prefix + doc_target
166-
if self.config.doc_to_choice is None
167-
or isinstance(doc_target, str)
168-
else prefix + str(self.doc_to_choice(doc)[doc_target]),
169-
}
170-
)
171-
else:
172-
# get fewshot context as one user turn
173-
chat_history.append(
174-
{
175-
"role": "user",
176-
"content": self.get_context(
177-
doc, num_fewshot, gen_prefix=gen_prefix
178-
),
179-
}
180-
)
181-
182-
return chat_history
183-
184-
def sample(self, n: int):
185-
"""
186-
Draw `n` samples from our fewshot docs. This method should be overridden by subclasses.
187-
"""
188-
189-
return self.rnd.sample(self.docs, n)
19094

19195

19296
class FirstNSampler(ContextSampler):
193-
def sample(self, n: int) -> None:
97+
def sample(self, n: int, eval_doc=None, df=None, **kwargs):
19498
"""
19599
Draw the first `n` samples in order from the specified split.
196100
Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU.
197101
"""
198-
assert n <= len(self.docs), (
199-
f"Error: number of fewshot samples requested exceeds the {len(self.docs)} that are available."
102+
assert n <= len(self.df), (
103+
f"Error: number of fewshot samples requested exceeds the {len(self.df)} that are available."
200104
)
201-
return self.docs[:n]
105+
return self.df[:n]
202106

203107

204108
class BalancedSampler(ContextSampler):
205-
def sample(self, n: int) -> None:
109+
def sample(self, n: int, eval_doc=None, df=None, **kwargs):
206110
"""
207111
TODO: this should return approximately class-balanced samples from our fewshot examples.
208112
TODO: what order should they be in? maybe random?
209113
"""
210114

211-
pass
115+
raise NotImplementedError
212116

213117

214118
class ManualSampler(ContextSampler):
215-
def sample(self, n: int) -> None:
119+
def sample(self, n: int, eval_doc=None, df=None, **kwargs):
216120
""" """
217-
pass
121+
raise NotImplementedError
218122

219123

220-
SAMPLER_REGISTRY = {
124+
SAMPLER_REGISTRY: dict[str, type[ContextSampler]] = {
221125
"default": ContextSampler,
222126
"first_n": FirstNSampler,
223127
}
@@ -226,7 +130,7 @@ def sample(self, n: int) -> None:
226130
def get_sampler(name: str):
227131
try:
228132
return SAMPLER_REGISTRY[name]
229-
except KeyError:
230-
raise ValueError(
133+
except KeyError as e:
134+
raise KeyError(
231135
f"Attempted to use contextsampler '{name}', but no sampling strategy for this name found! Supported model names: {', '.join(SAMPLER_REGISTRY.keys())}"
232-
)
136+
) from e

0 commit comments

Comments
 (0)