@@ -179,36 +179,26 @@ def parse_infill(code, tokenizer):
179179 code_gens = [[] for _ in range (n_tasks )]
180180 for sample , generated_tokens in gen_token_dict .items ():
181181 for s in generated_tokens :
182- if INFILL_MODE :
183- gen_code = parse_infill (
184- tokenizer .decode (
185- s , skip_special_tokens = False , clean_up_tokenization_spaces = False
186- ),
187- tokenizer ,
188- )
189- elif tokenizer .eos_token in task .stop_words :
182+ if INFILL_MODE or tokenizer .eos_token in task .stop_words :
190183 gen_code = tokenizer .decode (
191184 s , skip_special_tokens = False , clean_up_tokenization_spaces = False
192185 )
186+ if INFILL_MODE :
187+ gen_code = parse_infill (gen_code , tokenizer )
193188 else :
194189 gen_code = tokenizer .decode (
195190 s , skip_special_tokens = True , clean_up_tokenization_spaces = True
196191 )
192+ if not INFILL_MODE :
193+ gen_code = gen_code [len (prefix ) :]
197194 if postprocess :
198- if INFILL_MODE :
199- code_gens [sample ].append (
200- task .postprocess_generation (gen_code , int (sample ))
201- )
202- else :
203- code_gens [sample ].append (
204- task .postprocess_generation (
205- gen_code [len (prefix ) :], int (sample )
206- )
207- )
195+ code_gens [sample ].append (
196+ task .postprocess_generation (gen_code , int (sample ))
197+ )
208198 else :
209199 warnings .warn (
210200 "model output is not postprocessed, this might lower evaluation scores"
211201 )
212- code_gens [sample ].append (gen_code [ len ( prefix ) :] )
202+ code_gens [sample ].append (gen_code )
213203
214204 return code_gens
0 commit comments