@@ -301,7 +301,7 @@ def __exit__(self, exc_type, exc_value, traceback):
301301# adopted from
302302# https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/gemma/modeling_gemma.py#L965
303303# https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/llama/modeling_llama.py#L1058
304- def _llama_gemma_update_causal_mask (self , attention_mask , input_tensor , cache_position , past_seen_tokens = None ):
304+ def _llama_gemma_update_causal_mask_legacy (self , attention_mask , input_tensor , cache_position , past_seen_tokens = None ):
305305 from transformers .modeling_attn_mask_utils import AttentionMaskConverter
306306
307307 if self .config ._attn_implementation == "sdpa" and past_seen_tokens is not None :
@@ -314,10 +314,12 @@ def _llama_gemma_update_causal_mask(self, attention_mask, input_tensor, cache_po
314314
315315 dtype , device = input_tensor .dtype , input_tensor .device
316316
317+ # difference with original modeling
317318 # using minimum from dtype with larger bandwith (floa32) may lead to overflow
318319 # during execution on platforms with default lower precision (bfloat16, float16)
319320 min_dtype = torch .finfo (torch .float16 ).min
320321 sequence_length = input_tensor .shape [1 ]
322+ # difference with original modeling
321323 if hasattr (getattr (self .layers [0 ], "self_attn" , {}), "past_key_value" ): # static cache
322324 target_length = self .config .max_position_embeddings
323325 else : # dynamic cache
@@ -329,7 +331,9 @@ def _llama_gemma_update_causal_mask(self, attention_mask, input_tensor, cache_po
329331
330332 target_length = attention_mask .shape [- 1 ] if isinstance (attention_mask , torch .Tensor ) else current_length
331333
334+ # difference with original modeling
332335 causal_mask = torch .full ((sequence_length , target_length ), fill_value = 1 , dtype = dtype , device = device ) * min_dtype
336+
333337 if sequence_length != 1 :
334338 causal_mask = torch .triu (causal_mask , diagonal = 1 )
335339 causal_mask *= torch .arange (target_length , device = device ) > cache_position .reshape (- 1 , 1 )
@@ -366,6 +370,104 @@ def _llama_gemma_update_causal_mask(self, attention_mask, input_tensor, cache_po
366370 return causal_mask
367371
368372
373+ # adopted from https://github.com/huggingface/transformers/blob/f4014e75db0190792b3feeccfc5dc5b5f9f0ce7b/src/transformers/models/llama/modeling_llama.py#L1036
374+ def _llama_gemma_update_causal_mask_latest (
375+ self ,
376+ attention_mask ,
377+ input_tensor ,
378+ cache_position ,
379+ past_key_values ,
380+ output_attentions ,
381+ ):
382+ from transformers .cache_utils import StaticCache
383+ from transformers .modeling_attn_mask_utils import AttentionMaskConverter
384+
385+ # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
386+ # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
387+ # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
388+ # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
389+
390+ if self .config ._attn_implementation == "flash_attention_2" :
391+ if attention_mask is not None and 0.0 in attention_mask :
392+ return attention_mask
393+ return None
394+
395+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
396+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
397+ # to infer the attention mask.
398+ past_seen_tokens = past_key_values .get_seq_length () if past_key_values is not None else 0
399+ using_static_cache = isinstance (past_key_values , StaticCache )
400+
401+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
402+ if self .config ._attn_implementation == "sdpa" and not using_static_cache and not output_attentions :
403+ if AttentionMaskConverter ._ignore_causal_mask_sdpa (
404+ attention_mask ,
405+ inputs_embeds = input_tensor ,
406+ past_key_values_length = past_seen_tokens ,
407+ is_training = self .training ,
408+ ):
409+ return None
410+
411+ dtype , device = input_tensor .dtype , input_tensor .device
412+ # difference with original modeling
413+ # using minimum from dtype with larger bandwith (floa32) may lead to overflow
414+ # during execution on platforms with default lower precision (bfloat16, float16)
415+ min_dtype = torch .finfo (torch .float16 ).min
416+
417+ sequence_length = input_tensor .shape [1 ]
418+ if using_static_cache :
419+ target_length = past_key_values .get_max_length ()
420+ else :
421+ target_length = (
422+ attention_mask .shape [- 1 ]
423+ if isinstance (attention_mask , torch .Tensor )
424+ else past_seen_tokens + sequence_length + 1
425+ )
426+
427+ if attention_mask is not None and attention_mask .dim () == 4 :
428+ # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
429+ if attention_mask .max () != 0 :
430+ raise ValueError ("Custom 4D attention mask should be passed in inverted form with max==0`" )
431+ causal_mask = attention_mask
432+ else :
433+ # difference with original modeling
434+ causal_mask = (
435+ torch .full ((sequence_length , target_length ), fill_value = 1 , dtype = dtype , device = device ) * min_dtype
436+ )
437+
438+ if sequence_length != 1 :
439+ causal_mask = torch .triu (causal_mask , diagonal = 1 )
440+ causal_mask *= torch .arange (target_length , device = device ) > cache_position .reshape (- 1 , 1 )
441+ causal_mask = causal_mask [None , None , :, :].expand (input_tensor .shape [0 ], 1 , - 1 , - 1 )
442+ if attention_mask is not None :
443+ causal_mask = causal_mask .clone () # copy to contiguous memory for in-place edit
444+ mask_length = attention_mask .shape [- 1 ]
445+ padding_mask = causal_mask [:, :, :, :mask_length ] + attention_mask [:, None , None , :]
446+ padding_mask = padding_mask == 0
447+ causal_mask [:, :, :, :mask_length ] = causal_mask [:, :, :, :mask_length ].masked_fill (
448+ padding_mask , min_dtype
449+ )
450+ if (
451+ self .config ._attn_implementation == "sdpa"
452+ and attention_mask is not None
453+ and attention_mask .device .type == "cuda"
454+ and not output_attentions
455+ ):
456+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
457+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
458+ # Details: https://github.com/pytorch/pytorch/issues/110213
459+ causal_mask = AttentionMaskConverter ._unmask_unattended (causal_mask , min_dtype )
460+
461+ return causal_mask
462+
463+
464+ # TODO : deprecate _llama_gemma_update_causal_mask_legacy when transformers>=4.41.0
465+ if is_transformers_version (">" , "4.40.2" ):
466+ _llama_gemma_update_causal_mask = _llama_gemma_update_causal_mask_latest
467+ else :
468+ _llama_gemma_update_causal_mask = _llama_gemma_update_causal_mask_legacy
469+
470+
369471class GemmaModelPatcher (DecoderModelPatcher ):
370472 def __enter__ (self ):
371473 super ().__enter__ ()
0 commit comments