@@ -34,11 +34,35 @@ def apply_rotary_emb_single(
3434 freqs_sin = freqs_sin [None , :, None , :]
3535 x_out_r = x_r * freqs_cos - x_i * freqs_sin
3636 x_out_i = x_r * freqs_sin + x_i * freqs_cos
37-
3837 x_out = torch .cat ([x_out_r , x_out_i ], dim = - 1 )
3938 return x_out
4039
4140
41+ def apply_partial_rotary_emb_single (
42+ x : torch .Tensor , freqs_cos : torch .Tensor , freqs_sin : torch .Tensor
43+ ) -> torch .Tensor :
44+
45+ if x .dim () == 4 :
46+ freqs_cos = freqs_cos [None , :, None , :]
47+ freqs_sin = freqs_sin [None , :, None , :]
48+
49+ rotary_dim = freqs_cos .shape [- 1 ] * 2
50+
51+ x_rot , x_pass = x [..., :rotary_dim ], x [..., rotary_dim :]
52+ x_r , x_i = x_rot [..., : x_rot .shape [- 1 ] // 2 ], x_rot [..., x_rot .shape [- 1 ] // 2 :]
53+ x_out_r = x_r * freqs_cos - x_i * freqs_sin
54+ x_out_i = x_r * freqs_sin + x_i * freqs_cos
55+ x_rotated = torch .cat ([x_out_r , x_out_i ], dim = - 1 )
56+ return torch .cat ([x_rotated , x_pass ], dim = - 1 )
57+
58+
59+ APPLY_ROPE_EMBEDDING_FUNCTIONS = {
60+ "phi_4_mini" : apply_partial_rotary_emb_single ,
61+ "qwen2_5" : apply_rotary_emb_single ,
62+ "llama3_2" : apply_rotary_emb_single ,
63+ }
64+
65+
4266class LlamaAttention (nn .Module ):
4367 def __init__ (self , config : ModelArgs , output_new_cache_only = False ):
4468 super ().__init__ ()
@@ -59,6 +83,9 @@ def __init__(self, config: ModelArgs, output_new_cache_only=False):
5983 k_norm_dim = self .head_dim
6084 self .q_norm_fn = torch .nn .RMSNorm (q_norm_dim , eps = config .norm_eps )
6185 self .k_norm_fn = torch .nn .RMSNorm (k_norm_dim , eps = config .norm_eps )
86+ self .apply_rope_emb = APPLY_ROPE_EMBEDDING_FUNCTIONS [
87+ config .base_model_name_or_path
88+ ]
6289
6390 self .wq = nn .Linear (
6491 self .dim ,
@@ -199,15 +226,15 @@ def forward_sha( # noqa: C901
199226 for i in range (len (q )):
200227 if self .use_qk_norm and self .qk_norm_before_rope :
201228 q [i ] = self .q_norm_fn (q [i ])
202- q [i ] = apply_rotary_emb_single (q [i ], freqs_cos , freqs_sin )
229+ q [i ] = self . apply_rope_emb (q [i ], freqs_cos , freqs_sin )
203230 if hasattr (self .config , "enable_r3" ) and self .config .enable_r3 :
204231 q [i ] = torch .matmul (q [i ], self .r3_weight .T )
205232 if self .use_qk_norm and not self .qk_norm_before_rope :
206233 q [i ] = self .q_norm_fn (q [i ])
207234 for i in range (len (k )):
208235 if self .use_qk_norm and self .qk_norm_before_rope :
209236 k [i ] = self .k_norm_fn (k [i ])
210- k [i ] = apply_rotary_emb_single (k [i ], freqs_cos , freqs_sin ).transpose (1 , 2 )
237+ k [i ] = self . apply_rope_emb (k [i ], freqs_cos , freqs_sin ).transpose (1 , 2 )
211238 if hasattr (self .config , "enable_r3" ) and self .config .enable_r3 :
212239 k [i ] = torch .matmul (k [i ], self .r3_weight .T )
213240 if self .use_qk_norm and not self .qk_norm_before_rope :
@@ -272,8 +299,8 @@ def forward(
272299 q = self .q_norm_fn (q )
273300 k = self .k_norm_fn (k )
274301
275- q = apply_rotary_emb_single (q , freqs_cos , freqs_sin )
276- k = apply_rotary_emb_single (k , freqs_cos , freqs_sin ).permute (0 , 2 , 3 , 1 )
302+ q = self . apply_rope_emb (q , freqs_cos , freqs_sin )
303+ k = self . apply_rope_emb (k , freqs_cos , freqs_sin ).permute (0 , 2 , 3 , 1 )
277304
278305 if self .use_qk_norm and not self .qk_norm_before_rope :
279306 q = self .q_norm_fn (q )
@@ -368,7 +395,8 @@ def __init__(self, config: ModelArgs, output_new_cache_only=False):
368395 super ().__init__ ()
369396 self .dim = config .dim
370397 self .attention = LlamaAttention (
371- config = config , output_new_cache_only = output_new_cache_only
398+ config = config ,
399+ output_new_cache_only = output_new_cache_only ,
372400 )
373401 self .feed_forward = FeedForward (config )
374402 self .attention_norm = torch .nn .RMSNorm (config .dim , eps = config .norm_eps )
0 commit comments