29
29
from ..transformers .model_outputs import ModelOutput
30
30
from ..transformers .utils import get_scale_by_dtype
31
31
from ..utils .log import logger
32
+ from ..utils .masking_utils import _expand_2d_mask , _make_causal_mask
33
+ from ..utils .tools import get_env_device
32
34
from .configuration_utils import DEFAULT_MAX_NEW_TOKENS , GenerationConfig
33
35
from .logits_process import (
34
36
ForcedBOSTokenLogitsProcessor ,
@@ -339,13 +341,61 @@ def prepare_attention_mask_for_generation(input_ids, pad_token_id, eos_token_id)
339
341
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None ) or (
340
342
(eos_token_id is not None ) and (pad_token_id != eos_token_id )
341
343
)
342
- if is_pad_token_in_inputs_ids and is_pad_token_not_equal_to_eos_token_id :
343
- attention_mask = (input_ids == pad_token_id ).astype (paddle .get_default_dtype ()) * get_scale_by_dtype (
344
- return_positive = False
345
- )
344
+ inputs_tensor = input_ids
345
+
346
+ # No information for attention mask inference -> return default attention mask
347
+ default_attention_mask = paddle .ones (input_ids .shape [:2 ], dtype = paddle .get_default_dtype ())
348
+ if pad_token_id is None :
349
+ return default_attention_mask
350
+ can_infer_attention_mask = is_pad_token_in_inputs_ids * is_pad_token_not_equal_to_eos_token_id
351
+ attention_mask_from_padding = (inputs_tensor != pad_token_id ).astype (paddle .get_default_dtype ())
352
+
353
+ attention_mask = (
354
+ attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~ can_infer_attention_mask
355
+ )
356
+ return attention_mask
357
+
358
+ @staticmethod
359
+ def _prepare_decoder_attention_mask (attention_mask , input_shape , past_key_values_length , dtype ):
360
+ if attention_mask is not None :
361
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
362
+ if len (attention_mask .shape ) == 2 :
363
+ expanded_attn_mask = _expand_2d_mask (attention_mask , dtype , tgt_length = input_shape [- 1 ])
364
+ # For decoding phase in generation, seq_length = 1, we don't need to add causal mask
365
+ if input_shape [- 1 ] > 1 :
366
+ combined_attention_mask = _make_causal_mask (
367
+ input_shape , past_key_values_length = past_key_values_length
368
+ )
369
+ if get_env_device () in ["npu" , "mlu" , "intel_hpu" ]:
370
+ expanded_attn_mask = expanded_attn_mask .astype ("bool" ) & combined_attention_mask .astype ("bool" )
371
+ else :
372
+ expanded_attn_mask = expanded_attn_mask & combined_attention_mask
373
+ # [bsz, seq_len, seq_len] -> [bsz, 1, seq_len, seq_len]
374
+ elif len (attention_mask .shape ) == 3 :
375
+ expanded_attn_mask = attention_mask .unsqueeze (1 ).astype ("bool" )
376
+ # if attention_mask is already 4-D, do nothing
377
+ else :
378
+ expanded_attn_mask = attention_mask
379
+ else :
380
+ expanded_attn_mask = _make_causal_mask (input_shape , past_key_values_length = past_key_values_length )
381
+ # Convert bool attention_mask to float attention mask, which will be added to attention_scores later
382
+ if get_env_device () in ["npu" , "mlu" , "intel_hpu" ]:
383
+ x = paddle .to_tensor (0.0 , dtype = "float32" )
384
+ y = paddle .to_tensor (paddle .finfo (dtype ).min , dtype = "float32" )
385
+ expanded_attn_mask = paddle .where (expanded_attn_mask .cast ("bool" ), x , y ).astype (dtype )
386
+ elif get_env_device () == "xpu" :
387
+ x = paddle .to_tensor (0.0 , dtype = "float32" )
388
+ y = paddle .to_tensor (- 1.7005809656952787e38 , dtype = "float32" )
389
+ expanded_attn_mask = paddle .where (expanded_attn_mask .cast ("bool" ), x , y )
390
+ elif get_env_device () == "gcu" :
391
+ min_val = paddle .finfo (dtype ).min
392
+ x = paddle .to_tensor (0.0 , dtype = dtype )
393
+ y = paddle .to_tensor (min_val , dtype = dtype )
394
+ expanded_attn_mask = paddle .where (expanded_attn_mask .cast ("bool" ), x , y ).astype (dtype )
346
395
else :
347
- attention_mask = paddle .zeros_like (input_ids , dtype = paddle .get_default_dtype ())
348
- return paddle .unsqueeze (attention_mask , axis = [1 , 2 ])
396
+ expanded_attn_mask = paddle .where (expanded_attn_mask .cast ("bool" ), 0.0 , paddle .finfo (dtype ).min )
397
+ expanded_attn_mask = expanded_attn_mask .astype (dtype )
398
+ return expanded_attn_mask
349
399
350
400
@staticmethod
351
401
def prepare_seq_len_for_generation (input_ids , pad_token_id , eos_token_id ):
@@ -853,12 +903,8 @@ def generate(
853
903
bos_token_id , encoder_output = model_kwargs ["inputs_embeds" ]
854
904
)
855
905
856
- if model_kwargs .get ("attention_mask" , None ) is None :
857
- # TODO
858
- # Init `attention_mask` depending on `pad_token_id`
859
- model_kwargs ["attention_mask" ] = self .prepare_attention_mask_for_generation (
860
- input_ids , pad_token_id , eos_token_id
861
- )
906
+ kwargs_has_attention_mask = model_kwargs .get ("attention_mask" , None ) is not None
907
+ accepts_attention_mask = "attention_mask" in set (inspect .signature (self .forward ).parameters .keys ())
862
908
self .is_encoder_decoder = self .config .is_encoder_decoder
863
909
864
910
if self .is_encoder_decoder :
@@ -880,6 +926,11 @@ def generate(
880
926
881
927
pad_token_id = self .set_pad_token_id (pad_token_id , eos_token_id )
882
928
929
+ if not kwargs_has_attention_mask and accepts_attention_mask :
930
+ model_kwargs ["attention_mask" ] = self .prepare_attention_mask_for_generation (
931
+ input_ids , pad_token_id , eos_token_id
932
+ )
933
+
883
934
if generation_config .max_length != 0 and generation_config .max_new_tokens == DEFAULT_MAX_NEW_TOKENS :
884
935
logger .warning ("`max_length` will be deprecated in future releases, use `max_new_tokens` instead." )
885
936
generation_config .max_new_tokens = generation_config .max_length
0 commit comments