@@ -3650,3 +3650,93 @@ def __exit__(self, exc_type, exc_value, traceback):
36503650 block_sparse_moe .router .forward = block_sparse_moe .router ._orig_forward
36513651 block_sparse_moe .input_linear .forward = block_sparse_moe .input_linear ._orig_forward
36523652 block_sparse_moe .output_linear .forward = block_sparse_moe .output_linear ._orig_forward
3653+
3654+
3655+ # copied from https://github.com/huggingface/transformers/blob/v4.46.3/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py#L401
3656+ def gpt_bigcode_attn (self , query , key , value , attention_mask = None , head_mask = None ):
3657+ if head_mask is not None :
3658+ # The super dispatch is done in the forward.
3659+ raise ValueError ("PyTorch SDPA does not support head_mask. Please open an issue in Transformers repository." )
3660+
3661+ scale = None
3662+ if not self .scale_attn_weights :
3663+ scale = 1
3664+
3665+ # MQA models: (batch_size, query_length, num_heads * head_dim)
3666+ # MHA models: (batch_size, num_heads, query_length, head_dim)
3667+ query_shape = query .shape
3668+ batch_size = query_shape [0 ]
3669+ key .shape [- 2 ]
3670+
3671+ if self .multi_query :
3672+ query_length = query_shape [1 ]
3673+
3674+ # SDPA requires the dimension [..., sequence_length, head_dim].
3675+ query = query .view (batch_size , query_length , self .num_heads , self .head_dim ).transpose (1 , 2 )
3676+
3677+ # Without these unsqueeze, SDPA complains as the query and key/value have a different number of dimensions.
3678+ key = key .unsqueeze (1 )
3679+ value = value .unsqueeze (1 )
3680+
3681+ # Although these expand are not numerically useful, PyTorch can not dispatch to memory-efficient backend
3682+ # and flash attention backend (No available kernel. Aborting execution.) from the shapes
3683+ # query = [batch_size, num_heads, query_length, head_dim]
3684+ # key = [batch_size, 1, past_length, head_dim]
3685+ # value = [batch_size, 1, past_length, head_dim]
3686+ #
3687+ # torch==2.1.2 is bugged with non-contiguous inputs with custom attn_mask (https://github.com/pytorch/pytorch/issues/112577), hence the check.
3688+ if is_torch_version (">=" , "2.2.0" ):
3689+ key = key .expand (- 1 , self .num_heads , - 1 , - 1 )
3690+ value = value .expand (- 1 , self .num_heads , - 1 , - 1 )
3691+ else :
3692+ query_length = query_shape [- 1 ]
3693+
3694+ # See the comment above.
3695+ if query .device .type == "cuda" and attention_mask is not None :
3696+ query = query .contiguous ()
3697+ key = key .contiguous ()
3698+ value = value .contiguous ()
3699+
3700+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
3701+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
3702+ # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not
3703+ # create a causal mask in case query_length == 1.
3704+ is_causal = True if self .is_causal and attention_mask is None and query_length > 1 else False
3705+ # different from original, due to loading model weights in original format transformer.wte dtype may be different from query dtype
3706+ if attention_mask is not None :
3707+ attention_mask = attention_mask .to (query .dtype )
3708+ sdpa_result = torch .nn .functional .scaled_dot_product_attention (
3709+ query ,
3710+ key ,
3711+ value ,
3712+ attn_mask = attention_mask ,
3713+ dropout_p = self .attn_pdrop if self .training else 0.0 ,
3714+ is_causal = is_causal ,
3715+ scale = scale ,
3716+ )
3717+
3718+ if self .multi_query :
3719+ # (batch_size, num_heads, seq_len, head_dim) --> (batch_size, seq_len, num_heads, head_dim)
3720+ sdpa_result = sdpa_result .transpose (1 , 2 )
3721+
3722+ # Reshape is kind of expensive here, as it does a memory copy,
3723+ # but I did not manage to make away without it (logits do not match when using view)
3724+ # (batch_size, seq_len, num_heads, head_dim) --> (batch_size, seq_len, num_heads * head_dim)
3725+ sdpa_result = sdpa_result .reshape (query_shape )
3726+
3727+ return sdpa_result , None
3728+
3729+
3730+ class GptBigCodeModelPatcher (DecoderModelPatcher ):
3731+ def __enter__ (self ):
3732+ super ().__enter__ ()
3733+ if getattr (self ._model .config , "_attn_implementation" , "eager" ) == "sdpa" :
3734+ for layer in self ._model .transformer .h :
3735+ layer .attn ._orig_attn = layer .attn ._attn
3736+ layer .attn ._attn = types .MethodType (gpt_bigcode_attn , layer .attn )
3737+
3738+ def __exit__ (self , exc_type , exc_value , traceback ):
3739+ super ().__exit__ (exc_type , exc_value , traceback )
3740+ if getattr (self ._model .config , "_attn_implementation" , "eager" ) == "sdpa" :
3741+ for layer in self ._model .transformer .h :
3742+ layer .attn ._attn = layer .attn ._orig_attn
0 commit comments