@@ -50,6 +50,10 @@ def swiglu(x, y=None):
50
50
]
51
51
52
52
53
+ def get_sm_num ():
54
+ return 112
55
+
56
+
53
57
def set_parameter_color (
54
58
parameters , color , group = None , offline_quant_expert_weight = True , clear_origin_weight_when_offline_quant = True
55
59
):
@@ -159,7 +163,7 @@ def padding_and_quant_input(tensor):
159
163
tensor_t_fp8 , tensor_t_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
160
164
tensor ,
161
165
output_scale_transpose = True ,
162
- tquant_method = "1x128" ,
166
+ quant_method = "1x128" ,
163
167
input_transpose = True ,
164
168
return_transpose_only = True ,
165
169
)
@@ -178,7 +182,7 @@ def kitchen_gemm(
178
182
if out is None :
179
183
out = paddle .zeros ([x_fp8 .shape [0 ], w_fp8 .shape [0 ]], rtn_dtype )
180
184
if numpy .prod (x_fp8 .shape ) != 0 and numpy .prod (w_fp8 .shape ) != 0 :
181
- deep_gemm .wgrad_gemm_fp8_fp8_fp32_nt ((x_fp8 , x_scale ), (w_fp8 , w_scale ), out , num_sms = 118 )
185
+ deep_gemm .wgrad_gemm_fp8_fp8_fp32_nt ((x_fp8 , x_scale ), (w_fp8 , w_scale ), out , num_sms = get_sm_num () )
182
186
return out
183
187
184
188
if out is not None :
@@ -261,7 +265,9 @@ def compute_fp8_linear(
261
265
if out is None :
262
266
out = paddle .empty ([input_fp8 .shape [0 ], weight_fp8 .shape [0 ]], dtype = weight .dtype )
263
267
264
- deep_gemm .gemm_fp8_fp8_bf16_nt ((input_fp8 , input_scale .T ), (weight_fp8 , weight_scale ), out , num_sms = 118 )
268
+ deep_gemm .gemm_fp8_fp8_bf16_nt (
269
+ (input_fp8 , input_scale .T ), (weight_fp8 , weight_scale ), out , num_sms = get_sm_num ()
270
+ )
265
271
266
272
# Return outputs
267
273
if return_mode == "output_only" :
@@ -351,7 +357,7 @@ def common_fp8_mlp_bwd(
351
357
# Recompute o1 using deep_gemm(x_fp8, w1_t_fp8)
352
358
w1_fp8 , w1_scale = weight_quant (w1 , True )
353
359
o1 = paddle .empty ([x_fp8 .shape [0 ], w1_fp8 .shape [0 ]], dtype = do3 .dtype )
354
- deep_gemm .gemm_fp8_fp8_bf16_nt ((x_fp8 , x_scale .T ), (w1_fp8 , w1_scale ), o1 , num_sms = 118 )
360
+ deep_gemm .gemm_fp8_fp8_bf16_nt ((x_fp8 , x_scale .T ), (w1_fp8 , w1_scale ), o1 , num_sms = get_sm_num () )
355
361
356
362
# ===== [recompute] o2 = swiglu(o1) =====
357
363
o2 = swiglu (o1 )
@@ -838,7 +844,7 @@ def split_group_gemm(x_fp8, x_scale, w_fp8, w_scale, tokens_per_expert, gemm_out
838
844
(x_fp8 [start_idx :end_idx ], x_scale_tma_align ),
839
845
(w_fp8 [i ], w_scale [i ]),
840
846
gemm_out [start_idx :end_idx ],
841
- num_sms = 118 ,
847
+ num_sms = get_sm_num () ,
842
848
)
843
849
844
850
start_idx = end_idx
@@ -927,7 +933,7 @@ def fwd_gate_up(self, x, expert_w1, num_expert, tokens_per_expert, m_indices=Non
927
933
(w1_t_quant , w1_t_scale ),
928
934
o1 ,
929
935
m_indices = self .m_indices if m_indices is None else m_indices ,
930
- num_sms = 118 ,
936
+ num_sms = get_sm_num () ,
931
937
)
932
938
933
939
if m_indices is None :
@@ -981,7 +987,7 @@ def fwd_down(
981
987
(w2_quant , w2_scale ),
982
988
o3 ,
983
989
m_indices = m_indices if self .fwd_subbatch else self .m_indices ,
984
- num_sms = 118 ,
990
+ num_sms = get_sm_num () ,
985
991
)
986
992
987
993
return o3
@@ -1022,7 +1028,7 @@ def bwd_dowm_input(self, expert_w2, unzipped_grad, o1, tokens_per_expert, m_indi
1022
1028
(bw_w2_quant , bw_w2_scale ),
1023
1029
do2_s ,
1024
1030
m_indices = m_indices if self .bwd_subbatch else self .m_indices ,
1025
- num_sms = 118 ,
1031
+ num_sms = get_sm_num () ,
1026
1032
)
1027
1033
1028
1034
with paddle .amp .auto_cast (False ):
@@ -1068,7 +1074,7 @@ def bwd_gate_up_input(self, do1, expert_w1, tokens_per_expert, m_indices=None, d
1068
1074
(bw_w1_quant , bw_w1_scale ),
1069
1075
dx ,
1070
1076
m_indices = m_indices if self .bwd_subbatch else self .m_indices ,
1071
- num_sms = 118 ,
1077
+ num_sms = get_sm_num () ,
1072
1078
)
1073
1079
1074
1080
return dx
0 commit comments