Skip to content

Commit 0efb202

Browse files
authored
Implement overlapping of MTP Decoder Layer (#10963)
1 parent 21d7d20 commit 0efb202

File tree

1 file changed

+76
-4
lines changed

1 file changed

+76
-4
lines changed

paddlenlp/transformers/deepseek_v2/modeling_pp.py

Lines changed: 76 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ def __init__(
155155
config,
156156
shared_experts=None,
157157
using_post_norm_recompute=False,
158+
output_mtp_embed_first=False,
158159
name="PostProcessNode",
159160
):
160161
self.send_mtp_embed = send_mtp_embed
@@ -163,6 +164,7 @@ def __init__(
163164
self.config = config
164165
self.alpha = alpha
165166
self.using_post_norm_recompute = using_post_norm_recompute
167+
self.output_mtp_embed_first = output_mtp_embed_first
166168
self.name = name
167169

168170
if self.using_post_norm_recompute:
@@ -205,6 +207,7 @@ def forward_without_residual(self, inputs):
205207
hidden_states.stop_gradient = False
206208

207209
if self.send_mtp_embed:
210+
assert not self.output_mtp_embed_first, "forward_without_residual doesn't support output_mtp_embed_first"
208211
hidden_states = paddle.concat([hidden_states, inputs_embeds_mtp], axis=-1)
209212
self.mtp_embed_shape = inputs_embeds_mtp.shape # 保存mtp_embed的shape用于反向传播
210213

@@ -245,7 +248,10 @@ def forward(self, inputs):
245248
hidden_states = residual + final_hidden_states
246249

247250
if self.send_mtp_embed:
248-
hidden_states = paddle.concat([hidden_states, inputs_embeds_mtp], axis=-1)
251+
if self.output_mtp_embed_first:
252+
hidden_states = paddle.concat([inputs_embeds_mtp, hidden_states], axis=-1)
253+
else:
254+
hidden_states = paddle.concat([hidden_states, inputs_embeds_mtp], axis=-1)
249255
self.mtp_embed_shape = inputs_embeds_mtp.shape # 保存mtp_embed的shape用于反向传播
250256

251257
return return_args(hidden_states)
@@ -257,8 +263,12 @@ def backward(self, output_grad):
257263
if self.send_mtp_embed:
258264
# 分割梯度:do3的前部分对应hidden_states,后部分对应inputs_embeds_mtp
259265
hidden_size = do3.shape[-1] - self.mtp_embed_shape[-1]
260-
hidden_states_grad = do3[..., :hidden_size]
261-
inputs_embeds_mtp_grad = do3[..., hidden_size:]
266+
if self.output_mtp_embed_first:
267+
hidden_states_grad = do3[..., hidden_size:]
268+
inputs_embeds_mtp_grad = do3[..., :hidden_size]
269+
else:
270+
hidden_states_grad = do3[..., :hidden_size]
271+
inputs_embeds_mtp_grad = do3[..., hidden_size:]
262272
else:
263273
hidden_states_grad = do3
264274
inputs_embeds_mtp_grad = None
@@ -1682,7 +1692,8 @@ def build_schedule_node(self):
16821692
self.config,
16831693
self.mlp.shared_experts,
16841694
self.config.using_post_norm_recompute,
1685-
"post_process_node",
1695+
output_mtp_embed_first=isinstance(self, DeepseekV2MTPLayer),
1696+
name="post_process_node",
16861697
)
16871698
return FusionFp8DecoderLayerNode(
16881699
attn_and_gate_node=attn_and_gate_node,
@@ -1780,7 +1791,68 @@ def forward(self, args):
17801791
hidden_states = paddle.concat(output_list, axis=-1)
17811792
return return_args(hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids)
17821793

1794+
def attn_compute_for_fusion(self, args):
1795+
hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args)
1796+
assert attention_mask is None
1797+
assert attn_mask_startend_row_indices is None
1798+
assert position_ids is None
1799+
assert self.config.num_nextn_predict_layers == 1
1800+
1801+
if self.config.send_mtp_embed:
1802+
hidden_states_list = paddle.split(hidden_states, self.config.num_nextn_predict_layers + 1, axis=-1)
1803+
hidden_states_main_model = hidden_states_list[0]
1804+
inputs_embeds_cur_depth_list = hidden_states_list[1:]
1805+
else:
1806+
hidden_states_main_model = hidden_states
1807+
global global_inputs_embeds_mtp_queue
1808+
inputs_embeds_cur_depth_list = global_inputs_embeds_mtp_queue.get()
1809+
1810+
hidden_states = hidden_states_main_model
1811+
nextn_hidden_state = inputs_embeds_cur_depth_list[0]
1812+
1813+
# mtp compute
1814+
hidden_states = self.hnorm(hidden_states)
1815+
nextn_hidden_state = self.enorm(nextn_hidden_state)
1816+
1817+
hidden_states = self.eh_proj(paddle.concat([hidden_states, nextn_hidden_state], axis=-1))
1818+
1819+
# attention compute
1820+
hidden_states, residual = self.self_attn_compute(hidden_states)
1821+
1822+
if self.using_post_norm_recompute:
1823+
probs, routing_map, l_aux, _, norm_out = self.mlp.router(hidden_states)
1824+
else:
1825+
probs, routing_map, l_aux, _ = self.mlp.router(hidden_states)
1826+
1827+
# common return values
1828+
ret = (
1829+
hidden_states_main_model,
1830+
hidden_states,
1831+
residual,
1832+
probs,
1833+
routing_map,
1834+
l_aux,
1835+
)
1836+
ret = (*ret, norm_out) if self.using_post_norm_recompute else ret
1837+
1838+
return ret
1839+
17831840
def build_schedule_node(self):
1841+
if isinstance(self.mlp, DeepseekV2MoE):
1842+
self.mlp.update_flex_token()
1843+
if (
1844+
self.mlp.using_flex_token and
1845+
DSV3_USE_FP8_GEMM and
1846+
self.config.num_nextn_predict_layers == 1
1847+
):
1848+
prev_send_mtp_embed = self.config.send_mtp_embed
1849+
self.config.send_mtp_embed = True # must be True in MTP node
1850+
1851+
node = DeepseekV2DecoderLayerPipe.build_schedule_node(self)
1852+
assert isinstance(node, FusionFp8DecoderLayerNode)
1853+
1854+
self.config.send_mtp_embed = prev_send_mtp_embed
1855+
return node
17841856
return ScheduleNode(self.forward, name="DeepseekV2MTPLayerPipe")
17851857

17861858

0 commit comments

Comments
 (0)