Skip to content

Commit c6ce96c

Browse files
author
none
committed
fix moe tunning
1 parent 8528026 commit c6ce96c

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

test/kernel/fuse_moe_tuning.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,18 +98,25 @@ def test_kernel(
9898
rnd_logics = torch.randn(m, expert_num - num_fused_experts, device="cuda")
9999
topk_values, topk_ids = torch.topk(rnd_logics, topk, dim=1)
100100
topk_weights = torch.randn((m, topk + num_fused_experts), device="cuda", dtype=dtype) / 10
101+
101102
if num_fused_experts > 0:
102-
topk_ids = F.pad(topk_ids, (0, 1), mode="constant", value=expert_num)
103+
pad_topk_ids = torch.arange(
104+
start=expert_num - num_fused_experts,
105+
end=expert_num,
106+
step=1,
107+
dtype=topk_ids.dtype,
108+
device="cuda").view(1, num_fused_experts).repeat(topk_ids.shape[0], 1)
109+
topk_ids = torch.cat([topk_ids, pad_topk_ids], dim=1)
103110

104111
expert_to_tokens = torch.empty((expert_num, (topk + num_fused_experts) * m), dtype=torch.int32, device="cuda")
105112
expert_to_weights = torch.empty((expert_num, (topk + num_fused_experts) * m), dtype=torch.float32, device="cuda")
106113
moe_align(topk_ids=topk_ids, out=expert_to_tokens)
107114
expert_to_token_num = torch.empty((expert_num,), dtype=torch.int32, device="cuda")
108115
moe_align1(expert_to_tokens, topk_weights, expert_to_weights, expert_to_token_num, topk=topk + num_fused_experts)
109116

110-
out1 = torch.zeros((m * (topk + 1), 2 * n), dtype=torch.bfloat16, device="cuda")
111-
down_in = torch.zeros((m * (topk + 1), n), dtype=torch.bfloat16, device="cuda")
112-
out2 = torch.zeros((m * (topk + 1), k), dtype=torch.bfloat16, device="cuda")
117+
out1 = torch.zeros((m * (topk + num_fused_experts), 2 * n), dtype=torch.bfloat16, device="cuda")
118+
down_in = torch.zeros((m * (topk + num_fused_experts), n), dtype=torch.bfloat16, device="cuda")
119+
out2 = torch.zeros((m * (topk + num_fused_experts), k), dtype=torch.bfloat16, device="cuda")
113120

114121
for _ in range(test_count):
115122
input_tuples.append(

0 commit comments

Comments
 (0)