@@ -120,7 +120,7 @@ def forward(ctx, x, weight):
120
120
)
121
121
122
122
out = paddle .empty ([x_fp8 .shape [0 ], w_fp8 .shape [0 ]], dtype = x .dtype )
123
- deep_gemm .gemm_fp8_fp8_bf16_nt ((x_fp8 , x_scale .T ), (w_fp8 , w_sacle ), out )
123
+ deep_gemm .gemm_fp8_fp8_bf16_nt ((x_fp8 , x_scale .T ), (w_fp8 , w_sacle ), out , num_sms = 112 )
124
124
out = out .reshape ([x_orig_shape [0 ], - 1 , weight .shape [- 1 ]])
125
125
126
126
# save for bwd
@@ -194,7 +194,7 @@ def forward(ctx, x, weight):
194
194
195
195
# compute out = mm(x, w_t)
196
196
out = paddle .empty ([x_fp8 .shape [0 ], w_fp8 .shape [0 ]], dtype = x .dtype )
197
- deep_gemm .gemm_fp8_fp8_bf16_nt ((x_fp8 , x_scale .T ), (w_fp8 , w_sacle ), out )
197
+ deep_gemm .gemm_fp8_fp8_bf16_nt ((x_fp8 , x_scale .T ), (w_fp8 , w_sacle ), out , num_sms = 112 )
198
198
out = out .reshape ([x_orig_shape [0 ], - 1 , weight .shape [- 1 ]])
199
199
200
200
ctx .save_for_backward (x , weight )
@@ -230,7 +230,7 @@ def backward(ctx, dout):
230
230
)
231
231
232
232
dx = paddle .empty ([dout_fp8 .shape [0 ], w_fp8 .shape [0 ]], dout .dtype )
233
- deep_gemm .gemm_fp8_fp8_bf16_nt ((dout_fp8 , dout_scale .T ), (w_fp8 , w_sacle ), dx )
233
+ deep_gemm .gemm_fp8_fp8_bf16_nt ((dout_fp8 , dout_scale .T ), (w_fp8 , w_sacle ), dx , num_sms = 112 )
234
234
dx = dx .reshape (dx_orig_shape )
235
235
236
236
# ===== dw1 = deep_gemm(x_t_fp8, dout_t_fp8)
@@ -267,7 +267,7 @@ def fp8_mlp_fwd(x, w1, w2):
267
267
w1 , output_scale_transpose = False , quant_method = "128x128" , input_transpose = True
268
268
)
269
269
o1 = paddle .empty ([x_fp8 .shape [0 ], w1_fp8 .shape [0 ]], dtype = x .dtype )
270
- deep_gemm .gemm_fp8_fp8_bf16_nt ((x_fp8 , x_scale .T ), (w1_fp8 , w1_sacle ), o1 )
270
+ deep_gemm .gemm_fp8_fp8_bf16_nt ((x_fp8 , x_scale .T ), (w1_fp8 , w1_sacle ), o1 , num_sms = 112 )
271
271
272
272
# ===== o2 = swiglu(o1) =====
273
273
o2 = swiglu (o1 )
@@ -280,7 +280,7 @@ def fp8_mlp_fwd(x, w1, w2):
280
280
w2 , output_scale_transpose = False , quant_method = "128x128" , input_transpose = True
281
281
)
282
282
o3 = paddle .empty ([o2_fp8 .shape [0 ], w2_t_fp8 .shape [0 ]], dtype = o1 .dtype )
283
- deep_gemm .gemm_fp8_fp8_bf16_nt ((o2_fp8 , o2_scale .T ), (w2_t_fp8 , w2_t_scale ), o3 )
283
+ deep_gemm .gemm_fp8_fp8_bf16_nt ((o2_fp8 , o2_scale .T ), (w2_t_fp8 , w2_t_scale ), o3 , num_sms = 112 )
284
284
if len (x_orig_shape ) > 2 :
285
285
o3 = o3 .reshape ([x_orig_shape [0 ], - 1 , o3 .shape [- 1 ]])
286
286
@@ -297,7 +297,7 @@ def fp8_mlp_bwd(do3, x_fp8, x_scale, w1, w2):
297
297
w1 , output_scale_transpose = False , quant_method = "128x128" , input_transpose = True
298
298
)
299
299
o1 = paddle .empty ([x_fp8 .shape [0 ], w1_fp8 .shape [0 ]], dtype = do3 .dtype )
300
- deep_gemm .gemm_fp8_fp8_bf16_nt ((x_fp8 , x_scale .T ), (w1_fp8 , w1_sacle ), o1 )
300
+ deep_gemm .gemm_fp8_fp8_bf16_nt ((x_fp8 , x_scale .T ), (w1_fp8 , w1_sacle ), o1 , num_sms = 112 )
301
301
302
302
x_dequant_fp16 = paddle .incubate .nn .functional .fused_act_dequant (x_fp8 , x_scale .T .contiguous ())
303
303
x_dequant_fp16 = padding (x_dequant_fp16 , 0 )
@@ -326,7 +326,7 @@ def fp8_mlp_bwd(do3, x_fp8, x_scale, w1, w2):
326
326
w2 , output_scale_transpose = False , quant_method = "128x128" , input_transpose = False
327
327
)
328
328
do2 = paddle .empty ([do3_fp8 .shape [0 ], w2_fp8 .shape [0 ]], do3 .dtype )
329
- deep_gemm .gemm_fp8_fp8_bf16_nt ((do3_fp8 , do3_scale .T ), (w2_fp8 , w2_scale ), do2 )
329
+ deep_gemm .gemm_fp8_fp8_bf16_nt ((do3_fp8 , do3_scale .T ), (w2_fp8 , w2_scale ), do2 , num_sms = 112 )
330
330
331
331
# ===== dw2 = deep_gemm(o2_t_fp8, do3_t_fp8)
332
332
o2 = padding (o2 , 0 )
@@ -383,7 +383,7 @@ def fp8_mlp_bwd(do3, x_fp8, x_scale, w1, w2):
383
383
w1 , output_scale_transpose = False , quant_method = "128x128" , input_transpose = False
384
384
)
385
385
dx = paddle .empty ([do1_fp8 .shape [0 ], w1_fp8 .shape [0 ]], do1 .dtype )
386
- deep_gemm .gemm_fp8_fp8_bf16_nt ((do1_fp8 , do1_scale .T ), (w1_fp8 , w1_sacle ), dx )
386
+ deep_gemm .gemm_fp8_fp8_bf16_nt ((do1_fp8 , do1_scale .T ), (w1_fp8 , w1_sacle ), dx , num_sms = 112 )
387
387
if len (x_orig_shape ) > 2 :
388
388
dx = dx .reshape ([x_orig_shape [0 ], - 1 , dx .shape [- 1 ]])
389
389
@@ -577,7 +577,7 @@ def split_group_gemm(x_fp8, x_scale, w_fp8, w_scale, tokens_per_expert, gemm_out
577
577
x_scale_tma_align = x_scale [start_idx :end_idx ].T .contiguous ().T
578
578
579
579
deep_gemm .gemm_fp8_fp8_bf16_nt (
580
- (x_fp8 [start_idx :end_idx ], x_scale_tma_align ), (w_fp8 [i ], w_scale [i ]), gemm_out [start_idx :end_idx ]
580
+ (x_fp8 [start_idx :end_idx ], x_scale_tma_align ), (w_fp8 [i ], w_scale [i ]), gemm_out [start_idx :end_idx ], num_sms = 112
581
581
)
582
582
583
583
start_idx = end_idx
@@ -681,7 +681,7 @@ def fwd_gate_up(self, x_bf16, expert_w1, num_expert, tokens_per_expert):
681
681
split_group_gemm (x_fp8 , x_scale , w1_t_quant , w1_t_scale , tokens_per_expert , o1 )
682
682
else :
683
683
deep_gemm .m_grouped_gemm_fp8_fp8_bf16_nt_contiguous (
684
- (x_fp8 , x_scale ), (w1_t_quant , w1_t_scale ), o1 , m_indices = self .m_indices
684
+ (x_fp8 , x_scale ), (w1_t_quant , w1_t_scale ), o1 , m_indices = self .m_indices , num_sms = 112
685
685
)
686
686
687
687
if self .dequant_input :
@@ -728,7 +728,7 @@ def fwd_down(self, o1, unzipped_probs, expert_w2, num_expert, o3=None, clear_o1=
728
728
split_group_gemm (o2_fp8 , o2_scale , w2_quant , w2_sacle , self .tokens_per_expert , o3 )
729
729
else :
730
730
deep_gemm .m_grouped_gemm_fp8_fp8_bf16_nt_contiguous (
731
- (o2_fp8 , o2_scale ), (w2_quant , w2_sacle ), o3 , m_indices = self .m_indices
731
+ (o2_fp8 , o2_scale ), (w2_quant , w2_sacle ), o3 , m_indices = self .m_indices , num_sms = 112
732
732
)
733
733
return o3 , unzipped_probs
734
734
@@ -763,6 +763,7 @@ def bwd_dowm_input(self, expert_w2, unzipped_grad, o1, inplace_swiglu_prob=False
763
763
(bw_w2_quant , bw_w2_scale ),
764
764
do2_s ,
765
765
m_indices = self .m_indices ,
766
+ num_sms = 112
766
767
)
767
768
768
769
with paddle .amp .auto_cast (False ):
@@ -806,7 +807,7 @@ def bwd_gate_up_input(self, do1, expert_w1, dx=None):
806
807
else :
807
808
do1_scale = paddle .transpose (paddle .transpose (do1_scale , [1 , 0 ]).contiguous (), [1 , 0 ])
808
809
deep_gemm .m_grouped_gemm_fp8_fp8_bf16_nt_contiguous (
809
- (do1_fp8 , do1_scale ), (bw_w1_quant , bw_w1_scale ), dx , m_indices = self .m_indices
810
+ (do1_fp8 , do1_scale ), (bw_w1_quant , bw_w1_scale ), dx , m_indices = self .m_indices , num_sms = 112
810
811
)
811
812
812
813
return dx
0 commit comments