@@ -210,46 +210,6 @@ def hf_apply_rotary_emb_to_k(k, cos, sin, position_ids=None, unsqueeze_dim=1):
210210 return k_embed
211211
212212
213- # ======================= Qwen2 Implementation ========================
214-
215-
216- def qwen_precompute_freqs_cis (dim : int , end : int , theta : float = 1_000_000.0 ):
217- """
218- Precompute frequency tensor for Qwen2-style RoPE.
219- """
220- freqs = 1.0 / (
221- theta ** (torch .arange (0 , dim , 2 , device = "cpu" )[: (dim // 2 )].float () / dim )
222- )
223- t = torch .arange (end , device = freqs .device )
224- freqs = torch .outer (t , freqs ).float ()
225- freqs_cos = torch .cos (freqs )
226- freqs_sin = torch .sin (freqs )
227- return freqs_cos , freqs_sin
228-
229-
230- def qwen_apply_rotary_emb (
231- q : torch .Tensor , k : torch .Tensor , freqs_cos : torch .Tensor , freqs_sin : torch .Tensor
232- ) -> Tuple [torch .Tensor , torch .Tensor ]:
233- """
234- Apply Qwen2-style RoPE to query and key tensors.
235- """
236-
237- def rotate_half (x ):
238- """Rotates half the hidden dims of the input."""
239- x1 = x [..., : x .shape [- 1 ] // 2 ]
240- x2 = x [..., x .shape [- 1 ] // 2 :]
241- return torch .cat ((- x2 , x1 ), dim = - 1 )
242-
243- # Reshape cos and sin for broadcasting
244- cos = freqs_cos .unsqueeze (1 ) # [seq_len, 1, head_dim]
245- sin = freqs_sin .unsqueeze (1 ) # [seq_len, 1, head_dim]
246-
247- # Apply rotation
248- q_embed = (q * cos ) + (rotate_half (q ) * sin )
249- k_embed = (k * cos ) + (rotate_half (k ) * sin )
250- return q_embed , k_embed
251-
252-
253213class Rope (torch .nn .Module ):
254214 def __init__ (self , params : ModelArgs ):
255215 super ().__init__ ()
@@ -259,9 +219,6 @@ def __init__(self, params: ModelArgs):
259219 if self .params .use_hf_rope :
260220 self .precompute_freqs_cis = hf_precompute_freqs_cis
261221 self .apply_rotary_emb = hf_apply_rotary_emb
262- # elif self.params.use_qwen_rope:
263- # self.precompute_freqs_cis = qwen_precompute_freqs_cis
264- # self.apply_rotary_emb = qwen_apply_rotary_emb
265222 else :
266223 self .precompute_freqs_cis = partial (
267224 precompute_freqs_cis ,
0 commit comments