Skip to content

Commit 10ad5df

Browse files
sxufacebook-github-bot
authored andcommitted
Fix static attention RoPE implementation
Differential Revision: D76951243
1 parent a1dec07 commit 10ad5df

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

examples/models/llama/static_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ def forward(
352352
x_r, x_i = x[..., ::2], x[..., 1::2]
353353
x_out_r = x_r * freqs_cos - x_i * freqs_sin
354354
x_out_i = x_r * freqs_sin + x_i * freqs_cos
355-
x_out = torch.cat([x_out_r, x_out_i], dim=-1)
355+
x_out = torch.stack([x_out_r, x_out_i], dim=-1).flatten(2)
356356
return x_out
357357

358358

examples/models/llama/tests/test_static_attention.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,14 @@ def test(use_qk_norm, use_conv2d):
3030
rope = Rope(config)
3131
attn_mha = AttentionMHA(config, layer_id, rope).eval()
3232
static_attn = StaticAttention(config, layer_id, rope).eval()
33+
if use_qk_norm:
34+
with torch.no_grad():
35+
attn_mha.q_norm_fn.weight.copy_(
36+
torch.rand(config.head_dim) * 0.2 + 0.9
37+
)
38+
attn_mha.k_norm_fn.weight.copy_(
39+
torch.rand(config.head_dim) * 0.2 + 0.9
40+
)
3341
static_attn.load_weights_from_attention_mha(attn_mha)
3442
if use_conv2d:
3543
static_attn.linear_to_conv2d()

0 commit comments

Comments
 (0)