Skip to content

Commit ae4f27b

Browse files
authored
support quant return transpose only (#10833)
* support quant tanrspose only * fix bug
1 parent 5ccd5a8 commit ae4f27b

File tree

1 file changed

+86
-58
lines changed

1 file changed

+86
-58
lines changed

paddlenlp/transformers/fp8_utils.py

Lines changed: 86 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -108,20 +108,24 @@ def forward(ctx, x, custom_map):
108108
x, output_scale_transpose=True, quant_method="1x128", input_transpose=False
109109
)
110110
x = padding(x, 0)
111-
_, _, x_t_fp8, x_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
112-
x, output_scale_transpose=True, quant_method="1x128", input_transpose=True
111+
x_t_fp8, x_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
112+
x, output_scale_transpose=True, quant_method="1x128", input_transpose=True, return_transpose_only=True
113113
)
114114
else:
115115
x_fp8, x_scale, x_t_fp8, x_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
116116
x, output_scale_transpose=True, quant_method="1x128", input_transpose=True
117117
)
118118

119-
_, _, w_fp8, w_sacle = paddle.incubate.nn.functional.fp8_quant_blockwise(
120-
weight, output_scale_transpose=False, quant_method="128x128", input_transpose=True
119+
w_fp8, w_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
120+
weight,
121+
output_scale_transpose=False,
122+
quant_method="128x128",
123+
input_transpose=True,
124+
return_transpose_only=True,
121125
)
122126

123127
out = paddle.empty([x_fp8.shape[0], w_fp8.shape[0]], dtype=x.dtype)
124-
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale.T), (w_fp8, w_sacle), out, num_sms=112)
128+
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale.T), (w_fp8, w_scale), out, num_sms=112)
125129
out = out.reshape([x_orig_shape[0], -1, weight.shape[-1]])
126130

127131
# save for bwd
@@ -140,20 +144,24 @@ def backward(ctx, dout):
140144
dout_2d, output_scale_transpose=True, quant_method="1x128", input_transpose=False
141145
)
142146
dout_2d = padding(dout_2d, 0)
143-
_, _, dout_t_fp8, dout_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
144-
dout_2d, output_scale_transpose=True, quant_method="1x128", input_transpose=True
147+
dout_t_fp8, dout_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
148+
dout_2d,
149+
output_scale_transpose=True,
150+
quant_method="1x128",
151+
input_transpose=True,
152+
return_transpose_only=True,
145153
)
146154
else:
147155
dout_fp8, dout_scale, dout_t_fp8, dout_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
148156
dout_2d, output_scale_transpose=True, quant_method="1x128", input_transpose=True
149157
)
150-
w_fp8, w_sacle = paddle.incubate.nn.functional.fp8_quant_blockwise(
158+
w_fp8, w_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
151159
weight, output_scale_transpose=False, quant_method="128x128", input_transpose=False
152160
)
153161
dx = paddle.empty([ctx.x_t_shape[1], ctx.x_t_shape[0]], dout.dtype)
154162
dx_orig_shape = dout.shape[:-1]
155163
dx_orig_shape.append(ctx.x_t_shape[0])
156-
deep_gemm.gemm_fp8_fp8_bf16_nt((dout_fp8, dout_scale.T), (w_fp8, w_sacle), dx)
164+
deep_gemm.gemm_fp8_fp8_bf16_nt((dout_fp8, dout_scale.T), (w_fp8, w_scale), dx)
157165
dx = dx.reshape(dx_orig_shape)
158166

159167
# ===== dw1 = deep_gemm(x_t_fp8, dout_t_fp8)
@@ -204,13 +212,17 @@ def forward(ctx, x, custom_map):
204212
x_fp8, x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
205213
x, output_scale_transpose=True, quant_method="1x128", input_transpose=False
206214
)
207-
_, _, w_fp8, w_sacle = paddle.incubate.nn.functional.fp8_quant_blockwise(
208-
weight, output_scale_transpose=False, quant_method="128x128", input_transpose=True
215+
w_fp8, w_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
216+
weight,
217+
output_scale_transpose=False,
218+
quant_method="128x128",
219+
input_transpose=True,
220+
return_transpose_only=True,
209221
)
210222

