diff --git a/paddlenlp/transformers/deepseek_v2/modeling_pp.py b/paddlenlp/transformers/deepseek_v2/modeling_pp.py index 9e05faba3593..673a3edb3edd 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling_pp.py +++ b/paddlenlp/transformers/deepseek_v2/modeling_pp.py @@ -869,6 +869,9 @@ def build_overlapped_nodes(forward_chunk, backward_chunk): overlap_node = OverlapedScheduleChunk(forward_overlap_layers, backward_overlap_layers, use_fuion=DSV3_USE_FP8_GEMM) return forward_pre_node, backward_pre_node, overlap_node, forward_post_node, backward_post_node +import queue + +global_inputs_embeds_mtp_queue = queue.Queue() class DeepseekV2EmbeddingPipe(nn.Layer): def __init__(self, config: DeepseekV2Config): @@ -902,7 +905,7 @@ def forward(self, args): inputs_embeds = self.embed_tokens(input_ids) batch_size, seq_length = input_ids.shape - if self.config.send_mtp_embed: + if self.config.num_nextn_predict_layers > 0: seq_length -= self.config.num_nextn_predict_layers if attention_mask is not None: @@ -925,7 +928,7 @@ def forward(self, args): attention_mask = paddle.tril(paddle.ones((seq_length, seq_length), dtype="bool")) attention_mask.stop_gradient = True - if self.config.send_mtp_embed: + if self.config.num_nextn_predict_layers > 0: inputs_embeds_extra = inputs_embeds[:, -self.config.num_nextn_predict_layers :, :] # [B, S, D] inputs_embeds = inputs_embeds[:, : -self.config.num_nextn_predict_layers, :] inputs_embeds_ori = inputs_embeds @@ -937,6 +940,7 @@ def forward(self, args): # [seq_len * bs / n, num_head * head_dim] (n is mp parallelism) inputs_embeds = ScatterOp.apply(inputs_embeds) embeds_res = [inputs_embeds] + mtp_embeds = [] for depth in range(self.config.num_nextn_predict_layers): inputs_embeds_mtp = paddle.concat( [ @@ -948,12 +952,18 @@ def forward(self, args): if self.sequence_parallel: inputs_embeds_mtp = inputs_embeds_mtp.reshape([-1, inputs_embeds_mtp.shape[-1]]) inputs_embeds_mtp = ScatterOp.apply(inputs_embeds_mtp) - embeds_res.append(inputs_embeds_mtp) - # if not self.sequence_parallel - # mtp_embeds: [B*num_nextn_predict_layers, seq_len, hidden_size] - # else: - # mtp_embeds: [B*seq_len*num_nextn_predict_layers, hidden_size] - inputs_embeds = paddle.concat(embeds_res, axis=-1) + mtp_embeds.append(inputs_embeds_mtp) + + if self.config.send_mtp_embed: + embeds_res.extend(mtp_embeds) + # if not self.sequence_parallel + # mtp_embeds: [B*num_nextn_predict_layers, seq_len, hidden_size] + # else: + # mtp_embeds: [B*seq_len*num_nextn_predict_layers, hidden_size] + inputs_embeds = paddle.concat(embeds_res, axis=-1) + else: + global global_inputs_embeds_mtp_queue + global_inputs_embeds_mtp_queue.put(mtp_embeds) return return_args(inputs_embeds, attention_mask, attn_mask_startend_row_indices, position_ids) else: if self.sequence_parallel: @@ -1235,10 +1245,15 @@ def build_schedule_node(self): class DeepseekV2MTPLayerPipe(DeepseekV2MTPLayer): def forward(self, args): hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args) - - hidden_states_list = paddle.split(hidden_states, self.config.num_nextn_predict_layers + 1, axis=-1) - hidden_states_main_model = hidden_states_list[0] - inputs_embeds_cur_depth_list = hidden_states_list[1:] + if self.config.send_mtp_embed: + hidden_states_list = paddle.split(hidden_states, self.config.num_nextn_predict_layers + 1, axis=-1) + hidden_states_main_model = hidden_states_list[0] + inputs_embeds_cur_depth_list = hidden_states_list[1:] + else: + hidden_states_main_model = hidden_states + global global_inputs_embeds_mtp_queue + inputs_embeds_cur_depth_list = global_inputs_embeds_mtp_queue.get() + has_gradient = not hidden_states_main_model.stop_gradient if attention_mask is not None and attention_mask.dtype == paddle.int32: @@ -1303,7 +1318,7 @@ def __init__(self, config): def forward(self, args): hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args) - if self.config.send_mtp_embed: + if self.config.num_nextn_predict_layers > 0: hidden_states_list = paddle.split(hidden_states, self.config.num_nextn_predict_layers + 1, axis=-1) hidden_states = hidden_states_list[0] hidden_states_mtp = hidden_states_list[-self.config.num_nextn_predict_layers :] @@ -1328,7 +1343,7 @@ def embedding_weight(self): return get_attr(self, "weight") def forward(self, args: Union[Tuple, paddle.Tensor]): - if self.config.send_mtp_embed: + if self.config.num_nextn_predict_layers > 0: logits = [] for _hidden_states in args: logits.append(super().forward(_hidden_states)) @@ -1343,7 +1358,7 @@ def build_schedule_node(self): class DeepseekV2PretrainingCriterionPipe(DeepseekV2PretrainingCriterion): def forward(self, logits, labels): - if self.config.send_mtp_embed: + if self.config.num_nextn_predict_layers > 0: mtp_logits = logits[1:] logits = logits[0] loss = super().forward(logits, labels, mtp_logits=mtp_logits)