@@ -1161,7 +1161,7 @@ def __exit__(self, exc_type, exc_value, traceback):
11611161 block .attention .forward = block .attention ._orig_forward
11621162
11631163
1164- # Adapted from https://github.com/huggingface/transformers/blob/ccdabc5642bf84849af93f591e207dc625c8e1e1/src/transformers/models/phi3/modeling_phi3.py#L426
1164+ # Adapted from https://github.com/huggingface/transformers/blob/ccdabc5642bf84849af93f591e207dc625c8e1e1/src/transformers/models/phi3/modeling_phi3.py#L729
11651165def _phi3_self_attn_sdpa_forward (
11661166 self ,
11671167 hidden_states : torch .Tensor ,
@@ -1170,6 +1170,7 @@ def _phi3_self_attn_sdpa_forward(
11701170 past_key_value : Optional [Tuple [torch .Tensor ]] = None ,
11711171 output_attentions : bool = False ,
11721172 use_cache : bool = False ,
1173+ cache_position : Optional [torch .LongTensor ] = None ,
11731174) -> Tuple [torch .Tensor , Optional [torch .Tensor ], Optional [Tuple [torch .Tensor ]]]:
11741175 if output_attentions :
11751176 return self ._orig_forward (
@@ -1181,10 +1182,9 @@ def _phi3_self_attn_sdpa_forward(
11811182 use_cache = use_cache ,
11821183 )
11831184
1184- # TO DO: remove llama imports when transformers with phi3 support will be released
1185- try :
1185+ if is_transformers_version (">=" , "4.41.0" ):
11861186 from transformers .models .phi3 .modeling_phi3 import apply_rotary_pos_emb , repeat_kv
1187- except ImportError :
1187+ else :
11881188 from transformers .models .llama .modeling_llama import apply_rotary_pos_emb , repeat_kv
11891189
11901190 bsz , q_len , _ = hidden_states .size ()
@@ -1206,17 +1206,15 @@ def _phi3_self_attn_sdpa_forward(
12061206 query_states , key_states = apply_rotary_pos_emb (query_states , key_states , cos , sin , position_ids )
12071207
12081208 if past_key_value is not None :
1209- cache_kwargs = {"sin" : sin , "cos" : cos } # Specific to RoPE models
1209+ cache_kwargs = {"sin" : sin , "cos" : cos , "cache_position" : cache_position } # Specific to RoPE models
12101210 key_states , value_states = past_key_value .update (key_states , value_states , self .layer_idx , cache_kwargs )
12111211
12121212 key_states = repeat_kv (key_states , self .num_key_value_groups )
12131213 value_states = repeat_kv (value_states , self .num_key_value_groups )
12141214
1215+ causal_mask = attention_mask
12151216 if attention_mask is not None :
1216- if attention_mask .size () != (bsz , 1 , q_len , kv_seq_len ):
1217- raise ValueError (
1218- f"Attention mask should be of size { (bsz , 1 , q_len , kv_seq_len )} , but is { attention_mask .size ()} "
1219- )
1217+ causal_mask = attention_mask [:, :, :, : key_states .shape [- 2 ]]
12201218
12211219 # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
12221220 # Reference: https://github.com/pytorch/pytorch/issues/112577.
@@ -1229,7 +1227,7 @@ def _phi3_self_attn_sdpa_forward(
12291227 query_states ,
12301228 key_states ,
12311229 value_states ,
1232- attn_mask = attention_mask ,
1230+ attn_mask = causal_mask ,
12331231 dropout_p = self .attention_dropout if self .training else 0.0 ,
12341232 # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
12351233 is_causal = self .is_causal and attention_mask is None and q_len > 1 ,
@@ -1561,7 +1559,7 @@ def __exit__(self, exc_type, exc_value, traceback):
15611559 layer .attn ._attn = layer .attn ._orig_attn
15621560
15631561
1564- # adapted from https://github.com/huggingface/transformers/blob/v4.40.2/src/transformers/models/dbrx/modeling_dbrx.py#L763
1562+ # Adapted from https://github.com/huggingface/transformers/blob/v4.40.2/src/transformers/models/dbrx/modeling_dbrx.py#L763
15651563def _dbrx_experts_forward (
15661564 self , x : torch .Tensor , weights : torch .Tensor , top_weights : torch .Tensor , top_experts : torch .LongTensor
15671565):
@@ -1606,7 +1604,7 @@ def _dbrx_experts_forward(
16061604 return out
16071605
16081606
1609- # adapted from https://github.com/huggingface/transformers/blob/v4.40.2/src/transformers/models/dbrx/modeling_dbrx.py#L1228
1607+ # Adapted from https://github.com/huggingface/transformers/blob/v4.40.2/src/transformers/models/dbrx/modeling_dbrx.py#L1228
16101608def _dbrx_update_causal_mask_legacy (
16111609 self , attention_mask : Optional [torch .Tensor ], input_tensor : torch .Tensor , cache_position : torch .Tensor
16121610) -> Optional [torch .Tensor ]:
@@ -1803,6 +1801,7 @@ def __exit__(self, exc_type, exc_value, traceback):
18031801 block .ffn .experts .forward = block .ffn .experts ._orig_forward
18041802
18051803
1804+ # Adapted from https://github.com/huggingface/transformers/blob/v4.41.0/src/transformers/models/persimmon/modeling_persimmon.py#L264
18061805def _persimmon_self_attn_sdpa_forward (
18071806 self ,
18081807 hidden_states : torch .Tensor ,
@@ -1811,6 +1810,7 @@ def _persimmon_self_attn_sdpa_forward(
18111810 past_key_value : Optional ["Cache" ] = None ,
18121811 output_attentions : bool = False ,
18131812 use_cache : bool = False ,
1813+ cache_position : Optional [torch .LongTensor ] = None ,
18141814) -> Tuple [torch .Tensor , Optional [torch .Tensor ], Optional [Tuple [torch .Tensor ]]]:
18151815 from transformers .models .persimmon .modeling_persimmon import apply_rotary_pos_emb
18161816
@@ -1865,14 +1865,23 @@ def _persimmon_self_attn_sdpa_forward(
18651865
18661866 if past_key_value is not None :
18671867 # Specific to RoPE models with partial rotation
1868- cache_kwargs = {"sin" : sin , "cos" : cos , "partial_rotation_size" : self .rotary_emb .dim }
1868+ cache_kwargs = {
1869+ "sin" : sin ,
1870+ "cos" : cos ,
1871+ "partial_rotation_size" : self .rotary_emb .dim ,
1872+ "cache_position" : cache_position ,
1873+ }
18691874 key_states , value_states = past_key_value .update (key_states , value_states , self .layer_idx , cache_kwargs )
18701875
1876+ causal_mask = attention_mask
1877+ if attention_mask is not None : # no matter the length, we just slice it
1878+ causal_mask = attention_mask [:, :, :, : key_states .shape [- 2 ]]
1879+
18711880 attn_output = F .scaled_dot_product_attention (
18721881 query_states ,
18731882 key_states ,
18741883 value_states ,
1875- attention_mask ,
1884+ causal_mask ,
18761885 scale = 1 / math .sqrt (self .head_dim ),
18771886 dropout_p = self .attention_dropout .p ,
18781887 )
0 commit comments