211223
# compute out = mm(x, w_t)
212224
out = paddle.empty([x_fp8.shape[0], w_fp8.shape[0]], dtype=x.dtype)
213-
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale.T), (w_fp8, w_sacle), out, num_sms=112)
225+
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale.T), (w_fp8, w_scale), out, num_sms=112)
214226
out = out.reshape([x_orig_shape[0], -1, weight.shape[-1]])
215227

216228
ctx.save_for_backward(x, weight)
@@ -223,11 +235,11 @@ def backward(ctx, dout):
223235

224236
# padding
225237
x = padding(x, 0)
226-
_, _, x_t_fp8, x_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
227-
x, output_scale_transpose=True, quant_method="1x128", input_transpose=True
238+
x_t_fp8, x_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
239+
x, output_scale_transpose=True, quant_method="1x128", input_transpose=True, return_transpose_only=True
228240
)
229241

230-
w_fp8, w_sacle = paddle.incubate.nn.functional.fp8_quant_blockwise(
242+
w_fp8, w_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
231243
weight, output_scale_transpose=False, quant_method="128x128", input_transpose=False
232244
)
233245

@@ -237,16 +249,20 @@ def backward(ctx, dout):
237249
dout_2d, output_scale_transpose=True, quant_method="1x128", input_transpose=False
238250
)
239251
dout_2d = padding(dout_2d, 0)
240-
_, _, dout_t_fp8, dout_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
241-
dout_2d, output_scale_transpose=True, quant_method="1x128", input_transpose=True
252+
dout_t_fp8, dout_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
253+
dout_2d,
254+
output_scale_transpose=True,
255+
quant_method="1x128",
256+
input_transpose=True,
257+
return_transpose_only=True,
242258
)
243259
else:
244260
dout_fp8, dout_scale, dout_t_fp8, dout_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
245261
dout_2d, output_scale_transpose=True, quant_method="1x128", input_transpose=True
246262
)
247263

248264
dx = paddle.empty([dout_fp8.shape[0], w_fp8.shape[0]], dout.dtype)
249-
deep_gemm.gemm_fp8_fp8_bf16_nt((dout_fp8, dout_scale.T), (w_fp8, w_sacle), dx, num_sms=112)
265+
deep_gemm.gemm_fp8_fp8_bf16_nt((dout_fp8, dout_scale.T), (w_fp8, w_scale), dx, num_sms=112)
250266
dx = dx.reshape(dx_orig_shape)
251267

252268
# ===== dw1 = deep_gemm(x_t_fp8, dout_t_fp8)
@@ -293,11 +309,11 @@ def fp8_mlp_fwd(x, w1, w2):
293309
x, output_scale_transpose=True, quant_method="1x128", input_transpose=False
294310
)
295311

296-
_, _, w1_fp8, w1_sacle = paddle.incubate.nn.functional.fp8_quant_blockwise(
297-
w1, output_scale_transpose=False, quant_method="128x128", input_transpose=True
312+
w1_fp8, w1_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
313+
w1, output_scale_transpose=False, quant_method="128x128", input_transpose=True, return_transpose_only=True
298314
)
299315
o1 = paddle.empty([x_fp8.shape[0], w1_fp8.shape[0]], dtype=x.dtype)
300-
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale.T), (w1_fp8, w1_sacle), o1, num_sms=112)
316+
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale.T), (w1_fp8, w1_scale), o1, num_sms=112)
301317

302318
# ===== o2 = swiglu(o1) =====
303319
o2 = swiglu(o1)
@@ -306,8 +322,8 @@ def fp8_mlp_fwd(x, w1, w2):
306322
)
307323

308324
# ===== o3 = deep_gemm(o2_fp8, w2_t_fp8) =====
309-
_, _, w2_t_fp8, w2_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
310-
w2, output_scale_transpose=False, quant_method="128x128", input_transpose=True
325+
w2_t_fp8, w2_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
326+
w2, output_scale_transpose=False, quant_method="128x128", input_transpose=True, return_transpose_only=True
311327
)
312328
o3 = paddle.empty([o2_fp8.shape[0], w2_t_fp8.shape[0]], dtype=o1.dtype)
313329
deep_gemm.gemm_fp8_fp8_bf16_nt((o2_fp8, o2_scale.T), (w2_t_fp8, w2_t_scale), o3, num_sms=112)
@@ -333,15 +349,15 @@ def fp8_mlp_bwd(do3, x, w1, w2):
333349
x, output_scale_transpose=True, quant_method="1x128", input_transpose=False
334350
)
335351
x = padding(x, 0)
336-
_, _, x_t_fp8, x_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
337-
x, output_scale_transpose=True, quant_method="1x128", input_transpose=True
352+
x_t_fp8, x_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
353+
x, output_scale_transpose=True, quant_method="1x128", input_transpose=True, return_transpose_only=True
338354
)
339355

340-
_, _, w1_fp8, w1_sacle = paddle.incubate.nn.functional.fp8_quant_blockwise(
341-
w1, output_scale_transpose=False, quant_method="128x128", input_transpose=True
356+
w1_fp8, w1_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
357+
w1, output_scale_transpose=False, quant_method="128x128", input_transpose=True, return_transpose_only=True
342358
)
343359
o1 = paddle.empty([x_fp8.shape[0], w1_fp8.shape[0]], dtype=do3.dtype)
344-
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale.T), (w1_fp8, w1_sacle), o1, num_sms=112)
360+
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale.T), (w1_fp8, w1_scale), o1, num_sms=112)
345361

346362
# ===== [recompute] o2 = swiglu(o1) =====
347363
o2 = swiglu(o1)
@@ -352,8 +368,8 @@ def fp8_mlp_bwd(do3, x, w1, w2):
352368
do3, output_scale_transpose=True, quant_method="1x128", input_transpose=False
353369
)
354370
do3 = padding(do3, 0)
355-
_, _, do3_t_fp8, do3_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
356-
do3, output_scale_transpose=True, quant_method="1x128", input_transpose=True
371+
do3_t_fp8, do3_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
372+
do3, output_scale_transpose=True, quant_method="1x128", input_transpose=True, return_transpose_only=True
357373
)
358374
else:
359375
do3_fp8, do3_scale, do3_t_fp8, do3_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
@@ -367,8 +383,8 @@ def fp8_mlp_bwd(do3, x, w1, w2):
367383

368384
# ===== dw2 = deep_gemm(o2_t_fp8, do3_t_fp8)
369385
o2 = padding(o2, 0)
370-
_, _, o2_t_fp8, o2_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
371-
o2, output_scale_transpose=True, quant_method="1x128", input_transpose=True
386+
o2_t_fp8, o2_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
387+
o2, output_scale_transpose=True, quant_method="1x128", input_transpose=True, return_transpose_only=True
372388
)
373389

374390
if hasattr(w2, "main_grad"):
@@ -409,18 +425,18 @@ def fp8_mlp_bwd(do3, x, w1, w2):
409425
do1, output_scale_transpose=True, quant_method="1x128", input_transpose=False
410426
)
411427
do1 = padding(do1, 0)
412-
_, _, do1_t_fp8, do1_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
413-
do1, output_scale_transpose=True, quant_method="1x128", input_transpose=True
428+
do1_t_fp8, do1_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
429+
do1, output_scale_transpose=True, quant_method="1x128", input_transpose=True, return_transpose_only=True
414430
)
415431
else:
416432
do1_fp8, do1_scale, do1_t_fp8, do1_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
417433
do1, output_scale_transpose=True, quant_method="1x128", input_transpose=True
418434
)
419-
w1_fp8, w1_sacle = paddle.incubate.nn.functional.fp8_quant_blockwise(
435+
w1_fp8, w1_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
420436
w1, output_scale_transpose=False, quant_method="128x128", input_transpose=False
421437
)
422438
dx = paddle.empty([do1_fp8.shape[0], w1_fp8.shape[0]], do1.dtype)
423-
deep_gemm.gemm_fp8_fp8_bf16_nt((do1_fp8, do1_scale.T), (w1_fp8, w1_sacle), dx, num_sms=112)
439+
deep_gemm.gemm_fp8_fp8_bf16_nt((do1_fp8, do1_scale.T), (w1_fp8, w1_scale), dx, num_sms=112)
424440
if len(x_orig_shape) > 2:
425441
dx = dx.reshape([x_orig_shape[0], -1, dx.shape[-1]])
426442

@@ -469,11 +485,11 @@ def forward(ctx, x, w1, w2):
469485
x, output_scale_transpose=True, quant_method="1x128", input_transpose=False
470486
)
471487

472-
_, _, w1_fp8, w1_sacle = paddle.incubate.nn.functional.fp8_quant_blockwise(
473-
w1, output_scale_transpose=False, quant_method="128x128", input_transpose=True
488+
w1_fp8, w1_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
489+
w1, output_scale_transpose=False, quant_method="128x128", input_transpose=True, return_transpose_only=True
474490
)
475491
o1 = paddle.empty([x_fp8.shape[0], w1_fp8.shape[0]], dtype=x.dtype)
476-
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale.T), (w1_fp8, w1_sacle), o1)
492+
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale.T), (w1_fp8, w1_scale), o1)
477493

478494
# ===== o2 = swiglu(o1) =====
479495
o2 = swiglu(o1)
@@ -482,8 +498,8 @@ def forward(ctx, x, w1, w2):
482498
)
483499

484500
# ===== o3 = deep_gemm(o2_fp8, w2_t_fp8) =====
485-
_, _, w2_t_fp8, w2_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
486-
w2, output_scale_transpose=False, quant_method="128x128", input_transpose=True
501+
w2_t_fp8, w2_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
502+
w2, output_scale_transpose=False, quant_method="128x128", input_transpose=True, return_transpose_only=True
487503
)
488504
o3 = paddle.empty([o2_fp8.shape[0], w2_t_fp8.shape[0]], dtype=o1.dtype)
489505
deep_gemm.gemm_fp8_fp8_bf16_nt((o2_fp8, o2_scale.T), (w2_t_fp8, w2_t_scale), o3)
@@ -510,17 +526,21 @@ def backward(ctx, do3):
510526
x_fp8, x_scale, w1, w2, x_orig_shape = ctx.saved_tensor()
511527
x_orig_shape = x_orig_shape.numpy()
512528

513-
_, _, w1_fp8, w1_sacle = paddle.incubate.nn.functional.fp8_quant_blockwise(
514-
w1, output_scale_transpose=False, quant_method="128x128", input_transpose=True
529+
w1_fp8, w1_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
530+
w1, output_scale_transpose=False, quant_method="128x128", input_transpose=True, return_transpose_only=True
515531
)
516532
o1 = paddle.empty([x_fp8.shape[0], w1_fp8.shape[0]], dtype=do3.dtype)
517-
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale.T), (w1_fp8, w1_sacle), o1)
533+
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale.T), (w1_fp8, w1_scale), o1)
518534

519535
x_dequant_fp16 = paddle.incubate.nn.functional.fused_act_dequant(x_fp8, x_scale.T.contiguous())
520536
x_dequant_fp16 = padding(x_dequant_fp16, 0)
521537

522-
_, _, x_t_fp8, x_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
523-
x_dequant_fp16, output_scale_transpose=True, quant_method="1x128", input_transpose=True
538+
x_t_fp8, x_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
539+
x_dequant_fp16,
540+
output_scale_transpose=True,
541+
quant_method="1x128",
542+
input_transpose=True,
543+
return_transpose_only=True,
524544
)
525545

526546
# ===== [recompute] o2 = swiglu(o1) =====
@@ -532,8 +552,12 @@ def backward(ctx, do3):
532552
do3, output_scale_transpose=True, quant_method="1x128", input_transpose=False
533553
)
534554
do3 = padding(do3, 0)
535-
_, _, do3_t_fp8, do3_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
536-
do3, output_scale_transpose=True, quant_method="1x128", input_transpose=True
555+
do3_t_fp8, do3_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
556+
do3,
557+
output_scale_transpose=True,
558+
quant_method="1x128",
559+
input_transpose=True,
560+
return_transpose_only=True,
537561
)
538562
else:
539563
do3_fp8, do3_scale, do3_t_fp8, do3_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
@@ -547,8 +571,8 @@ def backward(ctx, do3):
547571

548572
# ===== dw2 = deep_gemm(o2_t_fp8, do3_t_fp8)
549573
o2 = padding(o2, 0)
550-
_, _, o2_t_fp8, o2_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
551-
o2, output_scale_transpose=True, quant_method="1x128", input_transpose=True
574+
o2_t_fp8, o2_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
575+
o2, output_scale_transpose=True, quant_method="1x128", input_transpose=True, return_transpose_only=True
552576
)
553577

554578
dw2 = kitchen_fp8_gemm(o2_t_fp8, o2_t_scale, do3_t_fp8, do3_t_scale, True, True, rtn_dtype=paddle.float32)
@@ -562,18 +586,22 @@ def backward(ctx, do3):
562586
do1, output_scale_transpose=True, quant_method="1x128", input_transpose=False
563587
)
564588
do1 = padding(do1, 0)
565-
_, _, do1_t_fp8, do1_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
566-
do1, output_scale_transpose=True, quant_method="1x128", input_transpose=True
589+
do1_t_fp8, do1_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
590+
do1,
591+
output_scale_transpose=True,
592+
quant_method="1x128",
593+
input_transpose=True,
594+
return_transpose_only=True,
567595
)
568596
else:
569597
do1_fp8, do1_scale, do1_t_fp8, do1_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
570598
do1, output_scale_transpose=True, quant_method="1x128", input_transpose=True
571599
)
572-
w1_fp8, w1_sacle = paddle.incubate.nn.functional.fp8_quant_blockwise(
600+
w1_fp8, w1_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
573601
w1, output_scale_transpose=False, quant_method="128x128", input_transpose=False
574602
)
575603
dx = paddle.empty([do1_fp8.shape[0], w1_fp8.shape[0]], do1.dtype)
576-
deep_gemm.gemm_fp8_fp8_bf16_nt((do1_fp8, do1_scale.T), (w1_fp8, w1_sacle), dx)
604+
deep_gemm.gemm_fp8_fp8_bf16_nt((do1_fp8, do1_scale.T), (w1_fp8, w1_scale), dx)
577605
if len(x_orig_shape) > 2:
578606
dx = dx.reshape([x_orig_shape[0], -1, dx.shape[-1]])
579607

@@ -739,9 +767,9 @@ def fwd_down(self, o1, unzipped_probs, expert_w2, num_expert, o3=None, clear_o1=
739767
[m_sum, k] = [m_sum, n] * [num_groups, n, k]
740768
"""
741769
# concat and transpose w2
742-
w2_quant, w2_sacle = paddle.incubate.nn.functional.fused_stack_transpose_quant(expert_w2, transpose=True)
770+
w2_quant, w2_scale = paddle.incubate.nn.functional.fused_stack_transpose_quant(expert_w2, transpose=True)
743771
w2_quant = w2_quant.reshape([num_expert, -1, w2_quant.shape[-1]])
744-
w2_sacle = w2_sacle.reshape([num_expert, -1, w2_sacle.shape[-1]])
772+
w2_scale = w2_scale.reshape([num_expert, -1, w2_scale.shape[-1]])
745773

746774
# quant o2
747775
with paddle.amp.auto_cast(False):
@@ -762,10 +790,10 @@ def fwd_down(self, o1, unzipped_probs, expert_w2, num_expert, o3=None, clear_o1=
762790
o3 = paddle.empty(o3_shape, dtype=o1.dtype)
763791
if numpy.prod(o2_fp8.shape) != 0:
764792
if self.is_split_group_gemm:
765-
split_group_gemm(o2_fp8, o2_scale, w2_quant, w2_sacle, self.tokens_per_expert, o3)
793+
split_group_gemm(o2_fp8, o2_scale, w2_quant, w2_scale, self.tokens_per_expert, o3)
766794
else:
767795
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
768-
(o2_fp8, o2_scale), (w2_quant, w2_sacle), o3, m_indices=self.m_indices, num_sms=112
796+
(o2_fp8, o2_scale), (w2_quant, w2_scale), o3, m_indices=self.m_indices, num_sms=112
769797
)
770798
return o3, unzipped_probs
771799

0 commit comments

Comments
 (0)