Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions examples/models/llama/static_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def forward(
x_r, x_i = x[..., ::2], x[..., 1::2]
x_out_r = x_r * freqs_cos - x_i * freqs_sin
x_out_i = x_r * freqs_sin + x_i * freqs_cos
x_out = torch.cat([x_out_r, x_out_i], dim=-1)
x_out = torch.stack([x_out_r, x_out_i], dim=-1).flatten(2)
return x_out


Expand All @@ -378,6 +378,7 @@ def __init__(self, config: ModelArgs, layer_id: int, rope: Rope):
self.inv_scale = 1.0 / (float(self.head_dim) ** 0.5)
self.attention_qkv_bias = config.attention_qkv_bias
self.use_qk_norm = config.use_qk_norm
self.qk_norm_before_rope = config.qk_norm_before_rope
self.use_conv2d = False

self.wqs = nn.ModuleList(
Expand Down Expand Up @@ -449,12 +450,17 @@ def from_conv2ds(ts):
new_ks = from_conv2ds(new_ks)
new_vs = from_conv2ds(new_vs)

if self.use_qk_norm:
if self.use_qk_norm and self.qk_norm_before_rope:
new_qs = [self.q_norm(q) for q in new_qs]
new_ks = [self.k_norm(k) for k in new_ks]

new_qs = [self.rope(q, freqs_cos, freqs_sin) for q in new_qs]
new_ks = [self.rope(k, freqs_cos, freqs_sin) for k in new_ks]

if self.use_qk_norm and not self.qk_norm_before_rope:
new_qs = [self.q_norm(q) for q in new_qs]
new_ks = [self.k_norm(k) for k in new_ks]

all_ks = []
all_vs = []
for i in range(self.n_kv_heads):
Expand Down Expand Up @@ -505,6 +511,7 @@ def load_weights_from_attention_mha(self, other: AttentionMHA):

if other.use_qk_norm:
self.use_qk_norm = True
self.qk_norm_before_rope = other.qk_norm_before_rope
self.q_norm = torch.nn.RMSNorm(other.q_norm_fn.dim, other.q_norm_fn.eps)
self.q_norm.load_state_dict(other.q_norm_fn.state_dict())
self.k_norm = torch.nn.RMSNorm(other.k_norm_fn.dim, other.k_norm_fn.eps)
Expand Down
12 changes: 12 additions & 0 deletions examples/models/llama/tests/test_static_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ def test(use_qk_norm, use_conv2d):
rope = Rope(config)
attn_mha = AttentionMHA(config, layer_id, rope).eval()
static_attn = StaticAttention(config, layer_id, rope).eval()
if use_qk_norm:
with torch.no_grad():
attn_mha.q_norm_fn.weight.copy_(
torch.rand(config.head_dim) * 0.2 + 0.9
)
attn_mha.k_norm_fn.weight.copy_(
torch.rand(config.head_dim) * 0.2 + 0.9
)
static_attn.load_weights_from_attention_mha(attn_mha)
if use_conv2d:
static_attn.linear_to_conv2d()
Expand Down Expand Up @@ -60,11 +68,15 @@ def test_hf_rope_without_cache(self):
n_heads=4,
n_kv_heads=2,
max_seq_len=8,
use_qk_norm=True,
use_hf_rope=True,
)
layer_id = 0
rope = Rope(config)
attn_mha = AttentionMHA(config, layer_id, rope).eval()
with torch.no_grad():
attn_mha.q_norm_fn.weight.copy_(torch.rand(config.head_dim) * 0.2 + 0.9)
attn_mha.k_norm_fn.weight.copy_(torch.rand(config.head_dim) * 0.2 + 0.9)
static_attn = StaticAttention(config, layer_id, rope).eval()
static_attn.load_weights_from_attention_mha(attn_mha)

Expand Down
Loading