Skip to content

Commit 6a08720

Browse files
authored
[BUG] 提升 GLM4 MOE Rope 的精度 (#2663)
1 parent f6ebf4d commit 6a08720

File tree

1 file changed

+17
-12
lines changed

1 file changed

+17
-12
lines changed

paddleformers/transformers/glm4_moe/modeling.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -882,19 +882,24 @@ def __init__(self, config: Glm4MoeConfig, device=None):
882882

883883
@paddle.no_grad()
884884
def forward(self, x, position_ids):
885-
inv_freq_expanded = (
886-
self.inv_freq.unsqueeze(0)
887-
.unsqueeze(-1)
888-
.cast(paddle.float32)
889-
.expand([position_ids.shape[0], -1, 1])
890-
.to(x.place)
891-
)
892-
position_ids_expanded = position_ids.unsqueeze(1).cast(paddle.float32)
885+
# NOTE: Paddle's Automatic Mixed Precision (AMP) has a default op whitelist that may automatically cast
886+
# certain operations (like matmul) to FP16/BF16 for performance optimization. However, in scenarios where
887+
# numerical stability is critical (e.g., RoPE init/compute), this conversion can lead to precision loss.
888+
# Disabling auto_cast here ensures the matmul operation runs in the original precision (FP32) as intended.
889+
with paddle.amp.auto_cast(False):
890+
inv_freq_expanded = (
891+
self.inv_freq.unsqueeze(0)
892+
.unsqueeze(-1)
893+
.cast(paddle.float32)
894+
.expand([position_ids.shape[0], -1, 1])
895+
.to(x.place)
896+
)
897+
position_ids_expanded = position_ids.unsqueeze(1).cast(paddle.float32)
893898

894-
freqs = paddle.matmul(inv_freq_expanded, position_ids_expanded).transpose([0, 2, 1])
895-
emb = paddle.cat((freqs, freqs), axis=-1)
896-
cos = paddle.cos(emb) * self.attention_scaling
897-
sin = paddle.sin(emb) * self.attention_scaling
899+
freqs = paddle.matmul(inv_freq_expanded, position_ids_expanded).transpose([0, 2, 1])
900+
emb = paddle.cat((freqs, freqs), axis=-1)
901+
cos = paddle.cos(emb) * self.attention_scaling
902+
sin = paddle.sin(emb) * self.attention_scaling
898903

899904
return cos.cast(dtype=x.dtype), sin.cast(dtype=x.dtype)
900905

0 commit comments

Comments
 (0)