Skip to content

Commit 0dfd27e

Browse files
committed
cleanup
Signed-off-by: Bill Nell <[email protected]>
1 parent 1d98c32 commit 0dfd27e

File tree

1 file changed

+7
-25
lines changed

1 file changed

+7
-25
lines changed

tests/kernels/moe/test_moe.py

Lines changed: 7 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,6 @@ def batch_by_experts(
112112
topk_ids: torch.Tensor,
113113
num_experts: int
114114
) -> torch.Tensor:
115-
#print(topk_ids.shape, topk_ids)
116115
assert topk_ids.dim() == 2
117116
assert topk_ids.shape[0] == a.shape[0]
118117

@@ -125,45 +124,36 @@ def batch_by_experts(
125124
expert_id = topk_ids[i, j]
126125
tokens_per_expert[expert_id] = tokens_per_expert[expert_id] + 1
127126

128-
#print(f"token_per_expert {tokens_per_expert.max()}")
129127
max_num_tokens = tokens_per_expert.max()
130128
b_a = torch.zeros((num_experts, max_num_tokens, a.shape[1]),
131129
dtype=a.dtype, device=a.device)
132130
#print(f"b_a shape {b_a.shape}")
133131

134-
experts_per_token = torch.zeros(num_experts, dtype=torch.int, device=a.device)
132+
token_counts = torch.zeros(num_experts, dtype=torch.int, device=a.device)
135133

136134
for token in range(num_tokens):
137135
for j in range(topk):
138136
expert_id = topk_ids[token, j]
139-
idx = experts_per_token[expert_id]
137+
idx = token_counts[expert_id]
140138
b_a[expert_id, idx:idx+1, :] = a[token, :]
141-
experts_per_token[expert_id] = experts_per_token[expert_id] + 1
142-
143-
if False:
144-
print(f"topk_ids = {topk_ids}")
145-
print(f"tokens_per_expert = {tokens_per_expert}")
146-
print(f"experts_per_token = {experts_per_token}")
139+
token_counts[expert_id] = token_counts[expert_id] + 1
147140

148141
return b_a, tokens_per_expert
149142

150143

151144
def unbatch_output(b_out, topk_weight, topk_ids, K):
152145
num_tokens, topk = topk_ids.shape
153146

154-
#print(f"b_out = {b_out.shape} M={num_tokens}, K={K}, topk={topk}")
155147
num_experts = b_out.shape[0]
156148
topk = topk_ids.shape[1]
157149
out = torch.zeros((num_tokens, K), dtype=b_out.dtype, device=b_out.device)
158150
expert_counts = torch.zeros(num_experts, dtype=torch.int, device=b_out.device)
159151
experts = torch.arange(0, num_experts, dtype=torch.int, device=b_out.device)
160152
for token in range(num_tokens):
161153
expert_ids = topk_ids[token]
162-
#print(f"b_out[0] = {b_out[0].shape}")
163154
for i in range(expert_ids.numel()):
164155
expert_id = expert_ids[i]
165156
idx = expert_counts[expert_id]
166-
#print(f"out = {out[token, :].shape}, b_out = {b_out[expert_id, idx:idx+1, :].shape}, topk_w = {topk_weight[token, i]}")
167157
out[token, :] = out[token, :] + b_out[expert_id, idx:idx+1, :] * topk_weight[token, i]
168158
expert_counts[expert_id] = expert_counts[expert_id] + 1
169159

@@ -172,20 +162,19 @@ def unbatch_output(b_out, topk_weight, topk_ids, K):
172162

173163
def torch_batched_moe(a, w1, w2, tokens_per_expert, topk_weight, topk_ids):
174164
assert a.dim() == 3
175-
#print(f"A = {a.shape} {a[0, :, :].shape}")
176165
num_tokens, topk = topk_ids.shape
177166
_, max_num_tokens, K = a.shape
178167
num_experts = w1.shape[0]
179168
out = torch.zeros((num_experts, max_num_tokens, w2.shape[1]), dtype=a.dtype, device=a.device)
180169
for expert in range(num_experts):
181170
num = tokens_per_expert[expert]
182171
if num > 0:
183-
#out[expert, :num, :] = SiluAndMul()(a[expert,:num,:] @ w1[expert].transpose(0, 1)) @ w2[expert].transpose(0, 1)
184-
out[expert, :, :] = SiluAndMul()(a[expert,:,:] @ w1[expert].transpose(0, 1)) @ w2[expert].transpose(0, 1)
172+
out[expert, :num, :] = SiluAndMul()(a[expert,:num,:] @ w1[expert].transpose(0, 1)) @ w2[expert].transpose(0, 1)
173+
#out[expert, :, :] = SiluAndMul()(a[expert,:,:] @ w1[expert].transpose(0, 1)) @ w2[expert].transpose(0, 1)
185174

186175
out = unbatch_output(out, topk_weight, topk_ids, K)
187176

188-
return out #(out * topk_weight.view(num_tokens, -1, 1).to(out.dtype)).sum(dim=1)
177+
return out
189178

190179

191180
def torch_moe2(a, w1, w2, topk_weight, topk_ids):
@@ -210,12 +199,6 @@ def torch_moe2(a, w1, w2, topk_weight, topk_ids):
210199
@pytest.mark.parametrize("e", NUM_EXPERTS)
211200
@pytest.mark.parametrize("topk", TOP_KS)
212201
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
213-
#@pytest.mark.parametrize("m", [33])
214-
#@pytest.mark.parametrize("n", [128])
215-
#@pytest.mark.parametrize("k", [128])
216-
#@pytest.mark.parametrize("e", [8])
217-
#@pytest.mark.parametrize("topk", [2])
218-
#@pytest.mark.parametrize("dtype", [torch.float16])
219202
def test_fused_moe_batched_experts(
220203
m: int,
221204
n: int,
@@ -248,7 +231,7 @@ def test_fused_moe_batched_experts(
248231
topk_weight,
249232
topk_ids)
250233
else:
251-
triton_output = fused_experts(a, # b_a
234+
triton_output = fused_experts(b_a,
252235
w1,
253236
w2,
254237
topk_weight,
@@ -262,7 +245,6 @@ def test_fused_moe_batched_experts(
262245
print("OUTPUT")
263246
print(triton_output)
264247

265-
#torch.testing.assert_close(triton_b_output, torch_b_output, atol=2e-2, rtol=0)
266248
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
267249

268250

0 commit comments

Comments
 (0)