Skip to content

Commit f83311a

Browse files
authored
optimize cross entropy speed (#11012)
1 parent 475942a commit f83311a

File tree

1 file changed

+23
-1
lines changed

1 file changed

+23
-1
lines changed

paddlenlp/transformers/deepseek_v2/modeling.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3163,6 +3163,28 @@ def forward(
31633163
)
31643164

31653165

3166+
class FastCrossEntropyFunction(paddle.autograd.PyLayer):
3167+
@staticmethod
3168+
def forward(ctx, preds, labels):
3169+
3170+
softmax_val, loss = paddle._C_ops.cross_entropy_with_softmax(preds, labels, False, True, False, -100, -1)
3171+
3172+
# print("softmax val", softmax_val.dtype)
3173+
3174+
ctx.save_for_backward(labels, softmax_val)
3175+
return loss
3176+
3177+
@staticmethod
3178+
def backward(ctx, dout):
3179+
labels, softmax_val = ctx.saved_tensor()
3180+
3181+
preds_grad = paddle.incubate.nn.functional.cross_entropy_with_softmax_bwd_w_downcast(
3182+
labels, softmax_val.cast(paddle.float32), dout.cast(paddle.float32)
3183+
)
3184+
3185+
return preds_grad, None
3186+
3187+
31663188
class DeepseekV2PretrainingCriterion(nn.Layer):
31673189
"""
31683190
Criterion for Mixtral.
@@ -3190,7 +3212,7 @@ def forward(self, prediction_scores, masked_lm_labels, router_loss=None, mtp_log
31903212

31913213
def compute_loss(preds, labels):
31923214
with paddle.amp.auto_cast(False):
3193-
masked_lm_loss = self.loss_func(preds.astype("float32"), labels.unsqueeze(2))
3215+
masked_lm_loss = FastCrossEntropyFunction.apply(preds, labels.unsqueeze(2))
31943216
binary_sequence = paddle.where(
31953217
masked_lm_loss > 0, paddle.ones_like(masked_lm_loss), paddle.zeros_like(masked_lm_loss)
31963218
)

0 commit comments

Comments
 (0)