Skip to content

Fix mtp bug #10895

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: dsv3_dev
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 30 additions & 15 deletions paddlenlp/transformers/deepseek_v2/modeling_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,9 @@ def build_overlapped_nodes(forward_chunk, backward_chunk):
overlap_node = OverlapedScheduleChunk(forward_overlap_layers, backward_overlap_layers, use_fuion=DSV3_USE_FP8_GEMM)
return forward_pre_node, backward_pre_node, overlap_node, forward_post_node, backward_post_node

import queue

global_inputs_embeds_mtp_queue = queue.Queue()

class DeepseekV2EmbeddingPipe(nn.Layer):
def __init__(self, config: DeepseekV2Config):
Expand Down Expand Up @@ -902,7 +905,7 @@ def forward(self, args):
inputs_embeds = self.embed_tokens(input_ids)

batch_size, seq_length = input_ids.shape
if self.config.send_mtp_embed:
if self.config.num_nextn_predict_layers > 0:
seq_length -= self.config.num_nextn_predict_layers

if attention_mask is not None:
Expand All @@ -925,7 +928,7 @@ def forward(self, args):
attention_mask = paddle.tril(paddle.ones((seq_length, seq_length), dtype="bool"))
attention_mask.stop_gradient = True

if self.config.send_mtp_embed:
if self.config.num_nextn_predict_layers > 0:
inputs_embeds_extra = inputs_embeds[:, -self.config.num_nextn_predict_layers :, :] # [B, S, D]
inputs_embeds = inputs_embeds[:, : -self.config.num_nextn_predict_layers, :]
inputs_embeds_ori = inputs_embeds
Expand All @@ -937,6 +940,7 @@ def forward(self, args):
# [seq_len * bs / n, num_head * head_dim] (n is mp parallelism)
inputs_embeds = ScatterOp.apply(inputs_embeds)
embeds_res = [inputs_embeds]
mtp_embeds = []
for depth in range(self.config.num_nextn_predict_layers):
inputs_embeds_mtp = paddle.concat(
[
Expand All @@ -948,12 +952,18 @@ def forward(self, args):
if self.sequence_parallel:
inputs_embeds_mtp = inputs_embeds_mtp.reshape([-1, inputs_embeds_mtp.shape[-1]])
inputs_embeds_mtp = ScatterOp.apply(inputs_embeds_mtp)
embeds_res.append(inputs_embeds_mtp)
# if not self.sequence_parallel
# mtp_embeds: [B*num_nextn_predict_layers, seq_len, hidden_size]
# else:
# mtp_embeds: [B*seq_len*num_nextn_predict_layers, hidden_size]
inputs_embeds = paddle.concat(embeds_res, axis=-1)
mtp_embeds.append(inputs_embeds_mtp)

if self.config.send_mtp_embed:
embeds_res.extend(mtp_embeds)
# if not self.sequence_parallel
# mtp_embeds: [B*num_nextn_predict_layers, seq_len, hidden_size]
# else:
# mtp_embeds: [B*seq_len*num_nextn_predict_layers, hidden_size]
inputs_embeds = paddle.concat(embeds_res, axis=-1)
else:
global global_inputs_embeds_mtp_queue
global_inputs_embeds_mtp_queue.put(mtp_embeds)
return return_args(inputs_embeds, attention_mask, attn_mask_startend_row_indices, position_ids)
else:
if self.sequence_parallel:
Expand Down Expand Up @@ -1235,10 +1245,15 @@ def build_schedule_node(self):
class DeepseekV2MTPLayerPipe(DeepseekV2MTPLayer):
def forward(self, args):
hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args)

hidden_states_list = paddle.split(hidden_states, self.config.num_nextn_predict_layers + 1, axis=-1)
hidden_states_main_model = hidden_states_list[0]
inputs_embeds_cur_depth_list = hidden_states_list[1:]
if self.config.send_mtp_embed:
hidden_states_list = paddle.split(hidden_states, self.config.num_nextn_predict_layers + 1, axis=-1)
hidden_states_main_model = hidden_states_list[0]
inputs_embeds_cur_depth_list = hidden_states_list[1:]
else:
hidden_states_main_model = hidden_states
global global_inputs_embeds_mtp_queue
inputs_embeds_cur_depth_list = global_inputs_embeds_mtp_queue.get()

has_gradient = not hidden_states_main_model.stop_gradient

if attention_mask is not None and attention_mask.dtype == paddle.int32:
Expand Down Expand Up @@ -1303,7 +1318,7 @@ def __init__(self, config):
def forward(self, args):
hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args)

if self.config.send_mtp_embed:
if self.config.num_nextn_predict_layers > 0:
hidden_states_list = paddle.split(hidden_states, self.config.num_nextn_predict_layers + 1, axis=-1)
hidden_states = hidden_states_list[0]
hidden_states_mtp = hidden_states_list[-self.config.num_nextn_predict_layers :]
Expand All @@ -1328,7 +1343,7 @@ def embedding_weight(self):
return get_attr(self, "weight")

def forward(self, args: Union[Tuple, paddle.Tensor]):
if self.config.send_mtp_embed:
if self.config.num_nextn_predict_layers > 0:
logits = []
for _hidden_states in args:
logits.append(super().forward(_hidden_states))
Expand All @@ -1343,7 +1358,7 @@ def build_schedule_node(self):

class DeepseekV2PretrainingCriterionPipe(DeepseekV2PretrainingCriterion):
def forward(self, logits, labels):
if self.config.send_mtp_embed:
if self.config.num_nextn_predict_layers > 0:
mtp_logits = logits[1:]
logits = logits[0]
loss = super().forward(logits, labels, mtp_logits=mtp_logits)
Expand Down
Loading