File tree Expand file tree Collapse file tree 4 files changed +7
-4
lines changed
Expand file tree Collapse file tree 4 files changed +7
-4
lines changed Original file line number Diff line number Diff line change @@ -226,11 +226,17 @@ def _preprocess(self, source):
226226 source = [source ] if isinstance (source , str ) else source
227227 source = [self .tokenizer .apply_chat_template (sentence , tokenize = False ) for sentence in source ]
228228
229+ return_position_ids = False
230+ return_attention_mask = False
231+ if len (source ) > 1 :
232+ return_position_ids = True
233+ return_attention_mask = True
229234 tokenized_source = self .tokenizer (
230235 source ,
231236 max_length = self .config .src_length ,
232237 truncation = True ,
233- return_position_ids = True if not isinstance (self .tokenizer , ChatGLMTokenizer ) else False ,
238+ return_position_ids = True if not isinstance (self .tokenizer , ChatGLMTokenizer ) else return_position_ids ,
239+ return_attention_mask = return_attention_mask ,
234240 truncation_side = "left" ,
235241 return_tensors = self .return_tensors ,
236242 padding = True ,
Original file line number Diff line number Diff line change @@ -1772,7 +1772,6 @@ def prepare_inputs_for_generation(
17721772 ):
17731773 batch_size , seq_length = input_ids .shape
17741774 position_ids = kwargs .get ("position_ids" , paddle .arange (seq_length ).expand ((batch_size , seq_length )))
1775- attention_mask = kwargs .get ("attention_mask" , None )
17761775 if past_key_values :
17771776 input_ids = input_ids [:, - 1 ].unsqueeze (axis = - 1 )
17781777 position_ids = position_ids [:, - 1 ].unsqueeze (- 1 )
Original file line number Diff line number Diff line change @@ -1493,7 +1493,6 @@ def prepare_inputs_for_generation(
14931493 ):
14941494 batch_size , seq_length = input_ids .shape
14951495 position_ids = kwargs .get ("position_ids" , paddle .arange (seq_length ).expand ((batch_size , seq_length )))
1496- attention_mask = kwargs .get ("attention_mask" , None )
14971496 if past_key_values :
14981497 input_ids = input_ids [:, - 1 ].unsqueeze (axis = - 1 )
14991498 position_ids = position_ids [:, - 1 ].unsqueeze (- 1 )
Original file line number Diff line number Diff line change @@ -1429,7 +1429,6 @@ def prepare_inputs_for_generation(
14291429 ):
14301430 batch_size , seq_length = input_ids .shape
14311431 position_ids = kwargs .get ("position_ids" , paddle .arange (seq_length ).expand ((batch_size , seq_length )))
1432- attention_mask = kwargs .get ("attention_mask" , None )
14331432 if past_key_values :
14341433 input_ids = input_ids [:, - 1 ].unsqueeze (axis = - 1 )
14351434 position_ids = position_ids [:, - 1 ].unsqueeze (- 1 )
You canβt perform that action at this time.
0 commit comments