@@ -98,18 +98,25 @@ def test_kernel(
9898 rnd_logics = torch .randn (m , expert_num - num_fused_experts , device = "cuda" )
9999 topk_values , topk_ids = torch .topk (rnd_logics , topk , dim = 1 )
100100 topk_weights = torch .randn ((m , topk + num_fused_experts ), device = "cuda" , dtype = dtype ) / 10
101+
101102 if num_fused_experts > 0 :
102- topk_ids = F .pad (topk_ids , (0 , 1 ), mode = "constant" , value = expert_num )
103+ pad_topk_ids = torch .arange (
104+ start = expert_num - num_fused_experts ,
105+ end = expert_num ,
106+ step = 1 ,
107+ dtype = topk_ids .dtype ,
108+ device = "cuda" ).view (1 , num_fused_experts ).repeat (topk_ids .shape [0 ], 1 )
109+ topk_ids = torch .cat ([topk_ids , pad_topk_ids ], dim = 1 )
103110
104111 expert_to_tokens = torch .empty ((expert_num , (topk + num_fused_experts ) * m ), dtype = torch .int32 , device = "cuda" )
105112 expert_to_weights = torch .empty ((expert_num , (topk + num_fused_experts ) * m ), dtype = torch .float32 , device = "cuda" )
106113 moe_align (topk_ids = topk_ids , out = expert_to_tokens )
107114 expert_to_token_num = torch .empty ((expert_num ,), dtype = torch .int32 , device = "cuda" )
108115 moe_align1 (expert_to_tokens , topk_weights , expert_to_weights , expert_to_token_num , topk = topk + num_fused_experts )
109116
110- out1 = torch .zeros ((m * (topk + 1 ), 2 * n ), dtype = torch .bfloat16 , device = "cuda" )
111- down_in = torch .zeros ((m * (topk + 1 ), n ), dtype = torch .bfloat16 , device = "cuda" )
112- out2 = torch .zeros ((m * (topk + 1 ), k ), dtype = torch .bfloat16 , device = "cuda" )
117+ out1 = torch .zeros ((m * (topk + num_fused_experts ), 2 * n ), dtype = torch .bfloat16 , device = "cuda" )
118+ down_in = torch .zeros ((m * (topk + num_fused_experts ), n ), dtype = torch .bfloat16 , device = "cuda" )
119+ out2 = torch .zeros ((m * (topk + num_fused_experts ), k ), dtype = torch .bfloat16 , device = "cuda" )
113120
114121 for _ in range (test_count ):
115122 input_tuples .append (
0 commit comments