File tree Expand file tree Collapse file tree 4 files changed +7
-4
lines changed
Expand file tree Collapse file tree 4 files changed +7
-4
lines changed Original file line number Diff line number Diff line change @@ -221,11 +221,17 @@ def _preprocess(self, source):
221221 source = [source ] if isinstance (source , str ) else source
222222 source = [self .tokenizer .apply_chat_template (sentence , tokenize = False ) for sentence in source ]
223223
224+ return_position_ids = False
225+ return_attention_mask = False
226+ if len (source ) > 1 :
227+ return_position_ids = True
228+ return_attention_mask = True
224229 tokenized_source = self .tokenizer (
225230 source ,
226231 max_length = self .config .src_length ,
227232 truncation = True ,
228- return_position_ids = True if not isinstance (self .tokenizer , ChatGLMTokenizer ) else False ,
233+ return_position_ids = True if not isinstance (self .tokenizer , ChatGLMTokenizer ) else return_position_ids ,
234+ return_attention_mask = return_attention_mask ,
229235 truncation_side = "left" ,
230236 return_tensors = self .return_tensors ,
231237 padding = True ,
Original file line number Diff line number Diff line change @@ -1772,7 +1772,6 @@ def prepare_inputs_for_generation(
17721772 ):
17731773 batch_size , seq_length = input_ids .shape
17741774 position_ids = kwargs .get ("position_ids" , paddle .arange (seq_length ).expand ((batch_size , seq_length )))
1775- attention_mask = kwargs .get ("attention_mask" , None )
17761775 if past_key_values :
17771776 input_ids = input_ids [:, - 1 ].unsqueeze (axis = - 1 )
17781777 position_ids = position_ids [:, - 1 ].unsqueeze (- 1 )
Original file line number Diff line number Diff line change @@ -1493,7 +1493,6 @@ def prepare_inputs_for_generation(
14931493 ):
14941494 batch_size , seq_length = input_ids .shape
14951495 position_ids = kwargs .get ("position_ids" , paddle .arange (seq_length ).expand ((batch_size , seq_length )))
1496- attention_mask = kwargs .get ("attention_mask" , None )
14971496 if past_key_values :
14981497 input_ids = input_ids [:, - 1 ].unsqueeze (axis = - 1 )
14991498 position_ids = position_ids [:, - 1 ].unsqueeze (- 1 )
Original file line number Diff line number Diff line change @@ -1429,7 +1429,6 @@ def prepare_inputs_for_generation(
14291429 ):
14301430 batch_size , seq_length = input_ids .shape
14311431 position_ids = kwargs .get ("position_ids" , paddle .arange (seq_length ).expand ((batch_size , seq_length )))
1432- attention_mask = kwargs .get ("attention_mask" , None )
14331432 if past_key_values :
14341433 input_ids = input_ids [:, - 1 ].unsqueeze (axis = - 1 )
14351434 position_ids = position_ids [:, - 1 ].unsqueeze (- 1 )
You canβt perform that action at this time.
0 commit comments