@@ -3163,6 +3163,28 @@ def forward(
3163
3163
)
3164
3164
3165
3165
3166
+ class FastCrossEntropyFunction (paddle .autograd .PyLayer ):
3167
+ @staticmethod
3168
+ def forward (ctx , preds , labels ):
3169
+
3170
+ softmax_val , loss = paddle ._C_ops .cross_entropy_with_softmax (preds , labels , False , True , False , - 100 , - 1 )
3171
+
3172
+ # print("softmax val", softmax_val.dtype)
3173
+
3174
+ ctx .save_for_backward (labels , softmax_val )
3175
+ return loss
3176
+
3177
+ @staticmethod
3178
+ def backward (ctx , dout ):
3179
+ labels , softmax_val = ctx .saved_tensor ()
3180
+
3181
+ preds_grad = paddle .incubate .nn .functional .cross_entropy_with_softmax_bwd_w_downcast (
3182
+ labels , softmax_val .cast (paddle .float32 ), dout .cast (paddle .float32 )
3183
+ )
3184
+
3185
+ return preds_grad , None
3186
+
3187
+
3166
3188
class DeepseekV2PretrainingCriterion (nn .Layer ):
3167
3189
"""
3168
3190
Criterion for Mixtral.
@@ -3190,7 +3212,7 @@ def forward(self, prediction_scores, masked_lm_labels, router_loss=None, mtp_log
3190
3212
3191
3213
def compute_loss (preds , labels ):
3192
3214
with paddle .amp .auto_cast (False ):
3193
- masked_lm_loss = self . loss_func (preds . astype ( "float32" ) , labels .unsqueeze (2 ))
3215
+ masked_lm_loss = FastCrossEntropyFunction . apply (preds , labels .unsqueeze (2 ))
3194
3216
binary_sequence = paddle .where (
3195
3217
masked_lm_loss > 0 , paddle .ones_like (masked_lm_loss ), paddle .zeros_like (masked_lm_loss )
3196
3218
)
0 commit comments