@@ -2681,6 +2681,96 @@ def __exit__(self, exc_type, exc_value, traceback):
26812681 unpatch_update_causal_mask (self ._model , "gpt_neox_japanese" )
26822682
26832683
2684+ def _gpt_neo_attn_forward (
2685+ self ,
2686+ hidden_states ,
2687+ attention_mask = None ,
2688+ layer_past = None ,
2689+ head_mask = None ,
2690+ use_cache = False ,
2691+ output_attentions = False ,
2692+ cache_position = None ,
2693+ ):
2694+ if output_attentions :
2695+ self ._attn = self ._orig_attn
2696+
2697+ return self ._orig_forward (
2698+ hidden_states ,
2699+ attention_mask = attention_mask ,
2700+ layer_past = layer_past ,
2701+ head_mask = head_mask ,
2702+ use_cache = use_cache ,
2703+ output_attentions = output_attentions ,
2704+ cache_position = cache_position ,
2705+ )
2706+
2707+
2708+ # Adopted from https://github.com/huggingface/optimum/blob/main/optimum/bettertransformer/models/attention.py#L185
2709+ def _gpt_neo_attn_sdpa (
2710+ self ,
2711+ query : torch .Tensor ,
2712+ key : torch .Tensor ,
2713+ value : torch .Tensor ,
2714+ attention_mask : Optional [torch .Tensor ] = None ,
2715+ head_mask : Optional [torch .Tensor ] = None ,
2716+ ):
2717+ batch_size = query .shape [0 ]
2718+
2719+ mask_value = torch .finfo (torch .float16 ).min
2720+ mask_value = torch .full ([], mask_value , dtype = value .dtype )
2721+
2722+ dropout_p = float (self .config .attention_dropout ) if self .training else 0.0
2723+ if (batch_size == 1 or self .training ) and self .attention_type == "global" :
2724+ if query .shape [2 ] > 1 :
2725+ sdpa_result = torch .nn .functional .scaled_dot_product_attention (
2726+ query , key , value , attn_mask = None , dropout_p = dropout_p , is_causal = True
2727+ )
2728+ else :
2729+ sdpa_result = torch .nn .functional .scaled_dot_product_attention (
2730+ query , key , value , attn_mask = None , dropout_p = dropout_p , is_causal = False , scale = 1.0
2731+ )
2732+ else :
2733+ query_length , key_length = query .size (- 2 ), key .size (- 2 )
2734+
2735+ causal_mask = self .bias [:, :, key_length - query_length : key_length , :key_length ]
2736+
2737+ causal_mask = torch .where (causal_mask , 0 , mask_value )
2738+ if batch_size > 1 :
2739+ # torch.Tensor.expand does no memory copy
2740+ causal_mask = causal_mask .expand (batch_size , - 1 , - 1 , - 1 )
2741+
2742+ if attention_mask is not None :
2743+ attention_mask = causal_mask + attention_mask
2744+
2745+ sdpa_result = torch .nn .functional .scaled_dot_product_attention (
2746+ query , key , value , attn_mask = attention_mask , dropout_p = dropout_p , is_causal = False , scale = 1.0
2747+ )
2748+
2749+ return sdpa_result , None
2750+
2751+
2752+ class GptNeoModelPatcher (DecoderModelPatcher ):
2753+ def __enter__ (self ):
2754+ super ().__enter__ ()
2755+ if is_transformers_version (">=" , "4.45.0" ) and is_torch_version (">=" , "2.1.0" ):
2756+ self ._model .config ._orig_attn_implementation = self ._model .config ._attn_implementation
2757+ self ._model .config ._attn_implementation = "sdpa"
2758+ for layer in self ._model .transformer .h :
2759+ self_attn = layer .attn .attention
2760+ self_attn ._orig_attn = self_attn ._attn
2761+ self_attn ._attn = types .MethodType (_gpt_neo_attn_sdpa , self_attn )
2762+ self_attn ._orig_forward = types .MethodType (_gpt_neo_attn_forward , self_attn )
2763+
2764+ def __exit__ (self , exc_type , exc_value , traceback ):
2765+ super ().__exit__ (exc_type , exc_value , traceback )
2766+ if hasattr (self ._model .config , "_orig_attn_implementation" ):
2767+ self ._model .config ._attn_implementation = self ._model .config ._orig_attn_implementation
2768+ for layer in self ._model .transformer .h :
2769+ for layer in self ._model .transformer .h :
2770+ layer .attn .attention .forward = layer .attn .attention ._orig_forward
2771+ layer .attn .attention ._attn = layer .attn .attention ._orig_attn
2772+
2773+
26842774class Gemma2ModelPatcher (LlamaModelPatcher ):
26852775 def __init__ (
26862776 self ,
0 commit comments