Skip to content

Commit 6964fc8

Browse files
phlrainwaliwali777
andauthored
mlp use add to (#10945)
Co-authored-by: xuexixi <[email protected]>
1 parent ba1cc7b commit 6964fc8

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

paddlenlp/transformers/fp8_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -696,13 +696,13 @@ def backward(ctx, do3):
696696
)
697697

698698
# ===== call func common_fp8_mlp_bwd =====
699-
dx, dw1, dw2 = FP8LinearFunctionBase.common_fp8_mlp_bwd(do3, x_fp8, x_scale, x_t_fp8, x_t_scale, w1, w2, False)
699+
dx = FP8LinearFunctionBase.common_fp8_mlp_bwd(do3, x_fp8, x_scale, x_t_fp8, x_t_scale, w1, w2, True)
700700

701701
# ===== reshape to origin shape =====
702702
if len(x_orig_shape) > 2:
703703
dx = dx.reshape([x_orig_shape[0], -1, dx.shape[-1]])
704704

705-
return dx, dw1, dw2
705+
return dx, None, None
706706

707707

708708
class FP8Mlp(paddle.nn.Layer):

0 commit comments

Comments
 (0)