Skip to content

Commit 9b87874

Browse files
committed
DRY refactor
1 parent 332a963 commit 9b87874

File tree

1 file changed

+9
-19
lines changed

1 file changed

+9
-19
lines changed

lm_eval/utils.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -179,36 +179,26 @@ def parse_infill(code, tokenizer):
179179
code_gens = [[] for _ in range(n_tasks)]
180180
for sample, generated_tokens in gen_token_dict.items():
181181
for s in generated_tokens:
182-
if INFILL_MODE:
183-
gen_code = parse_infill(
184-
tokenizer.decode(
185-
s, skip_special_tokens=False, clean_up_tokenization_spaces=False
186-
),
187-
tokenizer,
188-
)
189-
elif tokenizer.eos_token in task.stop_words:
182+
if INFILL_MODE or tokenizer.eos_token in task.stop_words:
190183
gen_code = tokenizer.decode(
191184
s, skip_special_tokens=False, clean_up_tokenization_spaces=False
192185
)
186+
if INFILL_MODE:
187+
gen_code = parse_infill(gen_code, tokenizer)
193188
else:
194189
gen_code = tokenizer.decode(
195190
s, skip_special_tokens=True, clean_up_tokenization_spaces=True
196191
)
192+
if not INFILL_MODE:
193+
gen_code = gen_code[len(prefix) :]
197194
if postprocess:
198-
if INFILL_MODE:
199-
code_gens[sample].append(
200-
task.postprocess_generation(gen_code, int(sample))
201-
)
202-
else:
203-
code_gens[sample].append(
204-
task.postprocess_generation(
205-
gen_code[len(prefix) :], int(sample)
206-
)
207-
)
195+
code_gens[sample].append(
196+
task.postprocess_generation(gen_code, int(sample))
197+
)
208198
else:
209199
warnings.warn(
210200
"model output is not postprocessed, this might lower evaluation scores"
211201
)
212-
code_gens[sample].append(gen_code[len(prefix) :])
202+
code_gens[sample].append(gen_code)
213203

214204
return code_gens

0 commit comments

Comments
 (0)