Skip to content

Commit 0432fe2

Browse files
sxufacebook-github-bot
authored andcommitted
Static attention: support overriding RMSNorm class
Summary: CoreML backend should not use the default torch.nn.RMSNorm because it leads to worse precision, where as QNN backend should as it leads to a single operator that's more efficient. Support overriding the normalization layer class. Differential Revision: D78926603
1 parent 7e82d00 commit 0432fe2

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

examples/models/llama/static_attention.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -840,7 +840,9 @@ def _forward_mha(
840840

841841
return y.transpose(1, 2).contiguous().view(bsz, seq_len, -1), out_cache_state
842842

843-
def load_weights_from_attention_mha(self, other: AttentionMHA):
843+
def load_weights_from_attention_mha(
844+
self, other: AttentionMHA, rms_norm_class=torch.nn.RMSNorm
845+
):
844846
if self.split_mha:
845847
for i in range(self.n_heads):
846848
self.wqs[i].weight.data.copy_(
@@ -864,9 +866,13 @@ def load_weights_from_attention_mha(self, other: AttentionMHA):
864866
if other.use_qk_norm:
865867
self.use_qk_norm = True
866868
self.qk_norm_before_rope = other.qk_norm_before_rope
867-
self.q_norm = torch.nn.RMSNorm(other.q_norm_fn.dim, other.q_norm_fn.eps)
869+
self.q_norm = rms_norm_class(other.q_norm_fn.dim, other.q_norm_fn.eps).to(
870+
other.q_norm_fn.weight.dtype
871+
)
868872
self.q_norm.load_state_dict(other.q_norm_fn.state_dict())
869-
self.k_norm = torch.nn.RMSNorm(other.k_norm_fn.dim, other.k_norm_fn.eps)
873+
self.k_norm = rms_norm_class(other.k_norm_fn.dim, other.k_norm_fn.eps).to(
874+
other.k_norm_fn.weight.dtype
875+
)
870876
self.k_norm.load_state_dict(other.k_norm_fn.state_dict())
871877

872878
def adopt_hf_rope(self):

0 commit comments

Comments
 (0)