@@ -2501,6 +2501,40 @@ def __enter__(self):
25012501 _reinitialize_cos_sin_cached_fp32 (layer .self_attn .rotary_emb )
25022502
25032503
2504+ # Adapted from https://github.com/huggingface/transformers/blob/31f9a289a6207be6cae746e009d8e0db523be203/src/transformers/models/falcon/modeling_falcon.py#L1138
2505+ def _falcon_prepare_4d_causal_attention_mask_with_cache_position (
2506+ attention_mask : torch .Tensor ,
2507+ sequence_length : int ,
2508+ target_length : int ,
2509+ dtype : torch .dtype ,
2510+ device : torch .device ,
2511+ cache_position : torch .Tensor ,
2512+ batch_size : int ,
2513+ ** kwargs ,
2514+ ):
2515+ if attention_mask is not None and attention_mask .dim () == 4 :
2516+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
2517+ causal_mask = attention_mask
2518+ else :
2519+ # different from original: allow to provide min_dtype as parameter
2520+ min_dtype = torch .finfo (dtype ).min if "min_dtype" not in kwargs else kwargs ["min_dtype" ]
2521+ causal_mask = torch .full ((sequence_length , target_length ), fill_value = min_dtype , dtype = dtype , device = device )
2522+ if sequence_length != 1 :
2523+ causal_mask = torch .triu (causal_mask , diagonal = 1 )
2524+ causal_mask *= torch .arange (target_length , device = device ) > cache_position .reshape (- 1 , 1 )
2525+ causal_mask = causal_mask [None , None , :, :].expand (batch_size , 1 , - 1 , - 1 )
2526+ if attention_mask is not None :
2527+ causal_mask = causal_mask .clone () # copy to contiguous memory for in-place edit
2528+ mask_length = attention_mask .shape [- 1 ]
2529+ padding_mask = causal_mask [:, :, :, :mask_length ] + attention_mask [:, None , None , :]
2530+ padding_mask = padding_mask == 0
2531+ causal_mask [:, :, :, :mask_length ] = causal_mask [:, :, :, :mask_length ].masked_fill (
2532+ padding_mask , min_dtype
2533+ )
2534+
2535+ return causal_mask
2536+
2537+
25042538def _falcon_update_causal_mask (
25052539 self ,
25062540 attention_mask : torch .Tensor ,
@@ -2520,13 +2554,6 @@ def _falcon_update_causal_mask(
25202554 # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
25212555 # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
25222556
2523- if hasattr (self , "_prepare_4d_causal_attention_mask_with_cache_position" ):
2524- _prepare_4d_causal_attention_mask_with_cache_position = (
2525- self ._prepare_4d_causal_attention_mask_with_cache_position
2526- )
2527- else :
2528- from transformers .models .falcon .modeling_falcon import _prepare_4d_causal_attention_mask_with_cache_position
2529-
25302557 if self .config ._attn_implementation == "flash_attention_2" :
25312558 if attention_mask is not None and 0.0 in attention_mask :
25322559 return attention_mask
@@ -2568,7 +2595,7 @@ def _falcon_update_causal_mask(
25682595 )
25692596
25702597 # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
2571- causal_mask = _prepare_4d_causal_attention_mask_with_cache_position (
2598+ causal_mask = _falcon_prepare_4d_causal_attention_mask_with_cache_position (
25722599 attention_mask ,
25732600 sequence_length = sequence_length ,
25742601 target_length = target_length ,
0 commit comments