File tree Expand file tree Collapse file tree 3 files changed +7
-3
lines changed
Expand file tree Collapse file tree 3 files changed +7
-3
lines changed Original file line number Diff line number Diff line change @@ -212,11 +212,17 @@ def _preprocess(self, source):
212212 source = [source ] if isinstance (source , str ) else source
213213 source = [self .tokenizer .apply_chat_template (sentence , tokenize = False ) for sentence in source ]
214214
215+ return_position_ids = False
216+ return_attention_mask = False
217+ if len (source ) > 1 :
218+ return_position_ids = True
219+ return_attention_mask = True
215220 tokenized_source = self .tokenizer (
216221 source ,
217222 max_length = self .config .src_length ,
218223 truncation = True ,
219- return_position_ids = True if not isinstance (self .tokenizer , ChatGLMTokenizer ) else False ,
224+ return_position_ids = True if not isinstance (self .tokenizer , ChatGLMTokenizer ) else return_position_ids ,
225+ return_attention_mask = return_attention_mask ,
220226 truncation_side = "left" ,
221227 return_tensors = self .return_tensors ,
222228 padding = True ,
Original file line number Diff line number Diff line change @@ -1493,7 +1493,6 @@ def prepare_inputs_for_generation(
14931493 ):
14941494 batch_size , seq_length = input_ids .shape
14951495 position_ids = kwargs .get ("position_ids" , paddle .arange (seq_length ).expand ((batch_size , seq_length )))
1496- attention_mask = kwargs .get ("attention_mask" , None )
14971496 if past_key_values :
14981497 input_ids = input_ids [:, - 1 ].unsqueeze (axis = - 1 )
14991498 position_ids = position_ids [:, - 1 ].unsqueeze (- 1 )
Original file line number Diff line number Diff line change @@ -1429,7 +1429,6 @@ def prepare_inputs_for_generation(
14291429 ):
14301430 batch_size , seq_length = input_ids .shape
14311431 position_ids = kwargs .get ("position_ids" , paddle .arange (seq_length ).expand ((batch_size , seq_length )))
1432- attention_mask = kwargs .get ("attention_mask" , None )
14331432 if past_key_values :
14341433 input_ids = input_ids [:, - 1 ].unsqueeze (axis = - 1 )
14351434 position_ids = position_ids [:, - 1 ].unsqueeze (- 1 )
You canβt perform that action at this time.
0 commit comments