@@ -108,20 +108,24 @@ def forward(ctx, x, custom_map):
108
108
x , output_scale_transpose = True , quant_method = "1x128" , input_transpose = False
109
109
)
110
110
x = padding (x , 0 )
111
- _ , _ , x_t_fp8 , x_t_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
112
- x , output_scale_transpose = True , quant_method = "1x128" , input_transpose = True
111
+ x_t_fp8 , x_t_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
112
+ x , output_scale_transpose = True , quant_method = "1x128" , input_transpose = True , return_transpose_only = True
113
113
)
114
114
else :
115
115
x_fp8 , x_scale , x_t_fp8 , x_t_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
116
116
x , output_scale_transpose = True , quant_method = "1x128" , input_transpose = True
117
117
)
118
118
119
- _ , _ , w_fp8 , w_sacle = paddle .incubate .nn .functional .fp8_quant_blockwise (
120
- weight , output_scale_transpose = False , quant_method = "128x128" , input_transpose = True
119
+ w_fp8 , w_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
120
+ weight ,
121
+ output_scale_transpose = False ,
122
+ quant_method = "128x128" ,
123
+ input_transpose = True ,
124
+ return_transpose_only = True ,
121
125
)
122
126
123
127
out = paddle .empty ([x_fp8 .shape [0 ], w_fp8 .shape [0 ]], dtype = x .dtype )
124
- deep_gemm .gemm_fp8_fp8_bf16_nt ((x_fp8 , x_scale .T ), (w_fp8 , w_sacle ), out , num_sms = 112 )
128
+ deep_gemm .gemm_fp8_fp8_bf16_nt ((x_fp8 , x_scale .T ), (w_fp8 , w_scale ), out , num_sms = 112 )
125
129
out = out .reshape ([x_orig_shape [0 ], - 1 , weight .shape [- 1 ]])
126
130
127
131
# save for bwd
@@ -140,20 +144,24 @@ def backward(ctx, dout):
140
144
dout_2d , output_scale_transpose = True , quant_method = "1x128" , input_transpose = False
141
145
)
142
146
dout_2d = padding (dout_2d , 0 )
143
- _ , _ , dout_t_fp8 , dout_t_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
144
- dout_2d , output_scale_transpose = True , quant_method = "1x128" , input_transpose = True
147
+ dout_t_fp8 , dout_t_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
148
+ dout_2d ,
149
+ output_scale_transpose = True ,
150
+ quant_method = "1x128" ,
151
+ input_transpose = True ,
152
+ return_transpose_only = True ,
145
153
)
146
154
else :
147
155
dout_fp8 , dout_scale , dout_t_fp8 , dout_t_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
148
156
dout_2d , output_scale_transpose = True , quant_method = "1x128" , input_transpose = True
149
157
)
150
- w_fp8 , w_sacle = paddle .incubate .nn .functional .fp8_quant_blockwise (
158
+ w_fp8 , w_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
151
159
weight , output_scale_transpose = False , quant_method = "128x128" , input_transpose = False
152
160
)
153
161
dx = paddle .empty ([ctx .x_t_shape [1 ], ctx .x_t_shape [0 ]], dout .dtype )
154
162
dx_orig_shape = dout .shape [:- 1 ]
155
163
dx_orig_shape .append (ctx .x_t_shape [0 ])
156
- deep_gemm .gemm_fp8_fp8_bf16_nt ((dout_fp8 , dout_scale .T ), (w_fp8 , w_sacle ), dx )
164
+ deep_gemm .gemm_fp8_fp8_bf16_nt ((dout_fp8 , dout_scale .T ), (w_fp8 , w_scale ), dx )
157
165
dx = dx .reshape (dx_orig_shape )
158
166
159
167
# ===== dw1 = deep_gemm(x_t_fp8, dout_t_fp8)
@@ -204,13 +212,17 @@ def forward(ctx, x, custom_map):
204
212
x_fp8 , x_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
205
213
x , output_scale_transpose = True , quant_method = "1x128" , input_transpose = False
206
214
)
207
- _ , _ , w_fp8 , w_sacle = paddle .incubate .nn .functional .fp8_quant_blockwise (
208
- weight , output_scale_transpose = False , quant_method = "128x128" , input_transpose = True
215
+ w_fp8 , w_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
216
+ weight ,
217
+ output_scale_transpose = False ,
218
+ quant_method = "128x128" ,
219
+ input_transpose = True ,
220
+ return_transpose_only = True ,
209
221
)
210
222
211
223
# compute out = mm(x, w_t)
212
224
out = paddle .empty ([x_fp8 .shape [0 ], w_fp8 .shape [0 ]], dtype = x .dtype )
213
- deep_gemm .gemm_fp8_fp8_bf16_nt ((x_fp8 , x_scale .T ), (w_fp8 , w_sacle ), out , num_sms = 112 )
225
+ deep_gemm .gemm_fp8_fp8_bf16_nt ((x_fp8 , x_scale .T ), (w_fp8 , w_scale ), out , num_sms = 112 )
214
226
out = out .reshape ([x_orig_shape [0 ], - 1 , weight .shape [- 1 ]])
215
227
216
228
ctx .save_for_backward (x , weight )
@@ -223,11 +235,11 @@ def backward(ctx, dout):
223
235
224
236
# padding
225
237
x = padding (x , 0 )
226
- _ , _ , x_t_fp8 , x_t_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
227
- x , output_scale_transpose = True , quant_method = "1x128" , input_transpose = True
238
+ x_t_fp8 , x_t_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
239
+ x , output_scale_transpose = True , quant_method = "1x128" , input_transpose = True , return_transpose_only = True
228
240
)
229
241
230
- w_fp8 , w_sacle = paddle .incubate .nn .functional .fp8_quant_blockwise (
242
+ w_fp8 , w_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
231
243
weight , output_scale_transpose = False , quant_method = "128x128" , input_transpose = False
232
244
)
233
245
@@ -237,16 +249,20 @@ def backward(ctx, dout):
237
249
dout_2d , output_scale_transpose = True , quant_method = "1x128" , input_transpose = False
238
250
)
239
251
dout_2d = padding (dout_2d , 0 )
240
- _ , _ , dout_t_fp8 , dout_t_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
241
- dout_2d , output_scale_transpose = True , quant_method = "1x128" , input_transpose = True
252
+ dout_t_fp8 , dout_t_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
253
+ dout_2d ,
254
+ output_scale_transpose = True ,
255
+ quant_method = "1x128" ,
256
+ input_transpose = True ,
257
+ return_transpose_only = True ,
242
258
)
243
259
else :
244
260
dout_fp8 , dout_scale , dout_t_fp8 , dout_t_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
245
261
dout_2d , output_scale_transpose = True , quant_method = "1x128" , input_transpose = True
246
262
)
247
263
248
264
dx = paddle .empty ([dout_fp8 .shape [0 ], w_fp8 .shape [0 ]], dout .dtype )
249
- deep_gemm .gemm_fp8_fp8_bf16_nt ((dout_fp8 , dout_scale .T ), (w_fp8 , w_sacle ), dx , num_sms = 112 )
265
+ deep_gemm .gemm_fp8_fp8_bf16_nt ((dout_fp8 , dout_scale .T ), (w_fp8 , w_scale ), dx , num_sms = 112 )
250
266
dx = dx .reshape (dx_orig_shape )
251
267
252
268
# ===== dw1 = deep_gemm(x_t_fp8, dout_t_fp8)
@@ -293,11 +309,11 @@ def fp8_mlp_fwd(x, w1, w2):
293
309
x , output_scale_transpose = True , quant_method = "1x128" , input_transpose = False
294
310
)
295
311
296
- _ , _ , w1_fp8 , w1_sacle = paddle .incubate .nn .functional .fp8_quant_blockwise (
297
- w1 , output_scale_transpose = False , quant_method = "128x128" , input_transpose = True
312
+ w1_fp8 , w1_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
313
+ w1 , output_scale_transpose = False , quant_method = "128x128" , input_transpose = True , return_transpose_only = True
298
314
)
299
315
o1 = paddle .empty ([x_fp8 .shape [0 ], w1_fp8 .shape [0 ]], dtype = x .dtype )
300
- deep_gemm .gemm_fp8_fp8_bf16_nt ((x_fp8 , x_scale .T ), (w1_fp8 , w1_sacle ), o1 , num_sms = 112 )
316
+ deep_gemm .gemm_fp8_fp8_bf16_nt ((x_fp8 , x_scale .T ), (w1_fp8 , w1_scale ), o1 , num_sms = 112 )
301
317
302
318
# ===== o2 = swiglu(o1) =====
303
319
o2 = swiglu (o1 )
@@ -306,8 +322,8 @@ def fp8_mlp_fwd(x, w1, w2):
306
322
)
307
323
308
324
# ===== o3 = deep_gemm(o2_fp8, w2_t_fp8) =====
309
- _ , _ , w2_t_fp8 , w2_t_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
310
- w2 , output_scale_transpose = False , quant_method = "128x128" , input_transpose = True
325
+ w2_t_fp8 , w2_t_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
326
+ w2 , output_scale_transpose = False , quant_method = "128x128" , input_transpose = True , return_transpose_only = True
311
327
)
312
328
o3 = paddle .empty ([o2_fp8 .shape [0 ], w2_t_fp8 .shape [0 ]], dtype = o1 .dtype )
313
329
deep_gemm .gemm_fp8_fp8_bf16_nt ((o2_fp8 , o2_scale .T ), (w2_t_fp8 , w2_t_scale ), o3 , num_sms = 112 )
@@ -333,15 +349,15 @@ def fp8_mlp_bwd(do3, x, w1, w2):
333
349
x , output_scale_transpose = True , quant_method = "1x128" , input_transpose = False
334
350
)
335
351
x = padding (x , 0 )
336
- _ , _ , x_t_fp8 , x_t_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
337
- x , output_scale_transpose = True , quant_method = "1x128" , input_transpose = True
352
+ x_t_fp8 , x_t_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
353
+ x , output_scale_transpose = True , quant_method = "1x128" , input_transpose = True , return_transpose_only = True
338
354
)
339
355
340
- _ , _ , w1_fp8 , w1_sacle = paddle .incubate .nn .functional .fp8_quant_blockwise (
341
- w1 , output_scale_transpose = False , quant_method = "128x128" , input_transpose = True
356
+ w1_fp8 , w1_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
357
+ w1 , output_scale_transpose = False , quant_method = "128x128" , input_transpose = True , return_transpose_only = True
342
358
)
343
359
o1 = paddle .empty ([x_fp8 .shape [0 ], w1_fp8 .shape [0 ]], dtype = do3 .dtype )
344
- deep_gemm .gemm_fp8_fp8_bf16_nt ((x_fp8 , x_scale .T ), (w1_fp8 , w1_sacle ), o1 , num_sms = 112 )
360
+ deep_gemm .gemm_fp8_fp8_bf16_nt ((x_fp8 , x_scale .T ), (w1_fp8 , w1_scale ), o1 , num_sms = 112 )
345
361
346
362
# ===== [recompute] o2 = swiglu(o1) =====
347
363
o2 = swiglu (o1 )
@@ -352,8 +368,8 @@ def fp8_mlp_bwd(do3, x, w1, w2):
352
368
do3 , output_scale_transpose = True , quant_method = "1x128" , input_transpose = False
353
369
)
354
370
do3 = padding (do3 , 0 )
355
- _ , _ , do3_t_fp8 , do3_t_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
356
- do3 , output_scale_transpose = True , quant_method = "1x128" , input_transpose = True
371
+ do3_t_fp8 , do3_t_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
372
+ do3 , output_scale_transpose = True , quant_method = "1x128" , input_transpose = True , return_transpose_only = True
357
373
)
358
374
else :
359
375
do3_fp8 , do3_scale , do3_t_fp8 , do3_t_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
@@ -367,8 +383,8 @@ def fp8_mlp_bwd(do3, x, w1, w2):
367
383
368
384
# ===== dw2 = deep_gemm(o2_t_fp8, do3_t_fp8)
369
385
o2 = padding (o2 , 0 )
370
- _ , _ , o2_t_fp8 , o2_t_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
371
- o2 , output_scale_transpose = True , quant_method = "1x128" , input_transpose = True
386
+ o2_t_fp8 , o2_t_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
387
+ o2 , output_scale_transpose = True , quant_method = "1x128" , input_transpose = True , return_transpose_only = True
372
388
)
373
389
374
390
if hasattr (w2 , "main_grad" ):
@@ -409,18 +425,18 @@ def fp8_mlp_bwd(do3, x, w1, w2):
409
425
do1 , output_scale_transpose = True , quant_method = "1x128" , input_transpose = False
410
426
)
411
427
do1 = padding (do1 , 0 )
412
- _ , _ , do1_t_fp8 , do1_t_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
413
- do1 , output_scale_transpose = True , quant_method = "1x128" , input_transpose = True
428
+ do1_t_fp8 , do1_t_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
429
+ do1 , output_scale_transpose = True , quant_method = "1x128" , input_transpose = True , return_transpose_only = True
414
430
)
415
431
else :
416
432
do1_fp8 , do1_scale , do1_t_fp8 , do1_t_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
417
433
do1 , output_scale_transpose = True , quant_method = "1x128" , input_transpose = True
418
434
)
419
- w1_fp8 , w1_sacle = paddle .incubate .nn .functional .fp8_quant_blockwise (
435
+ w1_fp8 , w1_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
420
436
w1 , output_scale_transpose = False , quant_method = "128x128" , input_transpose = False
421
437
)
422
438
dx = paddle .empty ([do1_fp8 .shape [0 ], w1_fp8 .shape [0 ]], do1 .dtype )
423
- deep_gemm .gemm_fp8_fp8_bf16_nt ((do1_fp8 , do1_scale .T ), (w1_fp8 , w1_sacle ), dx , num_sms = 112 )
439
+ deep_gemm .gemm_fp8_fp8_bf16_nt ((do1_fp8 , do1_scale .T ), (w1_fp8 , w1_scale ), dx , num_sms = 112 )
424
440
if len (x_orig_shape ) > 2 :
425
441
dx = dx .reshape ([x_orig_shape [0 ], - 1 , dx .shape [- 1 ]])
426
442
@@ -469,11 +485,11 @@ def forward(ctx, x, w1, w2):
469
485
x , output_scale_transpose = True , quant_method = "1x128" , input_transpose = False
470
486
)
471
487
472
- _ , _ , w1_fp8 , w1_sacle = paddle .incubate .nn .functional .fp8_quant_blockwise (
473
- w1 , output_scale_transpose = False , quant_method = "128x128" , input_transpose = True
488
+ w1_fp8 , w1_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
489
+ w1 , output_scale_transpose = False , quant_method = "128x128" , input_transpose = True , return_transpose_only = True
474
490
)
475
491
o1 = paddle .empty ([x_fp8 .shape [0 ], w1_fp8 .shape [0 ]], dtype = x .dtype )
476
- deep_gemm .gemm_fp8_fp8_bf16_nt ((x_fp8 , x_scale .T ), (w1_fp8 , w1_sacle ), o1 )
492
+ deep_gemm .gemm_fp8_fp8_bf16_nt ((x_fp8 , x_scale .T ), (w1_fp8 , w1_scale ), o1 )
477
493
478
494
# ===== o2 = swiglu(o1) =====
479
495
o2 = swiglu (o1 )
@@ -482,8 +498,8 @@ def forward(ctx, x, w1, w2):
482
498
)
483
499
484
500
# ===== o3 = deep_gemm(o2_fp8, w2_t_fp8) =====
485
- _ , _ , w2_t_fp8 , w2_t_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
486
- w2 , output_scale_transpose = False , quant_method = "128x128" , input_transpose = True
501
+ w2_t_fp8 , w2_t_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
502
+ w2 , output_scale_transpose = False , quant_method = "128x128" , input_transpose = True , return_transpose_only = True
487
503
)
488
504
o3 = paddle .empty ([o2_fp8 .shape [0 ], w2_t_fp8 .shape [0 ]], dtype = o1 .dtype )
489
505
deep_gemm .gemm_fp8_fp8_bf16_nt ((o2_fp8 , o2_scale .T ), (w2_t_fp8 , w2_t_scale ), o3 )
@@ -510,17 +526,21 @@ def backward(ctx, do3):
510
526
x_fp8 , x_scale , w1 , w2 , x_orig_shape = ctx .saved_tensor ()
511
527
x_orig_shape = x_orig_shape .numpy ()
512
528
513
- _ , _ , w1_fp8 , w1_sacle = paddle .incubate .nn .functional .fp8_quant_blockwise (
514
- w1 , output_scale_transpose = False , quant_method = "128x128" , input_transpose = True
529
+ w1_fp8 , w1_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
530
+ w1 , output_scale_transpose = False , quant_method = "128x128" , input_transpose = True , return_transpose_only = True
515
531
)
516
532
o1 = paddle .empty ([x_fp8 .shape [0 ], w1_fp8 .shape [0 ]], dtype = do3 .dtype )
517
- deep_gemm .gemm_fp8_fp8_bf16_nt ((x_fp8 , x_scale .T ), (w1_fp8 , w1_sacle ), o1 )
533
+ deep_gemm .gemm_fp8_fp8_bf16_nt ((x_fp8 , x_scale .T ), (w1_fp8 , w1_scale ), o1 )
518
534
519
535
x_dequant_fp16 = paddle .incubate .nn .functional .fused_act_dequant (x_fp8 , x_scale .T .contiguous ())
520
536
x_dequant_fp16 = padding (x_dequant_fp16 , 0 )
521
537
522
- _ , _ , x_t_fp8 , x_t_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
523
- x_dequant_fp16 , output_scale_transpose = True , quant_method = "1x128" , input_transpose = True
538
+ x_t_fp8 , x_t_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
539
+ x_dequant_fp16 ,
540
+ output_scale_transpose = True ,
541
+ quant_method = "1x128" ,
542
+ input_transpose = True ,
543
+ return_transpose_only = True ,
524
544
)
525
545
526
546
# ===== [recompute] o2 = swiglu(o1) =====
@@ -532,8 +552,12 @@ def backward(ctx, do3):
532
552
do3 , output_scale_transpose = True , quant_method = "1x128" , input_transpose = False
533
553
)
534
554
do3 = padding (do3 , 0 )
535
- _ , _ , do3_t_fp8 , do3_t_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
536
- do3 , output_scale_transpose = True , quant_method = "1x128" , input_transpose = True
555
+ do3_t_fp8 , do3_t_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
556
+ do3 ,
557
+ output_scale_transpose = True ,
558
+ quant_method = "1x128" ,
559
+ input_transpose = True ,
560
+ return_transpose_only = True ,
537
561
)
538
562
else :
539
563
do3_fp8 , do3_scale , do3_t_fp8 , do3_t_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
@@ -547,8 +571,8 @@ def backward(ctx, do3):
547
571
548
572
# ===== dw2 = deep_gemm(o2_t_fp8, do3_t_fp8)
549
573
o2 = padding (o2 , 0 )
550
- _ , _ , o2_t_fp8 , o2_t_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
551
- o2 , output_scale_transpose = True , quant_method = "1x128" , input_transpose = True
574
+ o2_t_fp8 , o2_t_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
575
+ o2 , output_scale_transpose = True , quant_method = "1x128" , input_transpose = True , return_transpose_only = True
552
576
)
553
577
554
578
dw2 = kitchen_fp8_gemm (o2_t_fp8 , o2_t_scale , do3_t_fp8 , do3_t_scale , True , True , rtn_dtype = paddle .float32 )
@@ -562,18 +586,22 @@ def backward(ctx, do3):
562
586
do1 , output_scale_transpose = True , quant_method = "1x128" , input_transpose = False
563
587
)
564
588
do1 = padding (do1 , 0 )
565
- _ , _ , do1_t_fp8 , do1_t_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
566
- do1 , output_scale_transpose = True , quant_method = "1x128" , input_transpose = True
589
+ do1_t_fp8 , do1_t_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
590
+ do1 ,
591
+ output_scale_transpose = True ,
592
+ quant_method = "1x128" ,
593
+ input_transpose = True ,
594
+ return_transpose_only = True ,
567
595
)
568
596
else :
569
597
do1_fp8 , do1_scale , do1_t_fp8 , do1_t_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
570
598
do1 , output_scale_transpose = True , quant_method = "1x128" , input_transpose = True
571
599
)
572
- w1_fp8 , w1_sacle = paddle .incubate .nn .functional .fp8_quant_blockwise (
600
+ w1_fp8 , w1_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
573
601
w1 , output_scale_transpose = False , quant_method = "128x128" , input_transpose = False
574
602
)
575
603
dx = paddle .empty ([do1_fp8 .shape [0 ], w1_fp8 .shape [0 ]], do1 .dtype )
576
- deep_gemm .gemm_fp8_fp8_bf16_nt ((do1_fp8 , do1_scale .T ), (w1_fp8 , w1_sacle ), dx )
604
+ deep_gemm .gemm_fp8_fp8_bf16_nt ((do1_fp8 , do1_scale .T ), (w1_fp8 , w1_scale ), dx )
577
605
if len (x_orig_shape ) > 2 :
578
606
dx = dx .reshape ([x_orig_shape [0 ], - 1 , dx .shape [- 1 ]])
579
607
@@ -739,9 +767,9 @@ def fwd_down(self, o1, unzipped_probs, expert_w2, num_expert, o3=None, clear_o1=
739
767
[m_sum, k] = [m_sum, n] * [num_groups, n, k]
740
768
"""
741
769
# concat and transpose w2
742
- w2_quant , w2_sacle = paddle .incubate .nn .functional .fused_stack_transpose_quant (expert_w2 , transpose = True )
770
+ w2_quant , w2_scale = paddle .incubate .nn .functional .fused_stack_transpose_quant (expert_w2 , transpose = True )
743
771
w2_quant = w2_quant .reshape ([num_expert , - 1 , w2_quant .shape [- 1 ]])
744
- w2_sacle = w2_sacle .reshape ([num_expert , - 1 , w2_sacle .shape [- 1 ]])
772
+ w2_scale = w2_scale .reshape ([num_expert , - 1 , w2_scale .shape [- 1 ]])
745
773
746
774
# quant o2
747
775
with paddle .amp .auto_cast (False ):
@@ -762,10 +790,10 @@ def fwd_down(self, o1, unzipped_probs, expert_w2, num_expert, o3=None, clear_o1=
762
790
o3 = paddle .empty (o3_shape , dtype = o1 .dtype )
763
791
if numpy .prod (o2_fp8 .shape ) != 0 :
764
792
if self .is_split_group_gemm :
765
- split_group_gemm (o2_fp8 , o2_scale , w2_quant , w2_sacle , self .tokens_per_expert , o3 )
793
+ split_group_gemm (o2_fp8 , o2_scale , w2_quant , w2_scale , self .tokens_per_expert , o3 )
766
794
else :
767
795
deep_gemm .m_grouped_gemm_fp8_fp8_bf16_nt_contiguous (
768
- (o2_fp8 , o2_scale ), (w2_quant , w2_sacle ), o3 , m_indices = self .m_indices , num_sms = 112
796
+ (o2_fp8 , o2_scale ), (w2_quant , w2_scale ), o3 , m_indices = self .m_indices , num_sms = 112
769
797
)
770
798
return o3 , unzipped_probs
771
799
0 commit comments