Skip to content

Commit 532909c

Browse files
Ensure backwards compatibility in fewshot_context by using kwargs (#3079)
Signed-off-by: kiersten-stokes <[email protected]>
1 parent 8bc4620 commit 532909c

File tree

2 files changed

+11
-26
lines changed

2 files changed

+11
-26
lines changed

lm_eval/api/task.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -458,11 +458,13 @@ def build_all_requests(
458458
# sample fewshot context #TODO: need to offset doc_id by rank now!
459459
fewshot_ctx = self.fewshot_context(
460460
doc,
461-
0 if self.config.num_fewshot is None else self.config.num_fewshot,
462-
system_instruction,
463-
apply_chat_template,
464-
fewshot_as_multiturn,
465-
chat_template,
461+
num_fewshot=0
462+
if self.config.num_fewshot is None
463+
else self.config.num_fewshot,
464+
system_instruction=system_instruction,
465+
apply_chat_template=apply_chat_template,
466+
fewshot_as_multiturn=fewshot_as_multiturn,
467+
chat_template=chat_template,
466468
gen_prefix=self.doc_to_prefix(doc),
467469
)
468470

lm_eval/tasks/unitxt/task.py

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
import importlib.util
88
import re
9-
from collections.abc import Callable
109
from functools import partial
1110
from typing import Any, Dict, Optional
1211

@@ -110,34 +109,18 @@ def doc_to_target(self, doc):
110109
def get_arguments(self, doc, ctx):
111110
return (ctx, {"until": ["\n"]})
112111

113-
def fewshot_context(
114-
self,
115-
doc: str,
116-
num_fewshot: int,
117-
system_instruction: Optional[str] = None,
118-
apply_chat_template: bool = False,
119-
fewshot_as_multiturn: bool = False,
120-
chat_template: Optional[Callable] = None,
121-
gen_prefix: Optional[str] = None,
122-
) -> str:
112+
def fewshot_context(self, doc, **kwargs) -> str:
123113
if isinstance(self.doc_to_text(doc), list):
124-
if apply_chat_template:
114+
if kwargs.get("apply_chat_template"):
115+
chat_template = kwargs.get("chat_template")
125116
formated_source = chat_template(self.doc_to_text(doc))
126117
return formated_source
127118
else:
128119
raise Exception(
129120
"Got chat template format from Unitxt, but apply_chat_template is false. Add '--apply_chat_template' to command line."
130121
)
131122
else:
132-
return super().fewshot_context(
133-
doc=doc,
134-
num_fewshot=num_fewshot,
135-
system_instruction=system_instruction,
136-
apply_chat_template=apply_chat_template,
137-
fewshot_as_multiturn=fewshot_as_multiturn,
138-
chat_template=chat_template,
139-
gen_prefix=gen_prefix,
140-
)
123+
return super().fewshot_context(doc=doc, **kwargs)
141124

142125
def construct_requests(self, doc, ctx, **kwargs):
143126
"""Uses RequestFactory to construct Requests and returns an iterable of

0 commit comments

Comments
 (0)