@@ -112,7 +112,6 @@ def batch_by_experts(
112
112
topk_ids : torch .Tensor ,
113
113
num_experts : int
114
114
) -> torch .Tensor :
115
- #print(topk_ids.shape, topk_ids)
116
115
assert topk_ids .dim () == 2
117
116
assert topk_ids .shape [0 ] == a .shape [0 ]
118
117
@@ -125,45 +124,36 @@ def batch_by_experts(
125
124
expert_id = topk_ids [i , j ]
126
125
tokens_per_expert [expert_id ] = tokens_per_expert [expert_id ] + 1
127
126
128
- #print(f"token_per_expert {tokens_per_expert.max()}")
129
127
max_num_tokens = tokens_per_expert .max ()
130
128
b_a = torch .zeros ((num_experts , max_num_tokens , a .shape [1 ]),
131
129
dtype = a .dtype , device = a .device )
132
130
#print(f"b_a shape {b_a.shape}")
133
131
134
- experts_per_token = torch .zeros (num_experts , dtype = torch .int , device = a .device )
132
+ token_counts = torch .zeros (num_experts , dtype = torch .int , device = a .device )
135
133
136
134
for token in range (num_tokens ):
137
135
for j in range (topk ):
138
136
expert_id = topk_ids [token , j ]
139
- idx = experts_per_token [expert_id ]
137
+ idx = token_counts [expert_id ]
140
138
b_a [expert_id , idx :idx + 1 , :] = a [token , :]
141
- experts_per_token [expert_id ] = experts_per_token [expert_id ] + 1
142
-
143
- if False :
144
- print (f"topk_ids = { topk_ids } " )
145
- print (f"tokens_per_expert = { tokens_per_expert } " )
146
- print (f"experts_per_token = { experts_per_token } " )
139
+ token_counts [expert_id ] = token_counts [expert_id ] + 1
147
140
148
141
return b_a , tokens_per_expert
149
142
150
143
151
144
def unbatch_output (b_out , topk_weight , topk_ids , K ):
152
145
num_tokens , topk = topk_ids .shape
153
146
154
- #print(f"b_out = {b_out.shape} M={num_tokens}, K={K}, topk={topk}")
155
147
num_experts = b_out .shape [0 ]
156
148
topk = topk_ids .shape [1 ]
157
149
out = torch .zeros ((num_tokens , K ), dtype = b_out .dtype , device = b_out .device )
158
150
expert_counts = torch .zeros (num_experts , dtype = torch .int , device = b_out .device )
159
151
experts = torch .arange (0 , num_experts , dtype = torch .int , device = b_out .device )
160
152
for token in range (num_tokens ):
161
153
expert_ids = topk_ids [token ]
162
- #print(f"b_out[0] = {b_out[0].shape}")
163
154
for i in range (expert_ids .numel ()):
164
155
expert_id = expert_ids [i ]
165
156
idx = expert_counts [expert_id ]
166
- #print(f"out = {out[token, :].shape}, b_out = {b_out[expert_id, idx:idx+1, :].shape}, topk_w = {topk_weight[token, i]}")
167
157
out [token , :] = out [token , :] + b_out [expert_id , idx :idx + 1 , :] * topk_weight [token , i ]
168
158
expert_counts [expert_id ] = expert_counts [expert_id ] + 1
169
159
@@ -172,20 +162,19 @@ def unbatch_output(b_out, topk_weight, topk_ids, K):
172
162
173
163
def torch_batched_moe (a , w1 , w2 , tokens_per_expert , topk_weight , topk_ids ):
174
164
assert a .dim () == 3
175
- #print(f"A = {a.shape} {a[0, :, :].shape}")
176
165
num_tokens , topk = topk_ids .shape
177
166
_ , max_num_tokens , K = a .shape
178
167
num_experts = w1 .shape [0 ]
179
168
out = torch .zeros ((num_experts , max_num_tokens , w2 .shape [1 ]), dtype = a .dtype , device = a .device )
180
169
for expert in range (num_experts ):
181
170
num = tokens_per_expert [expert ]
182
171
if num > 0 :
183
- # out[expert, :num, :] = SiluAndMul()(a[expert,:num,:] @ w1[expert].transpose(0, 1)) @ w2[expert].transpose(0, 1)
184
- out [expert , :, :] = SiluAndMul ()(a [expert ,:,:] @ w1 [expert ].transpose (0 , 1 )) @ w2 [expert ].transpose (0 , 1 )
172
+ out [expert , :num , :] = SiluAndMul ()(a [expert ,:num ,:] @ w1 [expert ].transpose (0 , 1 )) @ w2 [expert ].transpose (0 , 1 )
173
+ # out[expert, :, :] = SiluAndMul()(a[expert,:,:] @ w1[expert].transpose(0, 1)) @ w2[expert].transpose(0, 1)
185
174
186
175
out = unbatch_output (out , topk_weight , topk_ids , K )
187
176
188
- return out #(out * topk_weight.view(num_tokens, -1, 1).to(out.dtype)).sum(dim=1)
177
+ return out
189
178
190
179
191
180
def torch_moe2 (a , w1 , w2 , topk_weight , topk_ids ):
@@ -210,12 +199,6 @@ def torch_moe2(a, w1, w2, topk_weight, topk_ids):
210
199
@pytest .mark .parametrize ("e" , NUM_EXPERTS )
211
200
@pytest .mark .parametrize ("topk" , TOP_KS )
212
201
@pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 ])
213
- #@pytest.mark.parametrize("m", [33])
214
- #@pytest.mark.parametrize("n", [128])
215
- #@pytest.mark.parametrize("k", [128])
216
- #@pytest.mark.parametrize("e", [8])
217
- #@pytest.mark.parametrize("topk", [2])
218
- #@pytest.mark.parametrize("dtype", [torch.float16])
219
202
def test_fused_moe_batched_experts (
220
203
m : int ,
221
204
n : int ,
@@ -248,7 +231,7 @@ def test_fused_moe_batched_experts(
248
231
topk_weight ,
249
232
topk_ids )
250
233
else :
251
- triton_output = fused_experts (a , # b_a
234
+ triton_output = fused_experts (b_a ,
252
235
w1 ,
253
236
w2 ,
254
237
topk_weight ,
@@ -262,7 +245,6 @@ def test_fused_moe_batched_experts(
262
245
print ("OUTPUT" )
263
246
print (triton_output )
264
247
265
- #torch.testing.assert_close(triton_b_output, torch_b_output, atol=2e-2, rtol=0)
266
248
torch .testing .assert_close (triton_output , torch_output , atol = 2e-2 , rtol = 0 )
267
249
268
250
0 commit comments