@@ -951,15 +951,107 @@ def __exit__(self, exc_type, exc_value, traceback):
951951 block .attention .forward = block .attention ._orig_forward
952952
953953
954+ # Adapted from https://github.com/huggingface/transformers/blob/ccdabc5642bf84849af93f591e207dc625c8e1e1/src/transformers/models/phi3/modeling_phi3.py#L426
955+ def _phi3_self_attn_sdpa_forward (
956+ self ,
957+ hidden_states : torch .Tensor ,
958+ attention_mask : Optional [torch .Tensor ] = None ,
959+ position_ids : Optional [torch .LongTensor ] = None ,
960+ past_key_value : Optional [Tuple [torch .Tensor ]] = None ,
961+ output_attentions : bool = False ,
962+ use_cache : bool = False ,
963+ ) -> Tuple [torch .Tensor , Optional [torch .Tensor ], Optional [Tuple [torch .Tensor ]]]:
964+ if output_attentions :
965+ return self ._orig_forward (
966+ hidden_states = hidden_states ,
967+ attention_mask = attention_mask ,
968+ position_ids = position_ids ,
969+ past_key_value = past_key_value ,
970+ output_attentions = output_attentions ,
971+ use_cache = use_cache ,
972+ )
973+
974+ # TO DO: remove llama imports when transformers with phi3 support will be released
975+ try :
976+ from transformers .models .phi3 .modelling_phi3 import apply_rotary_pos_emb , repeat_kv
977+ except ImportError :
978+ from transformers .models .llama .modeling_llama import apply_rotary_pos_emb , repeat_kv
979+
980+ bsz , q_len , _ = hidden_states .size ()
981+
982+ qkv = self .qkv_proj (hidden_states )
983+ query_pos = self .num_heads * self .head_dim
984+ query_states = qkv [..., :query_pos ]
985+ key_states = qkv [..., query_pos : query_pos + self .num_key_value_heads * self .head_dim ]
986+ value_states = qkv [..., query_pos + self .num_key_value_heads * self .head_dim :]
987+
988+ query_states = query_states .view (bsz , q_len , self .num_heads , self .head_dim ).transpose (1 , 2 )
989+ key_states = key_states .view (bsz , q_len , self .num_key_value_heads , self .head_dim ).transpose (1 , 2 )
990+ value_states = value_states .view (bsz , q_len , self .num_key_value_heads , self .head_dim ).transpose (1 , 2 )
991+ kv_seq_len = key_states .shape [- 2 ]
992+ if past_key_value is not None :
993+ kv_seq_len += past_key_value .get_usable_length (kv_seq_len , self .layer_idx )
994+ cos , sin = self .rotary_emb (value_states , position_ids , seq_len = kv_seq_len )
995+
996+ query_states , key_states = apply_rotary_pos_emb (query_states , key_states , cos , sin , position_ids )
997+
998+ if past_key_value is not None :
999+ cache_kwargs = {"sin" : sin , "cos" : cos } # Specific to RoPE models
1000+ key_states , value_states = past_key_value .update (key_states , value_states , self .layer_idx , cache_kwargs )
1001+
1002+ key_states = repeat_kv (key_states , self .num_key_value_groups )
1003+ value_states = repeat_kv (value_states , self .num_key_value_groups )
1004+
1005+ if attention_mask is not None :
1006+ if attention_mask .size () != (bsz , 1 , q_len , kv_seq_len ):
1007+ raise ValueError (
1008+ f"Attention mask should be of size { (bsz , 1 , q_len , kv_seq_len )} , but is { attention_mask .size ()} "
1009+ )
1010+
1011+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
1012+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
1013+ if query_states .device .type == "cuda" and attention_mask is not None :
1014+ query_states = query_states .contiguous ()
1015+ key_states = key_states .contiguous ()
1016+ value_states = value_states .contiguous ()
1017+
1018+ attn_output = torch .nn .functional .scaled_dot_product_attention (
1019+ query_states ,
1020+ key_states ,
1021+ value_states ,
1022+ attn_mask = attention_mask ,
1023+ dropout_p = self .attention_dropout if self .training else 0.0 ,
1024+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
1025+ is_causal = self .is_causal and attention_mask is None and q_len > 1 ,
1026+ )
1027+
1028+ attn_output = attn_output .transpose (1 , 2 ).contiguous ()
1029+ attn_output = attn_output .view (bsz , q_len , self .hidden_size )
1030+
1031+ attn_output = self .o_proj (attn_output )
1032+
1033+ return attn_output , None , past_key_value
1034+
1035+
9541036class Phi3ModelPatcher (DecoderModelPatcher ):
9551037 def __enter__ (self ):
9561038 super ().__enter__ ()
957-
9581039 # https://github.com/huggingface/transformers/blob/30ee508c6c92a1c0aa0281d193c7c0fb815b8d2f/src/transformers/models/phi3/modeling_phi3.py#L113
9591040 # init inv_freq for torchscript tracing
9601041 for layer in self ._model .model .layers :
1042+ if is_torch_version (">=" , "2.1.0" ):
1043+ orig_self_attn_fwd = layer .self_attn .forward
1044+ layer .self_attn .forward = types .MethodType (_phi3_self_attn_sdpa_forward , layer .self_attn )
1045+ layer .self_attn ._orig_forward = orig_self_attn_fwd
1046+
9611047 if layer .self_attn .rotary_emb .inv_freq is None :
9621048 rotary_emb = layer .self_attn .rotary_emb
9631049 layer .self_attn .rotary_emb .inv_freq = 1.0 / (
9641050 rotary_emb .base ** (torch .arange (0 , rotary_emb .dim , 2 , dtype = torch .int64 ).float () / rotary_emb .dim )
9651051 )
1052+
1053+ def __exit__ (self , exc_type , exc_value , traceback ):
1054+ super ().__exit__ (exc_type , exc_value , traceback )
1055+ for layer in self ._model .model .layers :
1056+ if hasattr (layer .self_attn , "_orig_forward" ):
1057+ layer .self_attn .forward = layer .self_attn ._orig_forward
0 commit comments