@@ -882,19 +882,24 @@ def __init__(self, config: Glm4MoeConfig, device=None):
882
882
883
883
@paddle .no_grad ()
884
884
def forward (self , x , position_ids ):
885
- inv_freq_expanded = (
886
- self .inv_freq .unsqueeze (0 )
887
- .unsqueeze (- 1 )
888
- .cast (paddle .float32 )
889
- .expand ([position_ids .shape [0 ], - 1 , 1 ])
890
- .to (x .place )
891
- )
892
- position_ids_expanded = position_ids .unsqueeze (1 ).cast (paddle .float32 )
885
+ # NOTE: Paddle's Automatic Mixed Precision (AMP) has a default op whitelist that may automatically cast
886
+ # certain operations (like matmul) to FP16/BF16 for performance optimization. However, in scenarios where
887
+ # numerical stability is critical (e.g., RoPE init/compute), this conversion can lead to precision loss.
888
+ # Disabling auto_cast here ensures the matmul operation runs in the original precision (FP32) as intended.
889
+ with paddle .amp .auto_cast (False ):
890
+ inv_freq_expanded = (
891
+ self .inv_freq .unsqueeze (0 )
892
+ .unsqueeze (- 1 )
893
+ .cast (paddle .float32 )
894
+ .expand ([position_ids .shape [0 ], - 1 , 1 ])
895
+ .to (x .place )
896
+ )
897
+ position_ids_expanded = position_ids .unsqueeze (1 ).cast (paddle .float32 )
893
898
894
- freqs = paddle .matmul (inv_freq_expanded , position_ids_expanded ).transpose ([0 , 2 , 1 ])
895
- emb = paddle .cat ((freqs , freqs ), axis = - 1 )
896
- cos = paddle .cos (emb ) * self .attention_scaling
897
- sin = paddle .sin (emb ) * self .attention_scaling
899
+ freqs = paddle .matmul (inv_freq_expanded , position_ids_expanded ).transpose ([0 , 2 , 1 ])
900
+ emb = paddle .cat ((freqs , freqs ), axis = - 1 )
901
+ cos = paddle .cos (emb ) * self .attention_scaling
902
+ sin = paddle .sin (emb ) * self .attention_scaling
898
903
899
904
return cos .cast (dtype = x .dtype ), sin .cast (dtype = x .dtype )
900
905
0 commit comments