@@ -109,11 +109,20 @@ def patch_model_with_bettertransformer(model):
109109 return model
110110
111111
112- def patch_update_causal_mask (model , transformers_version ):
112+ def patch_update_causal_mask (model , transformers_version , inner_model_name = "model" , patch_fn = None ):
113113 if is_transformers_version (">=" , transformers_version ):
114- inner_model = getattr (model , "model" , getattr ( model , "transformer" , None ) )
114+ inner_model = getattr (model , inner_model_name , None )
115115 if inner_model is not None :
116- inner_model ._update_causal_mask = types .MethodType (_llama_gemma_update_causal_mask , inner_model )
116+ if hasattr (inner_model , "_update_causal_mask" ):
117+ inner_model ._orig_update_causal_mask = inner_model ._update_causal_mask
118+ patch_fn = patch_fn or _llama_gemma_update_causal_mask
119+ inner_model ._update_causal_mask = types .MethodType (patch_fn , inner_model )
120+
121+
122+ def unpatch_update_causal_mask (model , inner_model_name = "model" ):
123+ inner_model = getattr (model , inner_model_name , None )
124+ if inner_model is not None and hasattr (inner_model , "._orig_update_causal_mask" ):
125+ inner_model ._update_causal_mask = inner_model ._orig_update_causal_mask
117126
118127
119128# initialization of sin/cos cached in bf16/fp16 leads to accuracy loss
@@ -579,13 +588,11 @@ def __enter__(self):
579588
580589 # llama/gemma has some accuracy issues with bf16 with transformers >= 4.39
581590 # fill causal mask in slightly different way for avoid overflow on some platforms
582- patch_update_causal_mask (self ._model , "4.39.0" )
591+ patch_update_causal_mask (self ._model , "4.39.0" , "model" if hasattr ( self . _model , "model" ) else "transformer" )
583592
584593 def __exit__ (self , exc_type , exc_value , traceback ):
585594 super ().__exit__ (exc_type , exc_value , traceback )
586- inner_model = getattr (self ._model , "model" , getattr (self ._model , "transformer" , None ))
587- if hasattr (inner_model , "_orig_update_causal_mask" ):
588- inner_model ._update_causal_mask = inner_model ._orig_update_causal_mask
595+ unpatch_update_causal_mask (self ._model , "model" if hasattr (self ._model , "model" ) else "transformer" )
589596
590597
591598# copied from https://github.com/huggingface/transformers/commit/57d7594a79a9f5d835abf2d4d384db0e4818e548 to unblock export with transformers 4.42
@@ -1865,6 +1872,67 @@ def __exit__(self, exc_type, exc_value, traceback):
18651872 layer .self_attn .forward = layer .self_attn ._orig_forward
18661873
18671874
1875+ # copied from https://github.com/huggingface/optimum/blob/2112e99122d7f23a1da1a9d263fef64301050ea7/optimum/bettertransformer/models/attention.py#L168
1876+ # for preserving backward compatibility between outdated codegen remote code and new transformers
1877+ def _codegen_wrapped_scaled_dot_product_legacy (
1878+ self ,
1879+ query : torch .Tensor ,
1880+ key : torch .Tensor ,
1881+ value : torch .Tensor ,
1882+ attention_mask : Optional [torch .Tensor ] = None ,
1883+ head_mask : Optional [torch .Tensor ] = None ,
1884+ ):
1885+ from optimum .bettertransformer .models .attention import raise_on_head_mask
1886+
1887+ raise_on_head_mask (head_mask )
1888+ batch_size = query .shape [0 ]
1889+ mask_value = torch .finfo (value .dtype ).min
1890+ mask_value = torch .full ([], mask_value , dtype = value .dtype )
1891+
1892+ if batch_size == 1 and attention_mask is not None and attention_mask [0 , 0 , - 1 , - 1 ] < - 1 :
1893+ raise ValueError ("BetterTransformer does not support padding='max_length' with a batch size of 1." )
1894+
1895+ # in codegen the query and key are always in fp32 regardless of the dtype of the model
1896+ # https://github.com/huggingface/transformers/blob/5b28b7833297adf65c5160a685425ddb1eee5ce2/src/transformers/models/codegen/modeling_codegen.py#L226
1897+ query = query .to (value .dtype )
1898+ key = key .to (value .dtype )
1899+
1900+ dropout_p = self .dropout_prob_attn if self .training else 0.0
1901+ if batch_size == 1 or self .training :
1902+ if query .shape [2 ] > 1 :
1903+ # first step of the decoding
1904+ sdpa_result = torch .nn .functional .scaled_dot_product_attention (
1905+ query , key , value , attn_mask = None , dropout_p = dropout_p , is_causal = True
1906+ )
1907+ else :
1908+ # in this case, which is the later decoding steps, the `causal_mask`` in
1909+ # https://github.com/huggingface/transformers/blob/ae54e3c3b18bac0832ad62ea9b896dfd52a09850/src/transformers/models/gpt2/modeling_gpt2.py#L195
1910+ # is [True, ..., True] so actually not causal
1911+ sdpa_result = torch .nn .functional .scaled_dot_product_attention (
1912+ query , key , value , attn_mask = None , dropout_p = dropout_p , is_causal = False
1913+ )
1914+ else :
1915+ query_length , key_length = query .size (- 2 ), key .size (- 2 )
1916+
1917+ # causal_mask is always [True, ..., True] otherwise, so executing this is unnecessary
1918+ if query_length > 1 :
1919+ causal_mask = self .causal_mask [:, :, key_length - query_length : key_length , :key_length ].to (torch .bool )
1920+
1921+ causal_mask = torch .where (causal_mask , 0 , mask_value )
1922+
1923+ # torch.Tensor.expand does no memory copy
1924+ causal_mask = causal_mask .expand (batch_size , - 1 , - 1 , - 1 )
1925+
1926+ # we use torch.min to avoid having tensor(-inf)
1927+ attention_mask = torch .min (causal_mask , attention_mask )
1928+
1929+ sdpa_result = torch .nn .functional .scaled_dot_product_attention (
1930+ query , key , value , attn_mask = attention_mask , dropout_p = dropout_p , is_causal = False
1931+ )
1932+
1933+ return sdpa_result , None
1934+
1935+
18681936class CodeGenModelPatcher (DecoderModelPatcher ):
18691937 def __enter__ (self ):
18701938 super ().__enter__ ()
@@ -1873,14 +1941,23 @@ def __enter__(self):
18731941 # For avoiding breaking model on tracing stage, we reduce area of bettertransformer patch only for _attn.
18741942 from optimum .bettertransformer .models .attention import codegen_wrapped_scaled_dot_product
18751943
1944+ attn_fn = codegen_wrapped_scaled_dot_product
1945+ if is_torch_version (">=" , "2.1.0" ) and is_transformers_version (">=" , "4.45" ):
1946+ # in transformers 4.45 causal_mask const buffer was removed from the model
1947+ # if it still exists, it means legacy remote code was loaded
1948+ if hasattr (self ._model .transformer .h [0 ].attn , "causal_mask" ):
1949+ attn_fn = _codegen_wrapped_scaled_dot_product_legacy
1950+
18761951 for layer in self ._model .transformer .h :
18771952 if is_torch_version (">=" , "2.1.0" ) and not self ._model .config .output_attentions :
18781953 orig_self_attn_fwd = layer .attn ._attn
1879- layer .attn ._attn = types .MethodType (codegen_wrapped_scaled_dot_product , layer .attn )
1954+ layer .attn ._attn = types .MethodType (attn_fn , layer .attn )
18801955 layer .attn ._orig_attn = orig_self_attn_fwd
1956+ patch_update_causal_mask (self ._model , "4.45.0" , "transformer" )
18811957
18821958 def __exit__ (self , exc_type , exc_value , traceback ):
18831959 super ().__exit__ (exc_type , exc_value , traceback )
1960+ unpatch_update_causal_mask (self ._model , "transformer" )
18841961 for layer in self ._model .transformer .h :
18851962 if hasattr (layer .attn , "_orig_attn" ):
18861963 layer .attn ._attn = layer .attn ._orig_attn
@@ -2275,8 +2352,7 @@ def __enter__(self):
22752352
22762353 def __exit__ (self , exc_type , exc_value , traceback ):
22772354 super ().__exit__ (exc_type , exc_value , traceback )
2278- if hasattr (self ._model .model , "_orig_update_causal_mask" ):
2279- self ._model .model ._update_causal_mask = self ._model .model ._orig_update_causal_mask
2355+ unpatch_update_causal_mask (self ._model )
22802356 for layer in self ._model .model .layers :
22812357 if hasattr (layer .self_attn , "_orig_forward" ):
22822358 layer .self_attn .forward = layer .self_attn ._orig_forward
@@ -2413,8 +2489,7 @@ def __enter__(self):
24132489
24142490 def __exit__ (self , exc_type , exc_value , traceback ):
24152491 super ().__exit__ (exc_type , exc_value , traceback )
2416- if hasattr (self ._model .model , "_orig_update_causal_mask" ):
2417- self ._model .model ._update_causal_mask = self ._model .model ._orig_update_causal_mask
2492+ unpatch_update_causal_mask (self ._model )
24182493
24192494
24202495class RotaryEmbPatcher (DecoderModelPatcher ):
@@ -2425,12 +2500,119 @@ def __enter__(self):
24252500 _reinitialize_cos_sin_cached_fp32 (layer .self_attn .rotary_emb )
24262501
24272502
2503+ def _falcon_update_causal_mask (
2504+ self ,
2505+ attention_mask : torch .Tensor ,
2506+ input_tensor : torch .Tensor ,
2507+ cache_position : torch .Tensor ,
2508+ past_key_values : "Cache" ,
2509+ output_attentions : bool ,
2510+ head_mask : torch .Tensor ,
2511+ alibi : torch .Tensor ,
2512+ ):
2513+ # copied from https://github.com/huggingface/transformers/blob/a30c865f991dfec9452cc64bd9a97bfbb96be036/src/transformers/models/falcon/modeling_falcon.py#L1130
2514+ from transformers .cache_utils import StaticCache
2515+ from transformers .modeling_attn_mask_utils import AttentionMaskConverter
2516+
2517+ # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
2518+ # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
2519+ # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
2520+ # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
2521+
2522+ if hasattr (self , "_prepare_4d_causal_attention_mask_with_cache_position" ):
2523+ _prepare_4d_causal_attention_mask_with_cache_position = (
2524+ self ._prepare_4d_causal_attention_mask_with_cache_position
2525+ )
2526+ else :
2527+ from transformers .models .falcon .modeling_falcon import _prepare_4d_causal_attention_mask_with_cache_position
2528+
2529+ if self .config ._attn_implementation == "flash_attention_2" :
2530+ if attention_mask is not None and 0.0 in attention_mask :
2531+ return attention_mask
2532+ return None
2533+
2534+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
2535+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
2536+ # to infer the attention mask.
2537+ past_seen_tokens = past_key_values .get_seq_length () if past_key_values is not None else 0
2538+ using_static_cache = isinstance (past_key_values , StaticCache )
2539+
2540+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
2541+ if (
2542+ self .config ._attn_implementation == "sdpa"
2543+ and not using_static_cache
2544+ and not output_attentions
2545+ and head_mask is None
2546+ and alibi is None
2547+ ):
2548+ if AttentionMaskConverter ._ignore_causal_mask_sdpa (
2549+ attention_mask ,
2550+ inputs_embeds = input_tensor ,
2551+ past_key_values_length = past_seen_tokens ,
2552+ is_training = self .training ,
2553+ ):
2554+ return None
2555+
2556+ dtype , device = input_tensor .dtype , input_tensor .device
2557+ # difference from original, replace torch.finfo(dtype).min to float16 for prevent overflow for fp16/bf16 execution
2558+ min_dtype = torch .finfo (torch .float16 ).min
2559+ batch_size , sequence_length , _ = input_tensor .shape
2560+ if using_static_cache :
2561+ target_length = past_key_values .get_max_length ()
2562+ else :
2563+ target_length = (
2564+ attention_mask .shape [- 1 ]
2565+ if isinstance (attention_mask , torch .Tensor )
2566+ else past_seen_tokens + sequence_length
2567+ )
2568+
2569+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
2570+ causal_mask = _prepare_4d_causal_attention_mask_with_cache_position (
2571+ attention_mask ,
2572+ sequence_length = sequence_length ,
2573+ target_length = target_length ,
2574+ dtype = dtype ,
2575+ device = device ,
2576+ min_dtype = min_dtype ,
2577+ cache_position = cache_position ,
2578+ batch_size = input_tensor .shape [0 ],
2579+ )
2580+
2581+ # We take care to integrate alibi bias in the causal_mask here
2582+ if head_mask is None and alibi is not None :
2583+ alibi = alibi .reshape (batch_size , - 1 , * alibi .shape [1 :])
2584+ causal_mask = torch .masked_fill (
2585+ alibi / math .sqrt (self .config .hidden_size // self .num_heads ),
2586+ causal_mask < - 1 ,
2587+ min_dtype ,
2588+ )
2589+
2590+ if (
2591+ self .config ._attn_implementation == "sdpa"
2592+ and attention_mask is not None
2593+ and attention_mask .device .type == "cuda"
2594+ and not output_attentions
2595+ ):
2596+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
2597+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
2598+ # Details: https://github.com/pytorch/pytorch/issues/110213
2599+ causal_mask = AttentionMaskConverter ._unmask_unattended (causal_mask , min_dtype )
2600+
2601+ return causal_mask
2602+
2603+
24282604class FalconModelPatcher (DecoderModelPatcher ):
24292605 def __enter__ (self ):
24302606 super ().__enter__ ()
24312607 if is_transformers_version ("<" , "4.44.99" ):
24322608 for layer in self ._model .transformer .h :
24332609 _reinitialize_cos_sin_cached_fp32 (layer .self_attention .rotary_emb )
2610+ else :
2611+ patch_update_causal_mask (self ._model , "4.45.0" , "transformer" , _falcon_update_causal_mask )
2612+
2613+ def __exit__ (self , exc_type , exc_value , traceback ):
2614+ super ().__exit__ (exc_type , exc_value , traceback )
2615+ unpatch_update_causal_mask (self ._model , "transformer" )
24342616
24352617
24362618class GptNeoxModelPatcher (DecoderModelPatcher ):
@@ -2439,6 +2621,22 @@ def __enter__(self):
24392621 if is_transformers_version ("<" , "4.44.99" ):
24402622 for layer in self ._model .gpt_neox .layers :
24412623 _reinitialize_cos_sin_cached_fp32 (layer .attention .rotary_emb )
2624+ else :
2625+ patch_update_causal_mask (self ._model , "4.45.0" , "gpt_neox" )
2626+
2627+ def __exit__ (self , exc_type , exc_value , traceback ):
2628+ super ().__exit__ (exc_type , exc_value , traceback )
2629+ unpatch_update_causal_mask (self ._model , "gpt_neox" )
2630+
2631+
2632+ class GptJModelPatcher (DecoderModelPatcher ):
2633+ def __enter__ (self ):
2634+ super ().__enter__ ()
2635+ patch_update_causal_mask (self ._model , "4.45.0" , "transformer" )
2636+
2637+ def __exit__ (self , exc_type , exc_value , traceback ):
2638+ super ().__exit__ (exc_type , exc_value , traceback )
2639+ unpatch_update_causal_mask (self ._model , "transformer" )
24422640
24432641
24442642class GptNeoxJapaneseModelPatcher (DecoderModelPatcher ):
@@ -2447,6 +2645,12 @@ def __enter__(self):
24472645 if is_transformers_version ("<" , "4.44.99" ):
24482646 for layer in self ._model .gpt_neox_japanese .layers :
24492647 _reinitialize_cos_sin_cached_fp32 (layer .attention .rotary_emb )
2648+ else :
2649+ patch_update_causal_mask (self ._model , "4.45.0" , "gpt_neox_japanese" )
2650+
2651+ def __exit__ (self , exc_type , exc_value , traceback ):
2652+ super ().__exit__ (exc_type , exc_value , traceback )
2653+ unpatch_update_causal_mask (self ._model , "gpt_neox_japanese" )
24502654
24512655
24522656class Gemma2ModelPatcher (LlamaModelPatcher ):
0 commit comments