diff --git a/examples/models/llama/static_attention.py b/examples/models/llama/static_attention.py index 69ee4e192e1..57b5796cbb3 100644 --- a/examples/models/llama/static_attention.py +++ b/examples/models/llama/static_attention.py @@ -352,7 +352,7 @@ def forward( x_r, x_i = x[..., ::2], x[..., 1::2] x_out_r = x_r * freqs_cos - x_i * freqs_sin x_out_i = x_r * freqs_sin + x_i * freqs_cos - x_out = torch.cat([x_out_r, x_out_i], dim=-1) + x_out = torch.stack([x_out_r, x_out_i], dim=-1).flatten(2) return x_out @@ -378,6 +378,7 @@ def __init__(self, config: ModelArgs, layer_id: int, rope: Rope): self.inv_scale = 1.0 / (float(self.head_dim) ** 0.5) self.attention_qkv_bias = config.attention_qkv_bias self.use_qk_norm = config.use_qk_norm + self.qk_norm_before_rope = config.qk_norm_before_rope self.use_conv2d = False self.wqs = nn.ModuleList( @@ -449,12 +450,17 @@ def from_conv2ds(ts): new_ks = from_conv2ds(new_ks) new_vs = from_conv2ds(new_vs) - if self.use_qk_norm: + if self.use_qk_norm and self.qk_norm_before_rope: new_qs = [self.q_norm(q) for q in new_qs] new_ks = [self.k_norm(k) for k in new_ks] new_qs = [self.rope(q, freqs_cos, freqs_sin) for q in new_qs] new_ks = [self.rope(k, freqs_cos, freqs_sin) for k in new_ks] + + if self.use_qk_norm and not self.qk_norm_before_rope: + new_qs = [self.q_norm(q) for q in new_qs] + new_ks = [self.k_norm(k) for k in new_ks] + all_ks = [] all_vs = [] for i in range(self.n_kv_heads): @@ -505,6 +511,7 @@ def load_weights_from_attention_mha(self, other: AttentionMHA): if other.use_qk_norm: self.use_qk_norm = True + self.qk_norm_before_rope = other.qk_norm_before_rope self.q_norm = torch.nn.RMSNorm(other.q_norm_fn.dim, other.q_norm_fn.eps) self.q_norm.load_state_dict(other.q_norm_fn.state_dict()) self.k_norm = torch.nn.RMSNorm(other.k_norm_fn.dim, other.k_norm_fn.eps) diff --git a/examples/models/llama/tests/test_static_attention.py b/examples/models/llama/tests/test_static_attention.py index e40643299ef..a6eac24db1f 100644 --- a/examples/models/llama/tests/test_static_attention.py +++ b/examples/models/llama/tests/test_static_attention.py @@ -30,6 +30,14 @@ def test(use_qk_norm, use_conv2d): rope = Rope(config) attn_mha = AttentionMHA(config, layer_id, rope).eval() static_attn = StaticAttention(config, layer_id, rope).eval() + if use_qk_norm: + with torch.no_grad(): + attn_mha.q_norm_fn.weight.copy_( + torch.rand(config.head_dim) * 0.2 + 0.9 + ) + attn_mha.k_norm_fn.weight.copy_( + torch.rand(config.head_dim) * 0.2 + 0.9 + ) static_attn.load_weights_from_attention_mha(attn_mha) if use_conv2d: static_attn.linear_to_conv2d() @@ -60,11 +68,15 @@ def test_hf_rope_without_cache(self): n_heads=4, n_kv_heads=2, max_seq_len=8, + use_qk_norm=True, use_hf_rope=True, ) layer_id = 0 rope = Rope(config) attn_mha = AttentionMHA(config, layer_id, rope).eval() + with torch.no_grad(): + attn_mha.q_norm_fn.weight.copy_(torch.rand(config.head_dim) * 0.2 + 0.9) + attn_mha.k_norm_fn.weight.copy_(torch.rand(config.head_dim) * 0.2 + 0.9) static_attn = StaticAttention(config, layer_id, rope).eval() static_attn.load_weights_from_attention_mha(attn_mha)