@@ -39,6 +39,29 @@ def apply_rotary_emb_single(
3939 return x_out
4040
4141
42+ def apply_partial_rotary_emb_single (
43+ x : torch .Tensor , freqs_cos : torch .Tensor , freqs_sin : torch .Tensor
44+ ) -> torch .Tensor :
45+
46+ if x .dim () == 4 :
47+ freqs_cos = freqs_cos [None , :, None , :]
48+ freqs_sin = freqs_sin [None , :, None , :]
49+
50+ rotary_dim = freqs_cos .shape [- 1 ] * 2
51+
52+ x_rot , x_pass = x [..., :rotary_dim ], x [..., rotary_dim :]
53+ x_r , x_i = x_rot [..., : x_rot .shape [- 1 ] // 2 ], x_rot [..., x_rot .shape [- 1 ] // 2 :]
54+ x_out_r = x_r * freqs_cos - x_i * freqs_sin
55+ x_out_i = x_r * freqs_sin + x_i * freqs_cos
56+ x_rotated = torch .cat ([x_out_r , x_out_i ], dim = - 1 )
57+ return torch .cat ([x_rotated , x_pass ], dim = - 1 )
58+
59+
60+ APPLY_ROPE_EMBEDDING_FUNCTIONS = {
61+ "phi_4_mini" : apply_partial_rotary_emb_single ,
62+ }
63+
64+
4265class LlamaAttention (nn .Module ):
4366 def __init__ (self , config : ModelArgs , output_new_cache_only = False ):
4467 super ().__init__ ()
@@ -59,6 +82,9 @@ def __init__(self, config: ModelArgs, output_new_cache_only=False):
5982 k_norm_dim = self .head_dim
6083 self .q_norm_fn = torch .nn .RMSNorm (q_norm_dim , eps = config .norm_eps )
6184 self .k_norm_fn = torch .nn .RMSNorm (k_norm_dim , eps = config .norm_eps )
85+ self .apply_rope_emb = APPLY_ROPE_EMBEDDING_FUNCTIONS .get (
86+ config .base_model_name_or_path , apply_rotary_emb_single
87+ )
6288
6389 self .wq = nn .Linear (
6490 self .dim ,
@@ -199,17 +225,17 @@ def forward_sha( # noqa: C901
199225 for i in range (len (q )):
200226 if self .use_qk_norm and self .qk_norm_before_rope :
201227 q [i ] = self .q_norm_fn (q [i ])
202- q [i ] = apply_rotary_emb_single (q [i ], freqs_cos , freqs_sin )
228+ q [i ] = self . apply_rope_emb (q [i ], freqs_cos , freqs_sin )
203229 if hasattr (self .config , "enable_r3" ) and self .config .enable_r3 :
204- q [i ] = torch .matmul (q [i ], self .r3_weight . T )
230+ q [i ] = torch .matmul (q [i ], self .r3_weight )
205231 if self .use_qk_norm and not self .qk_norm_before_rope :
206232 q [i ] = self .q_norm_fn (q [i ])
207233 for i in range (len (k )):
208234 if self .use_qk_norm and self .qk_norm_before_rope :
209235 k [i ] = self .k_norm_fn (k [i ])
210- k [i ] = apply_rotary_emb_single (k [i ], freqs_cos , freqs_sin ).transpose (1 , 2 )
236+ k [i ] = self . apply_rope_emb (k [i ], freqs_cos , freqs_sin ).transpose (1 , 2 )
211237 if hasattr (self .config , "enable_r3" ) and self .config .enable_r3 :
212- k [i ] = torch .matmul (k [i ], self .r3_weight . T )
238+ k [i ] = torch .matmul (k [i ], self .r3_weight )
213239 if self .use_qk_norm and not self .qk_norm_before_rope :
214240 k [i ] = self .k_norm_fn (k [i ])
215241
@@ -272,8 +298,8 @@ def forward(
272298 q = self .q_norm_fn (q )
273299 k = self .k_norm_fn (k )
274300
275- q = apply_rotary_emb_single (q , freqs_cos , freqs_sin )
276- k = apply_rotary_emb_single (k , freqs_cos , freqs_sin ).permute (0 , 2 , 3 , 1 )
301+ q = self . apply_rope_emb (q , freqs_cos , freqs_sin )
302+ k = self . apply_rope_emb (k , freqs_cos , freqs_sin ).permute (0 , 2 , 3 , 1 )
277303
278304 if self .use_qk_norm and not self .qk_norm_before_rope :
279305 q = self .q_norm_fn (q )
@@ -368,7 +394,8 @@ def __init__(self, config: ModelArgs, output_new_cache_only=False):
368394 super ().__init__ ()
369395 self .dim = config .dim
370396 self .attention = LlamaAttention (
371- config = config , output_new_cache_only = output_new_cache_only
397+ config = config ,
398+ output_new_cache_only = output_new_cache_only ,
372399 )
373400 self .feed_forward = FeedForward (config )
374401 self .attention_norm = torch .nn .RMSNorm (config .dim , eps = config .norm_eps )
0 commit comments