diff --git a/paddlenlp/transformers/deepseek_v2/modeling.py b/paddlenlp/transformers/deepseek_v2/modeling.py index d96e1120b331..90dd25b8fa36 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling.py +++ b/paddlenlp/transformers/deepseek_v2/modeling.py @@ -3163,6 +3163,28 @@ def forward( ) +class FastCrossEntropyFunction(paddle.autograd.PyLayer): + @staticmethod + def forward(ctx, preds, labels): + + softmax_val, loss = paddle._C_ops.cross_entropy_with_softmax(preds, labels, False, True, False, -100, -1) + + # print("softmax val", softmax_val.dtype) + + ctx.save_for_backward(labels, softmax_val) + return loss + + @staticmethod + def backward(ctx, dout): + labels, softmax_val = ctx.saved_tensor() + + preds_grad = paddle.incubate.nn.functional.cross_entropy_with_softmax_bwd_w_downcast( + labels, softmax_val.cast(paddle.float32), dout.cast(paddle.float32) + ) + + return preds_grad, None + + class DeepseekV2PretrainingCriterion(nn.Layer): """ Criterion for Mixtral. @@ -3190,7 +3212,7 @@ def forward(self, prediction_scores, masked_lm_labels, router_loss=None, mtp_log def compute_loss(preds, labels): with paddle.amp.auto_cast(False): - masked_lm_loss = self.loss_func(preds.astype("float32"), labels.unsqueeze(2)) + masked_lm_loss = FastCrossEntropyFunction.apply(preds, labels.unsqueeze(2)) binary_sequence = paddle.where( masked_lm_loss > 0, paddle.ones_like(masked_lm_loss), paddle.zeros_like(masked_lm_loss) ) diff --git a/paddlenlp/transformers/deepseek_v2/modeling_pp.py b/paddlenlp/transformers/deepseek_v2/modeling_pp.py index 99b817ebe32b..f2cd941af934 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling_pp.py +++ b/paddlenlp/transformers/deepseek_v2/modeling_pp.py @@ -1191,7 +1191,11 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p paddle.base.core.nvprof_nvtx_pop() # dense_attn_moe_combine paddle.base.core.nvprof_nvtx_push("moe_mlp") + assert WeightGradStore.funcs_queue.empty() + WeightGradStore.enabled = True output_grad = self.backward_node.mlp_backward(output_grad) + WeightGradStore.enabled = False + WeightGradStore.flush() paddle.base.core.nvprof_nvtx_pop() # moe_mlp paddle.base.core.nvprof_nvtx_push("dense_mlp_moe_dispatch") @@ -1199,6 +1203,12 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p output_grad, async_finish=True, allocate_on_comm_stream=True ) dispatch_bw_event = deep_ep.get_event_from_comm_stream(self.backward_node.moe_group.id) + + paddle.base.core.nvprof_nvtx_push("dispatch_backward_dw") + WeightGradStore.pop() + assert WeightGradStore.funcs_queue.empty() + paddle.base.core.nvprof_nvtx_pop() + inputs = self.forward_node.mlp_node.forward(inputs) dispatch_bw_event.calc_stream_wait(self.backward_node.moe_group.id) paddle.base.core.nvprof_nvtx_pop() # dense_mlp_moe_dispatch @@ -1307,17 +1317,13 @@ def build_overlapped_nodes(forward_chunk, backward_chunk): overlap_node = OverlapedScheduleChunk(forward_overlap_layers, backward_overlap_layers, use_fuion=DSV3_USE_FP8_GEMM) return forward_pre_node, backward_pre_node, overlap_node, forward_post_node, backward_post_node + class EmbeddingFunction(paddle.autograd.PyLayer): @staticmethod - def forward(ctx, x, weight): - out = paddle.nn.functional.embedding( - x, - weight=weight, - padding_idx=None, - max_norm=None, - norm_type=2.0, - sparse=False, - scale_grad_by_freq=False ) + def forward(ctx, x, weight): + out = paddle.nn.functional.embedding( + x, weight=weight, padding_idx=None, max_norm=None, norm_type=2.0, sparse=False, scale_grad_by_freq=False + ) ctx.save_for_backward(x, weight) return out @@ -1325,15 +1331,15 @@ def forward(ctx, x, weight): @staticmethod def backward(ctx, dout): x, weight = ctx.saved_tensor() - - if hasattr( weight, "main_grad" ): - paddle.incubate.nn.functional.embedding_grad_add_to_(x, weight.main_grad, dout) + + if hasattr(weight, "main_grad"): + paddle.incubate.nn.functional.embedding_grad_add_to_(x, weight.main_grad, dout) else: paddle.incubate.nn.functional.embedding_grad_add_to_(x, weight.grad, dout) - return None, None + class DeepseekV2EmbeddingPipe(nn.Layer): def __init__(self, config: DeepseekV2Config): super(DeepseekV2EmbeddingPipe, self).__init__() @@ -1363,7 +1369,7 @@ def forward(self, args): _type_: _description_ """ input_ids, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args) - inputs_embeds = EmbeddingFunction.apply( input_ids, self.embed_tokens.weight ) + inputs_embeds = EmbeddingFunction.apply(input_ids, self.embed_tokens.weight) batch_size, seq_length = input_ids.shape if self.config.num_nextn_predict_layers > 0: