Skip to content

Commit b9db2c1

Browse files
authored
Refine fp8 node (#10798)
* refine fp8 node * refine fp8 node * fix bug * fix * fix bug * opt tma
1 parent d147557 commit b9db2c1

File tree

5 files changed

+330
-80
lines changed

5 files changed

+330
-80
lines changed

paddlenlp/transformers/deepseek_v2/configuration.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,9 @@ def __init__(
181181
using_flex_token=False,
182182
use_dualpipev=False,
183183
send_mtp_embed=False,
184+
recompute_fwd_gate_up=False,
185+
dequant_input=False,
186+
is_split_group_gemm=False,
184187
**kwargs,
185188
):
186189
self.vocab_size = vocab_size
@@ -231,6 +234,10 @@ def __init__(
231234
self.using_flex_token = using_flex_token
232235
self.use_dualpipev = use_dualpipev
233236
self.send_mtp_embed = send_mtp_embed
237+
self.recompute_fwd_gate_up = recompute_fwd_gate_up
238+
self.dequant_input = dequant_input
239+
self.is_split_group_gemm = is_split_group_gemm
240+
234241

235242
super().__init__(
236243
pad_token_id=pad_token_id,

paddlenlp/transformers/deepseek_v2/modeling_pp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def backward(self, output_grad):
174174

175175
assert not self.send_mtp_embed, "not support have mtp have yet"
176176

177-
dx, dw1, dw2 = fp8_mlp_bwd(do3, self.x_fp8, self.x_scale, self.shared_experts.w1, self.shared_experts.w2)
177+
dx = fp8_mlp_bwd(do3, self.x_fp8, self.x_scale, self.shared_experts.w1, self.shared_experts.w2)
178178

179179
self.x_fp8 = None
180180
self.x_scale = None

0 commit comments

Comments
 (0)