Skip to content

Commit a12912a

Browse files
committed
Partial rotary embeddings
1 parent 617d811 commit a12912a

File tree

3 files changed

+13
-3
lines changed

3 files changed

+13
-3
lines changed

examples/models/llama/model_args.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class ModelArgs:
3838
apply_embedding: bool = True # Use embedding inside the transformer
3939
apply_output: bool = True # Use output layer (unembedding) inside the transformer
4040
use_hf_rope: bool = False # Use HuggingFace's RoPE implementation
41+
partial_rotary_factor: float = 1.0
4142
rope_theta: Optional[float] = (
4243
None # The official name to override self.rope_freq_base.
4344
)

examples/models/llama/rope.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,13 @@ def hf_apply_rotary_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
190190
"""
191191
cos = cos.unsqueeze(unsqueeze_dim)
192192
sin = sin.unsqueeze(unsqueeze_dim)
193-
q_embed = (q * cos) + (rotate_half(q) * sin)
194-
k_embed = (k * cos) + (rotate_half(k) * sin)
193+
194+
rotary_dim = cos.shape[-1]
195+
q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
196+
k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
197+
198+
q_embed = torch.cat([(q_rot * cos) + (rotate_half(q_rot) * sin), q_pass], dim=-1)
199+
k_embed = torch.cat([(k_rot * cos) + (rotate_half(k_rot) * sin), k_pass], dim=-1)
195200
return q_embed, k_embed
196201

197202

@@ -227,7 +232,10 @@ def __init__(self, params: ModelArgs):
227232

228233
# Choose the appropriate RoPE implementation
229234
if self.params.use_hf_rope:
230-
self.precompute_freqs_cis = hf_precompute_freqs_cis
235+
self.precompute_freqs_cis = partial(
236+
hf_precompute_freqs_cis,
237+
partial_rotary_factor=self.params.partial_rotary_factor,
238+
)
231239
self.apply_rotary_emb = hf_apply_rotary_emb
232240
else:
233241
self.precompute_freqs_cis = partial(

examples/models/phi-4-mini/config.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,6 @@
1010
"use_scaled_rope": false,
1111
"vocab_size": 200064,
1212
"use_hf_rope": true,
13+
"partial_rotary_factor": 0.75,
1314
"attention_qkv_bias": false
1415
}

0 commit comments

Comments
 (0)