@@ -46,7 +46,9 @@ def __iter__(self):
4646 elif isinstance (prompt_contents , dict ):
4747 assert set (prompt_contents .keys ()) == {"prefix" , "suffix" }
4848 infill .append (True )
49- prompt = self .prefix + self ._make_infill_prompt (** prompt_contents )
49+ prompt = self ._make_infill_prompt (
50+ ** prompt_contents , preprefix = self .prefix
51+ )
5052 else :
5153 raise ValueError (f"Unsupported prompt format: { type (prompt_contents )} " )
5254 prompts .append (prompt )
@@ -83,18 +85,18 @@ def __iter__(self):
8385 "input_len" : outputs .attention_mask [sample ].sum (),
8486 }
8587
86- def _make_infill_prompt (self , prefix , suffix ):
88+ def _make_infill_prompt (self , prefix , suffix , preprefix = "" ):
8789 """Make a prompt for infilling.
8890 Currently supported only for official InCoder and SantaCoder implementations.
8991 """
9092 model_id = self .tokenizer .name_or_path
9193 if model_id in ["facebook/incoder-1B" , "facebook/incoder-6B" ]:
9294 self .tokenizer .add_special_tokens ({"pad_token" : "<pad>" })
93- return f"{ prefix } <|mask:0|>{ suffix } <|mask:0|>"
95+ return f"{ preprefix } { prefix } <|mask:0|>{ suffix } <|mask:0|>"
9496 elif model_id in ["bigcode/santacoder" ]:
95- return f"<fim-prefix>{ prefix } <fim-suffix>{ suffix } <fim-middle>"
96- elif model_id in ["bigcode/large-model" ]:
97- return f"<fim_prefix>{ prefix } <fim_suffix>{ suffix } <fim_middle>"
97+ return f"<fim-prefix>{ preprefix } { prefix } <fim-suffix>{ suffix } <fim-middle>"
98+ elif model_id in ["bigcode/large-model" , "bigcode/temp-model" ]:
99+ return f"<fim_prefix>{ preprefix } { prefix } <fim_suffix>{ suffix } <fim_middle>"
98100 else :
99101 raise ValueError (f"Infilling not yet supported for: { model_id } " )
100102
@@ -160,7 +162,7 @@ def parse_infill(code, tokenizer):
160162 prefix , rest = code .split ("<fim-suffix>" , 1 )
161163 suffix , infill = rest .split ("<fim-middle>" , 1 )
162164 infill = infill .split ("<|endoftext|>" )[0 ]
163- elif model_id in ["bigcode/large-model" ]:
165+ elif model_id in ["bigcode/large-model" , "bigcode/temp-model" ]:
164166 prefix , rest = code .split ("<fim_suffix>" , 1 )
165167 suffix , infill = rest .split ("<fim_middle>" , 1 )
166168 infill = infill .split ("<|endoftext|>" )[0 ]
@@ -193,9 +195,16 @@ def parse_infill(code, tokenizer):
193195 s , skip_special_tokens = True , clean_up_tokenization_spaces = True
194196 )
195197 if postprocess :
196- code_gens [sample ].append (
197- task .postprocess_generation (gen_code [len (prefix ) :], int (sample ))
198- )
198+ if INFILL_MODE :
199+ code_gens [sample ].append (
200+ task .postprocess_generation (gen_code , int (sample ))
201+ )
202+ else :
203+ code_gens [sample ].append (
204+ task .postprocess_generation (
205+ gen_code [len (prefix ) :], int (sample )
206+ )
207+ )
199208 else :
200209 warnings .warn (
201210 "model output is not postprocessed, this might lower evaluation scores"
0 commit comments