Skip to content

Commit 32a74ce

Browse files
committed
Remove qwen rope, use hf rope instead
1 parent e49fe94 commit 32a74ce

File tree

1 file changed

+0
-43
lines changed

1 file changed

+0
-43
lines changed

examples/models/llama/rope.py

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -210,46 +210,6 @@ def hf_apply_rotary_emb_to_k(k, cos, sin, position_ids=None, unsqueeze_dim=1):
210210
return k_embed
211211

212212

213-
# ======================= Qwen2 Implementation ========================
214-
215-
216-
def qwen_precompute_freqs_cis(dim: int, end: int, theta: float = 1_000_000.0):
217-
"""
218-
Precompute frequency tensor for Qwen2-style RoPE.
219-
"""
220-
freqs = 1.0 / (
221-
theta ** (torch.arange(0, dim, 2, device="cpu")[: (dim // 2)].float() / dim)
222-
)
223-
t = torch.arange(end, device=freqs.device)
224-
freqs = torch.outer(t, freqs).float()
225-
freqs_cos = torch.cos(freqs)
226-
freqs_sin = torch.sin(freqs)
227-
return freqs_cos, freqs_sin
228-
229-
230-
def qwen_apply_rotary_emb(
231-
q: torch.Tensor, k: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor
232-
) -> Tuple[torch.Tensor, torch.Tensor]:
233-
"""
234-
Apply Qwen2-style RoPE to query and key tensors.
235-
"""
236-
237-
def rotate_half(x):
238-
"""Rotates half the hidden dims of the input."""
239-
x1 = x[..., : x.shape[-1] // 2]
240-
x2 = x[..., x.shape[-1] // 2 :]
241-
return torch.cat((-x2, x1), dim=-1)
242-
243-
# Reshape cos and sin for broadcasting
244-
cos = freqs_cos.unsqueeze(1) # [seq_len, 1, head_dim]
245-
sin = freqs_sin.unsqueeze(1) # [seq_len, 1, head_dim]
246-
247-
# Apply rotation
248-
q_embed = (q * cos) + (rotate_half(q) * sin)
249-
k_embed = (k * cos) + (rotate_half(k) * sin)
250-
return q_embed, k_embed
251-
252-
253213
class Rope(torch.nn.Module):
254214
def __init__(self, params: ModelArgs):
255215
super().__init__()
@@ -259,9 +219,6 @@ def __init__(self, params: ModelArgs):
259219
if self.params.use_hf_rope:
260220
self.precompute_freqs_cis = hf_precompute_freqs_cis
261221
self.apply_rotary_emb = hf_apply_rotary_emb
262-
# elif self.params.use_qwen_rope:
263-
# self.precompute_freqs_cis = qwen_precompute_freqs_cis
264-
# self.apply_rotary_emb = qwen_apply_rotary_emb
265222
else:
266223
self.precompute_freqs_cis = partial(
267224
precompute_freqs_cis,

0 commit comments

Comments
 (0)