Skip to content

Commit 1d98c32

Browse files
committed
simple test
Signed-off-by: Bill Nell <[email protected]>
1 parent 16092a5 commit 1d98c32

File tree

1 file changed

+41
-17
lines changed

1 file changed

+41
-17
lines changed

tests/kernels/moe/test_moe.py

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,12 @@ def batch_by_experts(
116116
assert topk_ids.dim() == 2
117117
assert topk_ids.shape[0] == a.shape[0]
118118

119+
num_tokens = a.shape[0]
120+
topk = topk_ids.shape[1]
121+
119122
tokens_per_expert = torch.zeros(num_experts, dtype=torch.int, device=a.device)
120-
for i in range(topk_ids.shape[0]):
121-
for j in range(topk_ids.shape[1]):
123+
for i in range(num_tokens):
124+
for j in range(topk):
122125
expert_id = topk_ids[i, j]
123126
tokens_per_expert[expert_id] = tokens_per_expert[expert_id] + 1
124127

@@ -128,34 +131,41 @@ def batch_by_experts(
128131
dtype=a.dtype, device=a.device)
129132
#print(f"b_a shape {b_a.shape}")
130133

131-
#experts_per_token = torch.zeros(a.shape[0], dtype=torch.int, device=a.device)
134+
experts_per_token = torch.zeros(num_experts, dtype=torch.int, device=a.device)
132135

133-
for i in range(topk_ids.shape[0]):
134-
for j in range(topk_ids.shape[1]):
135-
expert_id = topk_ids[i, j]
136-
#idx = experts_per_token[i]
137-
b_a[expert_id, j:j+1, :] = a[i, :]
138-
#experts_per_token[i] = experts_per_token[i] + 1
136+
for token in range(num_tokens):
137+
for j in range(topk):
138+
expert_id = topk_ids[token, j]
139+
idx = experts_per_token[expert_id]
140+
b_a[expert_id, idx:idx+1, :] = a[token, :]
141+
experts_per_token[expert_id] = experts_per_token[expert_id] + 1
142+
143+
if False:
144+
print(f"topk_ids = {topk_ids}")
145+
print(f"tokens_per_expert = {tokens_per_expert}")
146+
print(f"experts_per_token = {experts_per_token}")
139147

140148
return b_a, tokens_per_expert
141149

142150

143-
def unbatch_output(b_out, topk_ids, K):
151+
def unbatch_output(b_out, topk_weight, topk_ids, K):
144152
num_tokens, topk = topk_ids.shape
145153

146154
#print(f"b_out = {b_out.shape} M={num_tokens}, K={K}, topk={topk}")
147155
num_experts = b_out.shape[0]
148-
out = torch.zeros((num_tokens, topk, K), dtype=b_out.dtype, device=b_out.device)
156+
topk = topk_ids.shape[1]
157+
out = torch.zeros((num_tokens, K), dtype=b_out.dtype, device=b_out.device)
149158
expert_counts = torch.zeros(num_experts, dtype=torch.int, device=b_out.device)
159+
experts = torch.arange(0, num_experts, dtype=torch.int, device=b_out.device)
150160
for token in range(num_tokens):
151161
expert_ids = topk_ids[token]
152162
#print(f"b_out[0] = {b_out[0].shape}")
153163
for i in range(expert_ids.numel()):
154164
expert_id = expert_ids[i]
155165
idx = expert_counts[expert_id]
156-
out[token, i:i+1, :] = b_out[expert_id, idx:idx+1, :]
157-
idx = idx + 1
158-
expert_counts[expert_id] = idx
166+
#print(f"out = {out[token, :].shape}, b_out = {b_out[expert_id, idx:idx+1, :].shape}, topk_w = {topk_weight[token, i]}")
167+
out[token, :] = out[token, :] + b_out[expert_id, idx:idx+1, :] * topk_weight[token, i]
168+
expert_counts[expert_id] = expert_counts[expert_id] + 1
159169

160170
return out
161171

@@ -173,9 +183,9 @@ def torch_batched_moe(a, w1, w2, tokens_per_expert, topk_weight, topk_ids):
173183
#out[expert, :num, :] = SiluAndMul()(a[expert,:num,:] @ w1[expert].transpose(0, 1)) @ w2[expert].transpose(0, 1)
174184
out[expert, :, :] = SiluAndMul()(a[expert,:,:] @ w1[expert].transpose(0, 1)) @ w2[expert].transpose(0, 1)
175185

176-
out = unbatch_output(out, topk_ids, w2.shape[1])
186+
out = unbatch_output(out, topk_weight, topk_ids, K)
177187

178-
return (out * topk_weight.view(num_tokens, -1, 1).to(out.dtype)).sum(dim=1)
188+
return out #(out * topk_weight.view(num_tokens, -1, 1).to(out.dtype)).sum(dim=1)
179189

180190

181191
def torch_moe2(a, w1, w2, topk_weight, topk_ids):
@@ -200,6 +210,12 @@ def torch_moe2(a, w1, w2, topk_weight, topk_ids):
200210
@pytest.mark.parametrize("e", NUM_EXPERTS)
201211
@pytest.mark.parametrize("topk", TOP_KS)
202212
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
213+
#@pytest.mark.parametrize("m", [33])
214+
#@pytest.mark.parametrize("n", [128])
215+
#@pytest.mark.parametrize("k", [128])
216+
#@pytest.mark.parametrize("e", [8])
217+
#@pytest.mark.parametrize("topk", [2])
218+
#@pytest.mark.parametrize("dtype", [torch.float16])
203219
def test_fused_moe_batched_experts(
204220
m: int,
205221
n: int,
@@ -208,12 +224,13 @@ def test_fused_moe_batched_experts(
208224
topk: int,
209225
dtype: torch.dtype,
210226
):
227+
current_platform.seed_everything(7)
228+
211229
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
212230
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
213231
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
214232

215233
score = torch.randn((m, e), device="cuda", dtype=dtype)
216-
e_map = None
217234

218235
vllm_config = VllmConfig()
219236
with set_current_vllm_config(vllm_config):
@@ -238,6 +255,13 @@ def test_fused_moe_batched_experts(
238255
topk_ids,
239256
global_num_experts=e)
240257

258+
if False:
259+
torch.set_printoptions(profile="full")
260+
print("BASELINE")
261+
print(torch_output)
262+
print("OUTPUT")
263+
print(triton_output)
264+
241265
#torch.testing.assert_close(triton_b_output, torch_b_output, atol=2e-2, rtol=0)
242266
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
243267

0 commit comments

Comments
 (0)