@@ -443,7 +443,17 @@ class Ernie4_5_MoePretrainedModel(PretrainedModel):
443
443
config_class = Ernie4_5_MoeConfig
444
444
base_model_prefix = "model"
445
445
_keep_in_fp32_modules = ["mlp.gate.weight" , "e_score_correction_bias" ]
446
- transpose_weight_keys = ["q_proj" , "k_proj" , "v_proj" , "o_proj" , "gate_proj" , "up_proj" , "down_proj" , "gate" ]
446
+ transpose_weight_keys = [
447
+ "q_proj" ,
448
+ "k_proj" ,
449
+ "v_proj" ,
450
+ "o_proj" ,
451
+ "gate_proj" ,
452
+ "up_proj" ,
453
+ "down_proj" ,
454
+ "gate" ,
455
+ "mtp_linear_proj.0" ,
456
+ ]
447
457
448
458
@classmethod
449
459
def _get_tensor_parallel_mappings (cls , config , is_split = True ):
@@ -659,16 +669,18 @@ def __init__(self, config: Ernie4_5_MoeConfig):
659
669
self .mtp_linear_proj = paddle .nn .LayerList (
660
670
[
661
671
GeneralLinear .create (
662
- self . config .hidden_size * 2 ,
663
- self . config .hidden_size ,
672
+ config .hidden_size * 2 ,
673
+ config .hidden_size ,
664
674
has_bias = config .use_bias ,
665
675
config = config ,
666
676
fuse_matmul_bias = config .fuse_linear ,
677
+ linear_type = "default" ,
667
678
)
668
- for _ in range (self . config .num_nextn_predict_layers )
679
+ for _ in range (config .num_nextn_predict_layers )
669
680
]
670
681
)
671
682
if config .sequence_parallel :
683
+ logger .info ("enable sequence parallel for mtp_linear" )
672
684
for mtp_linear in self .mtp_linear_proj :
673
685
mark_as_sequence_parallel_parameter (mtp_linear .weight )
674
686
if config .use_bias :
@@ -795,7 +807,7 @@ def forward(
795
807
attention_mask , inputs_embeds .shape [:2 ], kv_seq_len , inputs_embeds .dtype
796
808
)
797
809
798
- if self .config .num_nextn_predict_layers > 0 :
810
+ if self .training and self . config .num_nextn_predict_layers > 0 :
799
811
inputs_embeds_extra = inputs_embeds [:, - self .config .num_nextn_predict_layers :, :]
800
812
inputs_embeds = inputs_embeds [:, : - self .config .num_nextn_predict_layers , :]
801
813
inputs_embeds_ori = inputs_embeds
@@ -896,7 +908,7 @@ def forward(
896
908
all_gate_logits = all_gate_logits + (gate_logits ,)
897
909
898
910
# Multi Token Prediction
899
- if self .config .num_nextn_predict_layers > 0 :
911
+ if self .training and self . config .num_nextn_predict_layers > 0 :
900
912
mtp_outputs .append (hidden_states )
901
913
902
914
for depth in range (self .config .num_nextn_predict_layers ):
@@ -1088,6 +1100,9 @@ def forward(
1088
1100
Returns:
1089
1101
Union[tuple, MoECausalLMOutputWithPast]: Model outputs.
1090
1102
"""
1103
+ if kwargs .get ("attn_mask_start_row_indices" , None ) is not None and attn_mask_startend_row_indices is None :
1104
+ attn_mask_startend_row_indices = kwargs ["attn_mask_start_row_indices" ]
1105
+
1091
1106
output_attentions = output_attentions if output_attentions is not None else self .config .output_attentions
1092
1107
output_hidden_states = (
1093
1108
output_hidden_states if output_hidden_states is not None else self .config .output_hidden_states
0 commit comments