@@ -62,16 +62,16 @@ def test_kernel(
6262 use_fp8_w8a8 : bool ,
6363 is_up : bool ,
6464 block_shape ,
65- num_fused_experts : int ,
65+ num_fused_shared_experts : int ,
6666 ** config ,
6767):
6868 set_seed ()
6969 input_tuples = []
7070
7171 a = torch .randn ((m , k ), device = "cuda" , dtype = dtype ) / 10
7272 w1_scale = w2_scale = None
73- if num_fused_experts > 0 :
74- expert_num += num_fused_experts
73+ if num_fused_shared_experts > 0 :
74+ expert_num += num_fused_shared_experts
7575
7676 if use_fp8_w8a8 :
7777 init_dtype = dtype
@@ -95,28 +95,28 @@ def test_kernel(
9595 w1 = torch .randn (expert_num , 2 * n , k , dtype = dtype ).cuda ()
9696 w2 = torch .randn (expert_num , k , 2 * n // 2 , dtype = dtype ).cuda ()
9797
98- rnd_logics = torch .randn (m , expert_num - num_fused_experts , device = "cuda" )
98+ rnd_logics = torch .randn (m , expert_num - num_fused_shared_experts , device = "cuda" )
9999 topk_values , topk_ids = torch .topk (rnd_logics , topk , dim = 1 )
100- topk_weights = torch .randn ((m , topk + num_fused_experts ), device = "cuda" , dtype = dtype ) / 10
101-
102- if num_fused_experts > 0 :
100+ if num_fused_shared_experts > 0 :
101+ # 存在融合共享专家的时候,需要pad 共享专家对应的id 到topk_ids 中
103102 pad_topk_ids = torch .arange (
104- start = expert_num - num_fused_experts ,
103+ start = expert_num - num_fused_shared_experts ,
105104 end = expert_num ,
106105 step = 1 ,
107106 dtype = topk_ids .dtype ,
108- device = "cuda" ).view (1 , num_fused_experts ).repeat (topk_ids .shape [0 ], 1 )
107+ device = "cuda" ).view (1 , num_fused_shared_experts ).repeat (topk_ids .shape [0 ], 1 )
109108 topk_ids = torch .cat ([topk_ids , pad_topk_ids ], dim = 1 )
109+ topk_weights = torch .randn ((m , topk + num_fused_shared_experts ), device = "cuda" , dtype = dtype ) / 10
110110
111- expert_to_tokens = torch .empty ((expert_num , (topk + num_fused_experts ) * m ), dtype = torch .int32 , device = "cuda" )
112- expert_to_weights = torch .empty ((expert_num , (topk + num_fused_experts ) * m ), dtype = torch .float32 , device = "cuda" )
111+ expert_to_tokens = torch .empty ((expert_num , (topk + num_fused_shared_experts ) * m ), dtype = torch .int32 , device = "cuda" )
112+ expert_to_weights = torch .empty ((expert_num , (topk + num_fused_shared_experts ) * m ), dtype = torch .float32 , device = "cuda" )
113113 moe_align (topk_ids = topk_ids , out = expert_to_tokens )
114114 expert_to_token_num = torch .empty ((expert_num ,), dtype = torch .int32 , device = "cuda" )
115- moe_align1 (expert_to_tokens , topk_weights , expert_to_weights , expert_to_token_num , topk = topk + num_fused_experts )
115+ moe_align1 (expert_to_tokens , topk_weights , expert_to_weights , expert_to_token_num , topk = topk + num_fused_shared_experts )
116116
117- out1 = torch .zeros ((m * (topk + num_fused_experts ), 2 * n ), dtype = torch .bfloat16 , device = "cuda" )
118- down_in = torch .zeros ((m * (topk + num_fused_experts ), n ), dtype = torch .bfloat16 , device = "cuda" )
119- out2 = torch .zeros ((m * (topk + num_fused_experts ), k ), dtype = torch .bfloat16 , device = "cuda" )
117+ out1 = torch .zeros ((m * (topk + num_fused_shared_experts ), 2 * n ), dtype = torch .bfloat16 , device = "cuda" )
118+ down_in = torch .zeros ((m * (topk + num_fused_shared_experts ), n ), dtype = torch .bfloat16 , device = "cuda" )
119+ out2 = torch .zeros ((m * (topk + num_fused_shared_experts ), k ), dtype = torch .bfloat16 , device = "cuda" )
120120
121121 for _ in range (test_count ):
122122 input_tuples .append (
@@ -230,7 +230,7 @@ def worker(
230230 use_fp8_w8a8 : bool ,
231231 is_up : bool ,
232232 block_shape ,
233- num_fused_experts : int ,
233+ num_fused_shared_experts : int ,
234234 test_configs ,
235235 queue ,
236236):
@@ -247,7 +247,7 @@ def worker(
247247 use_fp8_w8a8 = use_fp8_w8a8 ,
248248 is_up = is_up ,
249249 block_shape = block_shape ,
250- num_fused_experts = num_fused_experts ,
250+ num_fused_shared_experts = num_fused_shared_experts ,
251251 ** test_configs [index ],
252252 )
253253 queue .put (cost_time ) # Put result in queue
@@ -280,7 +280,7 @@ def get_test_configs(split_id, split_count):
280280 4 ,
281281 8 ,
282282 ]:
283- for BLOCK_SIZE_M in [32 , 64 , 128 ]:
283+ for BLOCK_SIZE_M in [16 , 32 , 64 , 128 ]:
284284 for BLOCK_SIZE_N in [32 , 64 , 128 ]:
285285 for BLOCK_SIZE_K in [32 , 64 , 128 ]:
286286 t_config = {
@@ -311,7 +311,7 @@ def tuning_configs(
311311 use_fp8_w8a8 : bool ,
312312 is_up : bool ,
313313 block_shape ,
314- num_fused_experts : int ,
314+ num_fused_shared_experts : int ,
315315):
316316 os .environ ["CUDA_VISIBLE_DEVICES" ] = str (device_id )
317317 best_config , best_cost_time = None , 10000000
@@ -335,7 +335,7 @@ def tuning_configs(
335335 use_fp8_w8a8 ,
336336 is_up ,
337337 block_shape ,
338- num_fused_experts ,
338+ num_fused_shared_experts ,
339339 test_configs ,
340340 queue ,
341341 ),
@@ -370,7 +370,7 @@ def tuning_configs(
370370 use_fp8_w8a8 ,
371371 is_up ,
372372 block_shape ,
373- num_fused_experts ,
373+ num_fused_shared_experts ,
374374 test_configs ,
375375 queue ,
376376 ),
@@ -437,15 +437,15 @@ def main(args):
437437 "use_fp8_w8a8" : use_fp8_w8a8 ,
438438 "is_up" : True ,
439439 "block_shape" : block_shape ,
440- "num_fused_experts " : args .num_fused_experts ,
440+ "num_fused_shared_experts " : args .num_fused_shared_experts ,
441441 },
442442 )
443443 up_dict [m ] = ans
444444 MoeGroupedGemmKernelConfig .save_config (
445445 N = n * 2 ,
446446 K = hidden_dim ,
447447 topk_num = topk_num ,
448- expert_num = expert_num + args .num_fused_experts ,
448+ expert_num = expert_num + args .num_fused_shared_experts ,
449449 mul_routed_weight = False ,
450450 use_fp8_w8a8 = use_fp8_w8a8 ,
451451 out_dtype = str (torch .bfloat16 ),
@@ -467,7 +467,7 @@ def main(args):
467467 "use_fp8_w8a8" : use_fp8_w8a8 ,
468468 "is_up" : False ,
469469 "block_shape" : block_shape ,
470- "num_fused_experts " : args .num_fused_experts ,
470+ "num_fused_shared_experts " : args .num_fused_shared_experts ,
471471 },
472472 )
473473 down_dict [m ] = ans
@@ -476,7 +476,7 @@ def main(args):
476476 N = hidden_dim ,
477477 K = n ,
478478 topk_num = 1 ,
479- expert_num = expert_num + args .num_fused_experts ,
479+ expert_num = expert_num + args .num_fused_shared_experts ,
480480 mul_routed_weight = True ,
481481 use_fp8_w8a8 = use_fp8_w8a8 ,
482482 out_dtype = str (torch .bfloat16 ),
@@ -489,6 +489,6 @@ def main(args):
489489 parser .add_argument ("--model_dir" , type = str , default = "deepseek-ai/DeepSeek-R1" )
490490 parser .add_argument ("--tp" , type = int , default = 8 )
491491 parser .add_argument ("--use_fp8_w8a8" , action = "store_true" )
492- parser .add_argument ("--num_fused_experts " , type = int , default = 0 )
492+ parser .add_argument ("--num_fused_shared_experts " , type = int , default = 0 )
493493 args = parser .parse_args ()
494494 main (args )
0 commit comments