Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion paddlenlp/transformers/deepseek_v2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -3163,6 +3163,28 @@ def forward(
)


class FastCrossEntropyFunction(paddle.autograd.PyLayer):
@staticmethod
def forward(ctx, preds, labels):

softmax_val, loss = paddle._C_ops.cross_entropy_with_softmax(preds, labels, False, True, False, -100, -1)

# print("softmax val", softmax_val.dtype)

ctx.save_for_backward(labels, softmax_val)
return loss

@staticmethod
def backward(ctx, dout):
labels, softmax_val = ctx.saved_tensor()

preds_grad = paddle.incubate.nn.functional.cross_entropy_with_softmax_bwd_w_downcast(
labels, softmax_val.cast(paddle.float32), dout.cast(paddle.float32)
)

return preds_grad, None


class DeepseekV2PretrainingCriterion(nn.Layer):
"""
Criterion for Mixtral.
Expand Down Expand Up @@ -3190,7 +3212,7 @@ def forward(self, prediction_scores, masked_lm_labels, router_loss=None, mtp_log

def compute_loss(preds, labels):
with paddle.amp.auto_cast(False):
masked_lm_loss = self.loss_func(preds.astype("float32"), labels.unsqueeze(2))
masked_lm_loss = FastCrossEntropyFunction.apply(preds, labels.unsqueeze(2))
binary_sequence = paddle.where(
masked_lm_loss > 0, paddle.ones_like(masked_lm_loss), paddle.zeros_like(masked_lm_loss)
)
Expand Down
34 changes: 20 additions & 14 deletions paddlenlp/transformers/deepseek_v2/modeling_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1191,14 +1191,24 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
paddle.base.core.nvprof_nvtx_pop() # dense_attn_moe_combine

paddle.base.core.nvprof_nvtx_push("moe_mlp")
assert WeightGradStore.funcs_queue.empty()
WeightGradStore.enabled = True
output_grad = self.backward_node.mlp_backward(output_grad)
WeightGradStore.enabled = False
WeightGradStore.flush()
paddle.base.core.nvprof_nvtx_pop() # moe_mlp

paddle.base.core.nvprof_nvtx_push("dense_mlp_moe_dispatch")
output_grad = self.backward_node.dispatch_backward(
output_grad, async_finish=True, allocate_on_comm_stream=True
)
dispatch_bw_event = deep_ep.get_event_from_comm_stream(self.backward_node.moe_group.id)

paddle.base.core.nvprof_nvtx_push("dispatch_backward_dw")
WeightGradStore.pop()
assert WeightGradStore.funcs_queue.empty()
paddle.base.core.nvprof_nvtx_pop()

inputs = self.forward_node.mlp_node.forward(inputs)
dispatch_bw_event.calc_stream_wait(self.backward_node.moe_group.id)
paddle.base.core.nvprof_nvtx_pop() # dense_mlp_moe_dispatch
Expand Down Expand Up @@ -1307,33 +1317,29 @@ def build_overlapped_nodes(forward_chunk, backward_chunk):
overlap_node = OverlapedScheduleChunk(forward_overlap_layers, backward_overlap_layers, use_fuion=DSV3_USE_FP8_GEMM)
return forward_pre_node, backward_pre_node, overlap_node, forward_post_node, backward_post_node


class EmbeddingFunction(paddle.autograd.PyLayer):
@staticmethod
def forward(ctx, x, weight):
out = paddle.nn.functional.embedding(
x,
weight=weight,
padding_idx=None,
max_norm=None,
norm_type=2.0,
sparse=False,
scale_grad_by_freq=False )
def forward(ctx, x, weight):
out = paddle.nn.functional.embedding(
x, weight=weight, padding_idx=None, max_norm=None, norm_type=2.0, sparse=False, scale_grad_by_freq=False
)

ctx.save_for_backward(x, weight)
return out

@staticmethod
def backward(ctx, dout):
x, weight = ctx.saved_tensor()
if hasattr( weight, "main_grad" ):
paddle.incubate.nn.functional.embedding_grad_add_to_(x, weight.main_grad, dout)

if hasattr(weight, "main_grad"):
paddle.incubate.nn.functional.embedding_grad_add_to_(x, weight.main_grad, dout)
else:
paddle.incubate.nn.functional.embedding_grad_add_to_(x, weight.grad, dout)


return None, None


class DeepseekV2EmbeddingPipe(nn.Layer):
def __init__(self, config: DeepseekV2Config):
super(DeepseekV2EmbeddingPipe, self).__init__()
Expand Down Expand Up @@ -1363,7 +1369,7 @@ def forward(self, args):
_type_: _description_
"""
input_ids, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args)
inputs_embeds = EmbeddingFunction.apply( input_ids, self.embed_tokens.weight )
inputs_embeds = EmbeddingFunction.apply(input_ids, self.embed_tokens.weight)

batch_size, seq_length = input_ids.shape
if self.config.num_nextn_predict_layers > 0:
Expand Down
Loading