@@ -39,6 +39,24 @@ def apply_rotary_emb_single(
3939 return x_out
4040
4141
42+ def apply_partial_rotary_emb_single (
43+ x : torch .Tensor , freqs_cos : torch .Tensor , freqs_sin : torch .Tensor
44+ ) -> torch .Tensor :
45+
46+ if x .dim () == 4 :
47+ freqs_cos = freqs_cos [None , :, None , :]
48+ freqs_sin = freqs_sin [None , :, None , :]
49+
50+ rotary_dim = freqs_cos .shape [- 1 ] * 2
51+
52+ x_rot , x_pass = x [..., :rotary_dim ], x [..., rotary_dim :]
53+ x_r , x_i = x_rot [..., : x_rot .shape [- 1 ] // 2 ], x_rot [..., x_rot .shape [- 1 ] // 2 :]
54+ x_out_r = x_r * freqs_cos - x_i * freqs_sin
55+ x_out_i = x_r * freqs_sin + x_i * freqs_cos
56+ x_rotated = torch .cat ([x_out_r , x_out_i ], dim = - 1 )
57+ return torch .cat ([x_rotated , x_pass ], dim = - 1 )
58+
59+
4260class LlamaAttention (nn .Module ):
4361 def __init__ (self , config : ModelArgs , output_new_cache_only = False ):
4462 super ().__init__ ()
@@ -60,6 +78,11 @@ def __init__(self, config: ModelArgs, output_new_cache_only=False):
6078 self .q_norm_fn = torch .nn .RMSNorm (q_norm_dim , eps = config .norm_eps )
6179 self .k_norm_fn = torch .nn .RMSNorm (k_norm_dim , eps = config .norm_eps )
6280
81+ if config .partial_rotary_factor < 1 :
82+ self .apply_rope_emb = apply_partial_rotary_emb_single
83+ else :
84+ self .apply_rope_emb = apply_rotary_emb_single
85+
6386 self .wq = nn .Linear (
6487 self .dim ,
6588 self .n_heads * self .head_dim ,
@@ -199,17 +222,17 @@ def forward_sha( # noqa: C901
199222 for i in range (len (q )):
200223 if self .use_qk_norm and self .qk_norm_before_rope :
201224 q [i ] = self .q_norm_fn (q [i ])
202- q [i ] = apply_rotary_emb_single (q [i ], freqs_cos , freqs_sin )
225+ q [i ] = self . apply_rope_emb (q [i ], freqs_cos , freqs_sin )
203226 if hasattr (self .config , "enable_r3" ) and self .config .enable_r3 :
204- q [i ] = torch .matmul (q [i ], self .r3_weight . T )
227+ q [i ] = torch .matmul (q [i ], self .r3_weight )
205228 if self .use_qk_norm and not self .qk_norm_before_rope :
206229 q [i ] = self .q_norm_fn (q [i ])
207230 for i in range (len (k )):
208231 if self .use_qk_norm and self .qk_norm_before_rope :
209232 k [i ] = self .k_norm_fn (k [i ])
210- k [i ] = apply_rotary_emb_single (k [i ], freqs_cos , freqs_sin ).transpose (1 , 2 )
233+ k [i ] = self . apply_rope_emb (k [i ], freqs_cos , freqs_sin ).transpose (1 , 2 )
211234 if hasattr (self .config , "enable_r3" ) and self .config .enable_r3 :
212- k [i ] = torch .matmul (k [i ], self .r3_weight . T )
235+ k [i ] = torch .matmul (k [i ], self .r3_weight )
213236 if self .use_qk_norm and not self .qk_norm_before_rope :
214237 k [i ] = self .k_norm_fn (k [i ])
215238
@@ -272,8 +295,8 @@ def forward(
272295 q = self .q_norm_fn (q )
273296 k = self .k_norm_fn (k )
274297
275- q = apply_rotary_emb_single (q , freqs_cos , freqs_sin )
276- k = apply_rotary_emb_single (k , freqs_cos , freqs_sin ).permute (0 , 2 , 3 , 1 )
298+ q = self . apply_rope_emb (q , freqs_cos , freqs_sin )
299+ k = self . apply_rope_emb (k , freqs_cos , freqs_sin ).permute (0 , 2 , 3 , 1 )
277300
278301 if self .use_qk_norm and not self .qk_norm_before_rope :
279302 q = self .q_norm_fn (q )
@@ -368,7 +391,8 @@ def __init__(self, config: ModelArgs, output_new_cache_only=False):
368391 super ().__init__ ()
369392 self .dim = config .dim
370393 self .attention = LlamaAttention (
371- config = config , output_new_cache_only = output_new_cache_only
394+ config = config ,
395+ output_new_cache_only = output_new_cache_only ,
372396 )
373397 self .feed_forward = FeedForward (config )
374398 self .attention_norm = torch .nn .RMSNorm (config .dim , eps = config .norm_eps )
0 commit comments