Skip to content

Commit 0211a03

Browse files
authored
Static attention: support overriding RMSNorm class
Differential Revision: D78926603 Pull Request resolved: #12833
1 parent 69f3da0 commit 0211a03

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

examples/models/llama/static_attention.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -840,7 +840,9 @@ def _forward_mha(
840840

841841
return y.transpose(1, 2).contiguous().view(bsz, seq_len, -1), out_cache_state
842842

843-
def load_weights_from_attention_mha(self, other: AttentionMHA):
843+
def load_weights_from_attention_mha(
844+
self, other: AttentionMHA, rms_norm_class=torch.nn.RMSNorm
845+
):
844846
if self.split_mha:
845847
for i in range(self.n_heads):
846848
self.wqs[i].weight.data.copy_(
@@ -864,9 +866,13 @@ def load_weights_from_attention_mha(self, other: AttentionMHA):
864866
if other.use_qk_norm:
865867
self.use_qk_norm = True
866868
self.qk_norm_before_rope = other.qk_norm_before_rope
867-
self.q_norm = torch.nn.RMSNorm(other.q_norm_fn.dim, other.q_norm_fn.eps)
869+
self.q_norm = rms_norm_class(other.q_norm_fn.dim, other.q_norm_fn.eps).to(
870+
other.q_norm_fn.weight.dtype
871+
)
868872
self.q_norm.load_state_dict(other.q_norm_fn.state_dict())
869-
self.k_norm = torch.nn.RMSNorm(other.k_norm_fn.dim, other.k_norm_fn.eps)
873+
self.k_norm = rms_norm_class(other.k_norm_fn.dim, other.k_norm_fn.eps).to(
874+
other.k_norm_fn.weight.dtype
875+
)
870876
self.k_norm.load_state_dict(other.k_norm_fn.state_dict())
871877

872878
def adopt_hf_rope(self):

0 commit comments

Comments
 (0)