@@ -116,9 +116,12 @@ def batch_by_experts(
116
116
assert topk_ids .dim () == 2
117
117
assert topk_ids .shape [0 ] == a .shape [0 ]
118
118
119
+ num_tokens = a .shape [0 ]
120
+ topk = topk_ids .shape [1 ]
121
+
119
122
tokens_per_expert = torch .zeros (num_experts , dtype = torch .int , device = a .device )
120
- for i in range (topk_ids . shape [ 0 ] ):
121
- for j in range (topk_ids . shape [ 1 ] ):
123
+ for i in range (num_tokens ):
124
+ for j in range (topk ):
122
125
expert_id = topk_ids [i , j ]
123
126
tokens_per_expert [expert_id ] = tokens_per_expert [expert_id ] + 1
124
127
@@ -128,34 +131,41 @@ def batch_by_experts(
128
131
dtype = a .dtype , device = a .device )
129
132
#print(f"b_a shape {b_a.shape}")
130
133
131
- # experts_per_token = torch.zeros(a.shape[0] , dtype=torch.int, device=a.device)
134
+ experts_per_token = torch .zeros (num_experts , dtype = torch .int , device = a .device )
132
135
133
- for i in range (topk_ids .shape [0 ]):
134
- for j in range (topk_ids .shape [1 ]):
135
- expert_id = topk_ids [i , j ]
136
- #idx = experts_per_token[i]
137
- b_a [expert_id , j :j + 1 , :] = a [i , :]
138
- #experts_per_token[i] = experts_per_token[i] + 1
136
+ for token in range (num_tokens ):
137
+ for j in range (topk ):
138
+ expert_id = topk_ids [token , j ]
139
+ idx = experts_per_token [expert_id ]
140
+ b_a [expert_id , idx :idx + 1 , :] = a [token , :]
141
+ experts_per_token [expert_id ] = experts_per_token [expert_id ] + 1
142
+
143
+ if False :
144
+ print (f"topk_ids = { topk_ids } " )
145
+ print (f"tokens_per_expert = { tokens_per_expert } " )
146
+ print (f"experts_per_token = { experts_per_token } " )
139
147
140
148
return b_a , tokens_per_expert
141
149
142
150
143
- def unbatch_output (b_out , topk_ids , K ):
151
+ def unbatch_output (b_out , topk_weight , topk_ids , K ):
144
152
num_tokens , topk = topk_ids .shape
145
153
146
154
#print(f"b_out = {b_out.shape} M={num_tokens}, K={K}, topk={topk}")
147
155
num_experts = b_out .shape [0 ]
148
- out = torch .zeros ((num_tokens , topk , K ), dtype = b_out .dtype , device = b_out .device )
156
+ topk = topk_ids .shape [1 ]
157
+ out = torch .zeros ((num_tokens , K ), dtype = b_out .dtype , device = b_out .device )
149
158
expert_counts = torch .zeros (num_experts , dtype = torch .int , device = b_out .device )
159
+ experts = torch .arange (0 , num_experts , dtype = torch .int , device = b_out .device )
150
160
for token in range (num_tokens ):
151
161
expert_ids = topk_ids [token ]
152
162
#print(f"b_out[0] = {b_out[0].shape}")
153
163
for i in range (expert_ids .numel ()):
154
164
expert_id = expert_ids [i ]
155
165
idx = expert_counts [expert_id ]
156
- out [token , i : i + 1 , :] = b_out [expert_id , idx :idx + 1 , :]
157
- idx = idx + 1
158
- expert_counts [expert_id ] = idx
166
+ #print(f" out = {out [token, :].shape}, b_out = { b_out[expert_id, idx:idx+1, :].shape}, topk_w = {topk_weight[token, i]}")
167
+ out [ token , :] = out [ token , :] + b_out [ expert_id , idx : idx + 1 , :] * topk_weight [ token , i ]
168
+ expert_counts [expert_id ] = expert_counts [ expert_id ] + 1
159
169
160
170
return out
161
171
@@ -173,9 +183,9 @@ def torch_batched_moe(a, w1, w2, tokens_per_expert, topk_weight, topk_ids):
173
183
#out[expert, :num, :] = SiluAndMul()(a[expert,:num,:] @ w1[expert].transpose(0, 1)) @ w2[expert].transpose(0, 1)
174
184
out [expert , :, :] = SiluAndMul ()(a [expert ,:,:] @ w1 [expert ].transpose (0 , 1 )) @ w2 [expert ].transpose (0 , 1 )
175
185
176
- out = unbatch_output (out , topk_ids , w2 . shape [ 1 ] )
186
+ out = unbatch_output (out , topk_weight , topk_ids , K )
177
187
178
- return (out * topk_weight .view (num_tokens , - 1 , 1 ).to (out .dtype )).sum (dim = 1 )
188
+ return out # (out * topk_weight.view(num_tokens, -1, 1).to(out.dtype)).sum(dim=1)
179
189
180
190
181
191
def torch_moe2 (a , w1 , w2 , topk_weight , topk_ids ):
@@ -200,6 +210,12 @@ def torch_moe2(a, w1, w2, topk_weight, topk_ids):
200
210
@pytest .mark .parametrize ("e" , NUM_EXPERTS )
201
211
@pytest .mark .parametrize ("topk" , TOP_KS )
202
212
@pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 ])
213
+ #@pytest.mark.parametrize("m", [33])
214
+ #@pytest.mark.parametrize("n", [128])
215
+ #@pytest.mark.parametrize("k", [128])
216
+ #@pytest.mark.parametrize("e", [8])
217
+ #@pytest.mark.parametrize("topk", [2])
218
+ #@pytest.mark.parametrize("dtype", [torch.float16])
203
219
def test_fused_moe_batched_experts (
204
220
m : int ,
205
221
n : int ,
@@ -208,12 +224,13 @@ def test_fused_moe_batched_experts(
208
224
topk : int ,
209
225
dtype : torch .dtype ,
210
226
):
227
+ current_platform .seed_everything (7 )
228
+
211
229
a = torch .randn ((m , k ), device = "cuda" , dtype = dtype ) / 10
212
230
w1 = torch .randn ((e , 2 * n , k ), device = "cuda" , dtype = dtype ) / 10
213
231
w2 = torch .randn ((e , k , n ), device = "cuda" , dtype = dtype ) / 10
214
232
215
233
score = torch .randn ((m , e ), device = "cuda" , dtype = dtype )
216
- e_map = None
217
234
218
235
vllm_config = VllmConfig ()
219
236
with set_current_vllm_config (vllm_config ):
@@ -238,6 +255,13 @@ def test_fused_moe_batched_experts(
238
255
topk_ids ,
239
256
global_num_experts = e )
240
257
258
+ if False :
259
+ torch .set_printoptions (profile = "full" )
260
+ print ("BASELINE" )
261
+ print (torch_output )
262
+ print ("OUTPUT" )
263
+ print (triton_output )
264
+
241
265
#torch.testing.assert_close(triton_b_output, torch_b_output, atol=2e-2, rtol=0)
242
266
torch .testing .assert_close (triton_output , torch_output , atol = 2e-2 , rtol = 0 )
243
267
0 commit comments