Skip to content

Commit be62526

Browse files
authored
optimizer_dual_pp_post_node_memory (#10806)
1 parent fe5e87a commit be62526

File tree

2 files changed

+26
-20
lines changed

2 files changed

+26
-20
lines changed

paddlenlp/transformers/deepseek_v2/modeling_pp.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -153,13 +153,10 @@ def forward(self, inputs):
153153

154154
with paddle.no_grad():
155155
if self.shared_experts is not None:
156-
x_fp8, x_scale, shared_expert_output = fp8_mlp_fwd(
157-
hidden_states, self.shared_experts.w1, self.shared_experts.w2
158-
)
156+
shared_expert_output = fp8_mlp_fwd(hidden_states, self.shared_experts.w1, self.shared_experts.w2)
159157
final_hidden_states = final_hidden_states + shared_expert_output
160158

161-
self.x_fp8 = x_fp8
162-
self.x_scale = x_scale
159+
self.x = hidden_states
163160
self.l_aux = l_aux
164161
hidden_states = residual + final_hidden_states
165162

@@ -174,10 +171,9 @@ def backward(self, output_grad):
174171

175172
assert not self.send_mtp_embed, "not support have mtp have yet"
176173

177-
dx = fp8_mlp_bwd(do3, self.x_fp8, self.x_scale, self.shared_experts.w1, self.shared_experts.w2)
174+
dx = fp8_mlp_bwd(do3, self.x, self.shared_experts.w1, self.shared_experts.w2)
178175

179-
self.x_fp8 = None
180-
self.x_scale = None
176+
self.x = None
181177

182178
residual_grad = do3
183179

paddlenlp/transformers/fp8_utils.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -284,28 +284,35 @@ def fp8_mlp_fwd(x, w1, w2):
284284
if len(x_orig_shape) > 2:
285285
o3 = o3.reshape([x_orig_shape[0], -1, o3.shape[-1]])
286286

287-
return x_fp8, x_scale, o3
287+
return o3
288288

289289

290-
def fp8_mlp_bwd(do3, x_fp8, x_scale, w1, w2):
290+
def fp8_mlp_bwd(do3, x, w1, w2):
291291
do3_orig_shape = do3.shape
292292
do3 = do3.reshape([-1, do3_orig_shape[-1]])
293293

294-
x_orig_shape = x_fp8.shape
294+
x_orig_shape = x.shape
295+
x = x.reshape([-1, x_orig_shape[-1]])
296+
297+
if x.shape[0] % 128 == 0:
298+
x_fp8, x_scale, x_t_fp8, x_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
299+
x, output_scale_transpose=True, quant_method="1x128", input_transpose=True
300+
)
301+
else:
302+
x_fp8, x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
303+
x, output_scale_transpose=True, quant_method="1x128", input_transpose=False
304+
)
305+
x = padding(x, 0)
306+
_, _, x_t_fp8, x_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
307+
x, output_scale_transpose=True, quant_method="1x128", input_transpose=True
308+
)
295309

296310
_, _, w1_fp8, w1_sacle = paddle.incubate.nn.functional.fp8_quant_blockwise(
297311
w1, output_scale_transpose=False, quant_method="128x128", input_transpose=True
298312
)
299313
o1 = paddle.empty([x_fp8.shape[0], w1_fp8.shape[0]], dtype=do3.dtype)
300314
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale.T), (w1_fp8, w1_sacle), o1, num_sms=112)
301315

302-
x_dequant_fp16 = paddle.incubate.nn.functional.fused_act_dequant(x_fp8, x_scale.T.contiguous())
303-
x_dequant_fp16 = padding(x_dequant_fp16, 0)
304-
305-
_, _, x_t_fp8, x_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
306-
x_dequant_fp16, output_scale_transpose=True, quant_method="1x128", input_transpose=True
307-
)
308-
309316
# ===== [recompute] o2 = swiglu(o1) =====
310317
o2 = swiglu(o1)
311318

@@ -577,7 +584,10 @@ def split_group_gemm(x_fp8, x_scale, w_fp8, w_scale, tokens_per_expert, gemm_out
577584
x_scale_tma_align = x_scale[start_idx:end_idx].T.contiguous().T
578585

579586
deep_gemm.gemm_fp8_fp8_bf16_nt(
580-
(x_fp8[start_idx:end_idx], x_scale_tma_align), (w_fp8[i], w_scale[i]), gemm_out[start_idx:end_idx], num_sms=112
587+
(x_fp8[start_idx:end_idx], x_scale_tma_align),
588+
(w_fp8[i], w_scale[i]),
589+
gemm_out[start_idx:end_idx],
590+
num_sms=112,
581591
)
582592

583593
start_idx = end_idx
@@ -763,7 +773,7 @@ def bwd_dowm_input(self, expert_w2, unzipped_grad, o1, inplace_swiglu_prob=False
763773
(bw_w2_quant, bw_w2_scale),
764774
do2_s,
765775
m_indices=self.m_indices,
766-
num_sms=112
776+
num_sms=112,
767777
)
768778

769779
with paddle.amp.auto_cast(False):

0 commit comments

Comments
 (0)