@@ -155,6 +155,7 @@ def __init__(
155
155
config ,
156
156
shared_experts = None ,
157
157
using_post_norm_recompute = False ,
158
+ output_mtp_embed_first = False ,
158
159
name = "PostProcessNode" ,
159
160
):
160
161
self .send_mtp_embed = send_mtp_embed
@@ -163,6 +164,7 @@ def __init__(
163
164
self .config = config
164
165
self .alpha = alpha
165
166
self .using_post_norm_recompute = using_post_norm_recompute
167
+ self .output_mtp_embed_first = output_mtp_embed_first
166
168
self .name = name
167
169
168
170
if self .using_post_norm_recompute :
@@ -205,6 +207,7 @@ def forward_without_residual(self, inputs):
205
207
hidden_states .stop_gradient = False
206
208
207
209
if self .send_mtp_embed :
210
+ assert not self .output_mtp_embed_first , "forward_without_residual doesn't support output_mtp_embed_first"
208
211
hidden_states = paddle .concat ([hidden_states , inputs_embeds_mtp ], axis = - 1 )
209
212
self .mtp_embed_shape = inputs_embeds_mtp .shape # 保存mtp_embed的shape用于反向传播
210
213
@@ -245,7 +248,10 @@ def forward(self, inputs):
245
248
hidden_states = residual + final_hidden_states
246
249
247
250
if self .send_mtp_embed :
248
- hidden_states = paddle .concat ([hidden_states , inputs_embeds_mtp ], axis = - 1 )
251
+ if self .output_mtp_embed_first :
252
+ hidden_states = paddle .concat ([inputs_embeds_mtp , hidden_states ], axis = - 1 )
253
+ else :
254
+ hidden_states = paddle .concat ([hidden_states , inputs_embeds_mtp ], axis = - 1 )
249
255
self .mtp_embed_shape = inputs_embeds_mtp .shape # 保存mtp_embed的shape用于反向传播
250
256
251
257
return return_args (hidden_states )
@@ -257,8 +263,12 @@ def backward(self, output_grad):
257
263
if self .send_mtp_embed :
258
264
# 分割梯度:do3的前部分对应hidden_states,后部分对应inputs_embeds_mtp
259
265
hidden_size = do3 .shape [- 1 ] - self .mtp_embed_shape [- 1 ]
260
- hidden_states_grad = do3 [..., :hidden_size ]
261
- inputs_embeds_mtp_grad = do3 [..., hidden_size :]
266
+ if self .output_mtp_embed_first :
267
+ hidden_states_grad = do3 [..., hidden_size :]
268
+ inputs_embeds_mtp_grad = do3 [..., :hidden_size ]
269
+ else :
270
+ hidden_states_grad = do3 [..., :hidden_size ]
271
+ inputs_embeds_mtp_grad = do3 [..., hidden_size :]
262
272
else :
263
273
hidden_states_grad = do3
264
274
inputs_embeds_mtp_grad = None
@@ -1682,7 +1692,8 @@ def build_schedule_node(self):
1682
1692
self .config ,
1683
1693
self .mlp .shared_experts ,
1684
1694
self .config .using_post_norm_recompute ,
1685
- "post_process_node" ,
1695
+ output_mtp_embed_first = isinstance (self , DeepseekV2MTPLayer ),
1696
+ name = "post_process_node" ,
1686
1697
)
1687
1698
return FusionFp8DecoderLayerNode (
1688
1699
attn_and_gate_node = attn_and_gate_node ,
@@ -1780,7 +1791,68 @@ def forward(self, args):
1780
1791
hidden_states = paddle .concat (output_list , axis = - 1 )
1781
1792
return return_args (hidden_states , attention_mask , attn_mask_startend_row_indices , position_ids )
1782
1793
1794
+ def attn_compute_for_fusion (self , args ):
1795
+ hidden_states , attention_mask , attn_mask_startend_row_indices , position_ids = parse_args (args )
1796
+ assert attention_mask is None
1797
+ assert attn_mask_startend_row_indices is None
1798
+ assert position_ids is None
1799
+ assert self .config .num_nextn_predict_layers == 1
1800
+
1801
+ if self .config .send_mtp_embed :
1802
+ hidden_states_list = paddle .split (hidden_states , self .config .num_nextn_predict_layers + 1 , axis = - 1 )
1803
+ hidden_states_main_model = hidden_states_list [0 ]
1804
+ inputs_embeds_cur_depth_list = hidden_states_list [1 :]
1805
+ else :
1806
+ hidden_states_main_model = hidden_states
1807
+ global global_inputs_embeds_mtp_queue
1808
+ inputs_embeds_cur_depth_list = global_inputs_embeds_mtp_queue .get ()
1809
+
1810
+ hidden_states = hidden_states_main_model
1811
+ nextn_hidden_state = inputs_embeds_cur_depth_list [0 ]
1812
+
1813
+ # mtp compute
1814
+ hidden_states = self .hnorm (hidden_states )
1815
+ nextn_hidden_state = self .enorm (nextn_hidden_state )
1816
+
1817
+ hidden_states = self .eh_proj (paddle .concat ([hidden_states , nextn_hidden_state ], axis = - 1 ))
1818
+
1819
+ # attention compute
1820
+ hidden_states , residual = self .self_attn_compute (hidden_states )
1821
+
1822
+ if self .using_post_norm_recompute :
1823
+ probs , routing_map , l_aux , _ , norm_out = self .mlp .router (hidden_states )
1824
+ else :
1825
+ probs , routing_map , l_aux , _ = self .mlp .router (hidden_states )
1826
+
1827
+ # common return values
1828
+ ret = (
1829
+ hidden_states_main_model ,
1830
+ hidden_states ,
1831
+ residual ,
1832
+ probs ,
1833
+ routing_map ,
1834
+ l_aux ,
1835
+ )
1836
+ ret = (* ret , norm_out ) if self .using_post_norm_recompute else ret
1837
+
1838
+ return ret
1839
+
1783
1840
def build_schedule_node (self ):
1841
+ if isinstance (self .mlp , DeepseekV2MoE ):
1842
+ self .mlp .update_flex_token ()
1843
+ if (
1844
+ self .mlp .using_flex_token and
1845
+ DSV3_USE_FP8_GEMM and
1846
+ self .config .num_nextn_predict_layers == 1
1847
+ ):
1848
+ prev_send_mtp_embed = self .config .send_mtp_embed
1849
+ self .config .send_mtp_embed = True # must be True in MTP node
1850
+
1851
+ node = DeepseekV2DecoderLayerPipe .build_schedule_node (self )
1852
+ assert isinstance (node , FusionFp8DecoderLayerNode )
1853
+
1854
+ self .config .send_mtp_embed = prev_send_mtp_embed
1855
+ return node
1784
1856
return ScheduleNode (self .forward , name = "DeepseekV2MTPLayerPipe" )
1785
1857
1786
1858
0 commit comments