@@ -99,20 +99,27 @@ def test_kernel(
9999 topk_values , topk_ids = torch .topk (rnd_logics , topk , dim = 1 )
100100 if num_fused_shared_experts > 0 :
101101 # 存在融合共享专家的时候,需要pad 共享专家对应的id 到topk_ids 中
102- pad_topk_ids = torch .arange (
103- start = expert_num - num_fused_shared_experts ,
104- end = expert_num ,
105- step = 1 ,
106- dtype = topk_ids .dtype ,
107- device = "cuda" ).view (1 , num_fused_shared_experts ).repeat (topk_ids .shape [0 ], 1 )
102+ pad_topk_ids = (
103+ torch .arange (
104+ start = expert_num - num_fused_shared_experts , end = expert_num , step = 1 , dtype = topk_ids .dtype , device = "cuda"
105+ )
106+ .view (1 , num_fused_shared_experts )
107+ .repeat (topk_ids .shape [0 ], 1 )
108+ )
108109 topk_ids = torch .cat ([topk_ids , pad_topk_ids ], dim = 1 )
109110 topk_weights = torch .randn ((m , topk + num_fused_shared_experts ), device = "cuda" , dtype = dtype ) / 10
110111
111- expert_to_tokens = torch .empty ((expert_num , (topk + num_fused_shared_experts ) * m ), dtype = torch .int32 , device = "cuda" )
112- expert_to_weights = torch .empty ((expert_num , (topk + num_fused_shared_experts ) * m ), dtype = torch .float32 , device = "cuda" )
112+ expert_to_tokens = torch .empty (
113+ (expert_num , (topk + num_fused_shared_experts ) * m ), dtype = torch .int32 , device = "cuda"
114+ )
115+ expert_to_weights = torch .empty (
116+ (expert_num , (topk + num_fused_shared_experts ) * m ), dtype = torch .float32 , device = "cuda"
117+ )
113118 moe_align (topk_ids = topk_ids , out = expert_to_tokens )
114119 expert_to_token_num = torch .empty ((expert_num ,), dtype = torch .int32 , device = "cuda" )
115- moe_align1 (expert_to_tokens , topk_weights , expert_to_weights , expert_to_token_num , topk = topk + num_fused_shared_experts )
120+ moe_align1 (
121+ expert_to_tokens , topk_weights , expert_to_weights , expert_to_token_num , topk = topk + num_fused_shared_experts
122+ )
116123
117124 out1 = torch .zeros ((m * (topk + num_fused_shared_experts ), 2 * n ), dtype = torch .bfloat16 , device = "cuda" )
118125 down_in = torch .zeros ((m * (topk + num_fused_shared_experts ), n ), dtype = torch .bfloat16 , device = "cuda" )
0 commit comments