@@ -304,6 +304,9 @@ def __init__(self, config, num_experts, expert_hidden_size, **kwargs):
304
304
)
305
305
self .expert_usage .stop_gradient = True
306
306
307
+ # weight and e_score_correction_bias do not need to be cast to low precision
308
+ self ._cast_to_low_precision = False
309
+
307
310
def forward (self , hidden_states ):
308
311
"""
309
312
Args:
@@ -340,12 +343,15 @@ def __init__(self, config: Glm4MoeConfig):
340
343
341
344
self .weight = paddle .create_parameter (
342
345
shape = [self .n_routed_experts , config .hidden_size ],
343
- dtype = "bfloat16 " ,
346
+ dtype = "float32 " ,
344
347
default_initializer = paddle .nn .initializer .Uniform (),
345
348
)
346
349
347
350
self .register_buffer ("e_score_correction_bias" , paddle .zeros ((self .n_routed_experts ,), dtype = paddle .float32 ))
348
351
352
+ # weight and e_score_correction_bias do not need to be cast to low precision
353
+ self ._cast_to_low_precision = False
354
+
349
355
@paddle .no_grad ()
350
356
def get_topk_indices (self , scores ):
351
357
scores_for_choice = scores .reshape ([- 1 , self .n_routed_experts ]) + self .e_score_correction_bias .unsqueeze (0 )
@@ -588,6 +594,7 @@ class Glm4MoePreTrainedModel(PretrainedModel):
588
594
config : Glm4MoeConfig
589
595
config_class = Glm4MoeConfig
590
596
base_model_prefix = "model"
597
+ _keep_in_fp32_modules = ["mlp.gate.weight" , "e_score_correction_bias" ]
591
598
transpose_weight_keys = ["q_proj" , "k_proj" , "v_proj" , "o_proj" , "gate_proj" , "up_proj" , "down_proj" ]
592
599
593
600
@classmethod
0 commit comments