Skip to content

Commit c90f3bd

Browse files
authored
[Bug] Fix precision of gate and e_score_correction_bias in Glm4Moe (#2637)
1 parent ed81369 commit c90f3bd

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

paddleformers/transformers/glm4_moe/modeling.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,9 @@ def __init__(self, config, num_experts, expert_hidden_size, **kwargs):
304304
)
305305
self.expert_usage.stop_gradient = True
306306

307+
# weight and e_score_correction_bias do not need to be cast to low precision
308+
self._cast_to_low_precision = False
309+
307310
def forward(self, hidden_states):
308311
"""
309312
Args:
@@ -340,12 +343,15 @@ def __init__(self, config: Glm4MoeConfig):
340343

341344
self.weight = paddle.create_parameter(
342345
shape=[self.n_routed_experts, config.hidden_size],
343-
dtype="bfloat16",
346+
dtype="float32",
344347
default_initializer=paddle.nn.initializer.Uniform(),
345348
)
346349

347350
self.register_buffer("e_score_correction_bias", paddle.zeros((self.n_routed_experts,), dtype=paddle.float32))
348351

352+
# weight and e_score_correction_bias do not need to be cast to low precision
353+
self._cast_to_low_precision = False
354+
349355
@paddle.no_grad()
350356
def get_topk_indices(self, scores):
351357
scores_for_choice = scores.reshape([-1, self.n_routed_experts]) + self.e_score_correction_bias.unsqueeze(0)
@@ -588,6 +594,7 @@ class Glm4MoePreTrainedModel(PretrainedModel):
588594
config: Glm4MoeConfig
589595
config_class = Glm4MoeConfig
590596
base_model_prefix = "model"
597+
_keep_in_fp32_modules = ["mlp.gate.weight", "e_score_correction_bias"]
591598
transpose_weight_keys = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
592599

593600
@classmethod

0 commit comments

Comments
 (0)