@@ -192,6 +192,7 @@ def forward_without_residual(self, inputs):
192
192
193
193
if self .send_mtp_embed :
194
194
hidden_states = paddle .concat ([hidden_states , inputs_embeds_mtp ], axis = - 1 )
195
+ self .mtp_embed_shape = inputs_embeds_mtp .shape # 保存mtp_embed的shape用于反向传播
195
196
196
197
return return_args (hidden_states )
197
198
@@ -227,37 +228,47 @@ def forward(self, inputs):
227
228
228
229
if self .send_mtp_embed :
229
230
hidden_states = paddle .concat ([hidden_states , inputs_embeds_mtp ], axis = - 1 )
231
+ self .mtp_embed_shape = inputs_embeds_mtp .shape # 保存mtp_embed的shape用于反向传播
230
232
231
233
return return_args (hidden_states )
232
234
233
235
@paddle .no_grad ()
234
236
def backward (self , output_grad ):
235
237
(do3 ,) = output_grad
236
238
237
- assert not self .send_mtp_embed , "not support have mtp have yet"
239
+ if self .send_mtp_embed :
240
+ # 分割梯度:do3的前部分对应hidden_states,后部分对应inputs_embeds_mtp
241
+ hidden_size = do3 .shape [- 1 ] - self .mtp_embed_shape [- 1 ]
242
+ hidden_states_grad = do3 [..., :hidden_size ]
243
+ inputs_embeds_mtp_grad = do3 [..., hidden_size :]
244
+ else :
245
+ hidden_states_grad = do3
246
+ inputs_embeds_mtp_grad = None
247
+
238
248
if self .using_post_norm_recompute :
239
249
dx = FP8LinearFunctionBase .fp8_mlp_bwd_norm_rc (
240
- do3 ,
250
+ hidden_states_grad ,
241
251
self .x ,
242
252
self .shared_experts .norm_weight ,
243
253
self .shared_experts .norm_eps ,
244
254
self .shared_experts .w1 ,
245
255
self .shared_experts .w2 ,
246
256
)
247
257
else :
248
- dx = FP8LinearFunctionBase .fp8_mlp_bwd (do3 , self .x , self .shared_experts .w1 , self .shared_experts .w2 )
258
+ dx = FP8LinearFunctionBase .fp8_mlp_bwd (
259
+ hidden_states_grad , self .x , self .shared_experts .w1 , self .shared_experts .w2 , True
260
+ )
249
261
250
262
self .x = None
251
263
252
- residual_grad = do3
253
-
254
- hidden_states_grad = dx
255
-
264
+ residual_grad = hidden_states_grad
256
265
l_aux_grad = paddle .ones (1 , dtype = self .l_aux .dtype ) * self .alpha
266
+ final_hidden_states_grad = hidden_states_grad
257
267
258
- final_hidden_states_grad = do3
259
-
260
- return (hidden_states_grad , residual_grad , l_aux_grad , final_hidden_states_grad )
268
+ if self .send_mtp_embed :
269
+ return (inputs_embeds_mtp_grad , dx , residual_grad , l_aux_grad , final_hidden_states_grad )
270
+ else :
271
+ return (dx , residual_grad , l_aux_grad , final_hidden_states_grad )
261
272
262
273
263
274
class DecoderLayerNode (ScheduleNode ):
@@ -749,6 +760,9 @@ def attn_backward(self, output_grad):
749
760
hs_grad ,
750
761
token_probs_grad ,
751
762
) = output_grad
763
+ inputs_embeds_mtp_grad_shape = hidden_states_grad .shape
764
+ inputs_embeds_mtp_grad_shape [- 1 ] = - 1
765
+ inputs_embeds_mtp_grad = inputs_embeds_mtp_grad .view (inputs_embeds_mtp_grad_shape )
752
766
else :
753
767
hidden_states_grad , residual_grad , l_aux_grad , hs_grad , token_probs_grad = output_grad
754
768
@@ -906,8 +920,11 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
906
920
combine_forward_event .calc_stream_wait (self .forward_node .moe_group .id )
907
921
908
922
final_out = self .forward_node .post_process_node .forward_without_residual (inputs )
909
- inputs = final_out + combine_fwd_out
910
-
923
+ if final_out .shape [- 1 ] != combine_fwd_out .shape [- 1 ]:
924
+ final_out [:, :, : combine_fwd_out .shape [- 1 ]] += combine_fwd_out # 直接广播并相加
925
+ else :
926
+ final_out += combine_fwd_out
927
+ inputs = final_out
911
928
combine_fwd_out ._record_stream ()
912
929
913
930
paddle .base .core .nvprof_nvtx_pop ()
@@ -1072,7 +1089,7 @@ def forward(self, args):
1072
1089
if self .config .send_mtp_embed :
1073
1090
batch_size , _ , hidden_size = hidden_states .shape
1074
1091
batch_size_mtp = hidden_size // (self .config .num_nextn_predict_layers + 1 )
1075
- inputs_embeds_mtp = hidden_states [..., - batch_size_mtp :]
1092
+ inputs_embeds_mtp = hidden_states [..., batch_size_mtp :]
1076
1093
hidden_states = hidden_states [..., :batch_size_mtp ]
1077
1094
1078
1095
has_gradient = not hidden_states .stop_gradient
@@ -1129,7 +1146,7 @@ def attn_compute(self, args):
1129
1146
1130
1147
batch_size , _ , hidden_size = hidden_states .shape
1131
1148
batch_size_mtp = hidden_size // (self .config .num_nextn_predict_layers + 1 )
1132
- inputs_embeds_mtp = hidden_states [..., - batch_size_mtp :]
1149
+ inputs_embeds_mtp = hidden_states [..., batch_size_mtp :]
1133
1150
hidden_states = hidden_states [..., :batch_size_mtp ]
1134
1151
1135
1152
def attn_compute_func (hidden_states ):
@@ -1162,7 +1179,7 @@ def attn_compute_for_fusion(self, args):
1162
1179
# slice from holy tensor
1163
1180
batch_size , _ , hidden_size = hidden_states .shape
1164
1181
batch_size_mtp = hidden_size // (self .config .num_nextn_predict_layers + 1 )
1165
- inputs_embeds_mtp = hidden_states [..., - batch_size_mtp :]
1182
+ inputs_embeds_mtp = hidden_states [..., batch_size_mtp :]
1166
1183
hidden_states = hidden_states [..., :batch_size_mtp ]
1167
1184
1168
1185
hidden_states , residual = self .self_attn_compute (hidden_states )
0 commit comments