@@ -164,18 +164,19 @@ def torch_dispatch(
164
164
a : torch .Tensor ,
165
165
topk_ids : torch .Tensor ,
166
166
num_experts : int
167
- ) -> torch .Tensor :
167
+ ) -> Tuple [ torch .Tensor , torch . Tensor ] :
168
168
assert topk_ids .dim () == 2
169
169
assert topk_ids .shape [0 ] == a .shape [0 ]
170
170
171
171
num_tokens = a .shape [0 ]
172
172
topk = topk_ids .shape [1 ]
173
173
174
174
tokens_per_expert = torch .bincount (topk_ids .view (- 1 ), minlength = num_experts )
175
-
176
175
max_num_tokens = tokens_per_expert .max ()
176
+
177
177
b_a = torch .zeros ((num_experts , max_num_tokens , a .shape [1 ]),
178
178
dtype = a .dtype , device = a .device )
179
+
179
180
#print(f"b_a shape {b_a.shape}")
180
181
181
182
token_counts = torch .zeros (num_experts , dtype = torch .int , device = a .device )
@@ -242,59 +243,58 @@ def torch_moe2(a, w1, w2, topk_weight, topk_ids):
242
243
topk_weight .view (M , - 1 , 1 ).to (out .dtype )).sum (dim = 1 )
243
244
244
245
245
- # @pytest.mark.parametrize("m", [1, 33, 64, 222]) #, 1024 * 128])
246
- # @pytest.mark.parametrize("n", [128, 1024, 2048])
247
- # @pytest.mark.parametrize("k", [128, 511, 1024])
248
- # @pytest.mark.parametrize("e", NUM_EXPERTS)
249
- # @pytest.mark.parametrize("topk", TOP_KS)
250
- # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
251
- # def test_fused_moe_batched_experts(
252
- # m: int,
253
- # n: int,
254
- # k: int,
255
- # e: int,
256
- # topk: int,
257
- # dtype: torch.dtype,
258
- # ):
259
- # current_platform.seed_everything(7)
260
-
261
- # a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
262
- # w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
263
- # w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
264
-
265
- # score = torch.randn((m, e), device="cuda", dtype=dtype)
266
-
267
- # vllm_config = VllmConfig()
268
- # with set_current_vllm_config(vllm_config):
269
- # topk_weight, topk_ids = fused_topk(a, score, topk, False)
270
-
271
- # torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids)
272
-
273
- # if True:
274
- # triton_output = torch_batched_moe(a,
275
- # w1,
276
- # w2,
277
- # topk_weight,
278
- # topk_ids)
279
- # else:
280
- # b_a, tokens_per_expert = batch_by_experts(a, topk_ids, e)
281
- # triton_output = fused_batched_experts(
282
- # b_a,
283
- # w1,
284
- # w2,
285
- # topk_weight,
286
- # topk_ids,
287
- # global_num_experts=e
288
- # )
289
-
290
- # if False:
291
- # torch.set_printoptions(profile="full")
292
- # print("BASELINE")
293
- # print(torch_output)
294
- # print("OUTPUT")
295
- # print(triton_output)
296
-
297
- # torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
246
+ @pytest .mark .parametrize ("m" , [1 , 33 , 64 , 222 ]) #, 1024 * 128])
247
+ @pytest .mark .parametrize ("n" , [128 , 1024 , 2048 ])
248
+ @pytest .mark .parametrize ("k" , [128 , 511 , 1024 ])
249
+ @pytest .mark .parametrize ("e" , NUM_EXPERTS )
250
+ @pytest .mark .parametrize ("topk" , TOP_KS )
251
+ @pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 ])
252
+ def test_fused_moe_batched_experts (
253
+ m : int ,
254
+ n : int ,
255
+ k : int ,
256
+ e : int ,
257
+ topk : int ,
258
+ dtype : torch .dtype ,
259
+ ):
260
+ current_platform .seed_everything (7 )
261
+
262
+ a = torch .randn ((m , k ), device = "cuda" , dtype = dtype ) / 10
263
+ w1 = torch .randn ((e , 2 * n , k ), device = "cuda" , dtype = dtype ) / 10
264
+ w2 = torch .randn ((e , k , n ), device = "cuda" , dtype = dtype ) / 10
265
+ score = torch .randn ((m , e ), device = "cuda" , dtype = dtype )
266
+
267
+ vllm_config = VllmConfig ()
268
+ with set_current_vllm_config (vllm_config ):
269
+ topk_weight , topk_ids = fused_topk (a , score , topk , False )
270
+
271
+ torch_output = torch_moe2 (a , w1 , w2 , topk_weight , topk_ids )
272
+
273
+ if True :
274
+ triton_output = torch_batched_moe (a ,
275
+ w1 ,
276
+ w2 ,
277
+ topk_weight ,
278
+ topk_ids )
279
+ else :
280
+ b_a , tokens_per_expert = batch_by_experts (a , topk_ids , e )
281
+ triton_output = fused_batched_experts (
282
+ b_a ,
283
+ w1 ,
284
+ w2 ,
285
+ topk_weight ,
286
+ topk_ids ,
287
+ global_num_experts = e
288
+ )
289
+
290
+ if False :
291
+ torch .set_printoptions (profile = "full" )
292
+ print ("BASELINE" )
293
+ print (torch_output )
294
+ print ("OUTPUT" )
295
+ print (triton_output )
296
+
297
+ torch .testing .assert_close (triton_output , torch_output , atol = 2e-2 , rtol = 0 )
298
298
299
299
300
300
def chunk_by_rank (t , r , w ):
@@ -310,6 +310,7 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk):
310
310
311
311
num_tokens , hidden_dim = a .shape
312
312
num_experts = w1 .shape [0 ]
313
+ num_local_experts = w1 .shape [0 ] // pgi .world_size
313
314
block_size = 128
314
315
device = pgi .device
315
316
rank_num_tokens = num_tokens // pgi .world_size
@@ -352,7 +353,7 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk):
352
353
score_chunk = chunk_by_rank (scores , rank , world_size ).to (device )
353
354
chunk_topk_weight , chunk_topk_ids = fused_topk (a_chunk , score_chunk , topk , False )
354
355
355
- #print(f"chunk_topk_ids = {chunk_topk_ids}")
356
+ #print(f"chunk_topk_ids = {chunk_topk_ids.view(-1) }")
356
357
357
358
b_a , b_a_scale , expert_num_tokens = dispatch_combine .dispatch (
358
359
a_chunk ,
@@ -363,6 +364,25 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk):
363
364
None
364
365
)
365
366
367
+ #topk_weight, topk_ids = fused_topk(a_chunk, score_chunk, topk, False)
368
+ naive_b_a , tokens_per_expert = torch_dispatch (a_chunk , chunk_topk_ids , num_experts )
369
+
370
+ torch .distributed .all_reduce (tokens_per_expert )
371
+ #max_num = tokens_per_expert.max()
372
+ tokens_per_expert = chunk_by_rank (tokens_per_expert , rank , world_size ).to (dtype = torch .int32 )
373
+
374
+ #print(f"tpe {tokens_per_expert}")
375
+ #print(f"ent {expert_num_tokens}")
376
+
377
+ #naive_b_a = chunk_by_rank(naive_b_a, rank, world_size)
378
+
379
+ #torch.set_printoptions(profile="full")
380
+ #print("b_a", b_a[:naive_b_a.shape[1]])
381
+ #print("naive_b_a", naive_b_a)
382
+
383
+ torch .testing .assert_close (tokens_per_expert , expert_num_tokens , atol = 0 , rtol = 0 )
384
+ #torch.testing.assert_close(b_a[:, :naive_b_a.shape[1]], naive_b_a, atol=2e-2, rtol=0)
385
+
366
386
b_a = b_a * 1.5
367
387
368
388
out = torch .full (
@@ -382,8 +402,6 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk):
382
402
383
403
ata .destroy ()
384
404
385
- #torch.distributed.barrier()
386
-
387
405
#print(f"OUT {rank}: {out.shape} {out[:rank_num_tokens]}")
388
406
389
407
#torch.distributed.all_reduce(out)
@@ -547,8 +565,6 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk):
547
565
548
566
ata .destroy ()
549
567
550
- #torch.distributed.barrier()
551
-
552
568
#print(f"OUT {rank}: {out.shape} {out[:rank_num_tokens]}")
553
569
554
570
#torch.distributed.all_reduce(out)
@@ -593,8 +609,6 @@ def _pplx_moe(
593
609
score ,
594
610
topk )
595
611
596
- #print(f"torch_output {pgi.rank}: {torch_output}")
597
-
598
612
if False :
599
613
print ("BASELINE" )
600
614
print (torch_output )
@@ -603,23 +617,25 @@ def _pplx_moe(
603
617
604
618
torch_output = chunk_by_rank (torch_output , pgi .rank , pgi .world_size ).to (pplx_output .device )
605
619
620
+ #print(f"torch_output {pgi.rank}: {torch_output.shape} {torch_output}")
621
+
606
622
torch .testing .assert_close (pplx_output , torch_output , atol = 2e-2 , rtol = 0 )
607
623
608
624
nvshmem_finalize ()
609
625
610
626
611
- # @pytest.mark.parametrize("m", [1, 33 , 64, 222]) #, 1024 * 128])
612
- # @pytest.mark.parametrize("n", [128, 1024, 2048])
613
- # @pytest.mark.parametrize("k", [128, 512, 1024])
614
- # @pytest.mark.parametrize("e", NUM_EXPERTS)
615
- # @pytest.mark.parametrize("topk", TOP_KS)
616
- # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
617
- @pytest .mark .parametrize ("m" , [64 ]) ##, 32]) #, 1024 * 128])
618
- @pytest .mark .parametrize ("n" , [128 ])
619
- @pytest .mark .parametrize ("k" , [128 ])
620
- @pytest .mark .parametrize ("e" , [8 ]) #NUM_EXPERTS)
621
- @pytest .mark .parametrize ("topk" , [2 ]) #TOP_KS)
622
- @pytest .mark .parametrize ("dtype" , [torch .bfloat16 ])
627
+ @pytest .mark .parametrize ("m" , [2 , 32 , 64 , 222 ]) #, 1024 * 128])
628
+ @pytest .mark .parametrize ("n" , [128 , 1024 , 2048 ])
629
+ @pytest .mark .parametrize ("k" , [128 , 512 , 1024 ])
630
+ @pytest .mark .parametrize ("e" , NUM_EXPERTS )
631
+ @pytest .mark .parametrize ("topk" , TOP_KS )
632
+ @pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 ])
633
+ # @pytest.mark.parametrize("m", [64]) ##, 32]) #, 1024 * 128])
634
+ # @pytest.mark.parametrize("n", [128])
635
+ # @pytest.mark.parametrize("k", [128])
636
+ # @pytest.mark.parametrize("e", [8]) #NUM_EXPERTS)
637
+ # @pytest.mark.parametrize("topk", [2]) #TOP_KS)
638
+ # @pytest.mark.parametrize("dtype", [torch.bfloat16])
623
639
@pytest .mark .parametrize ("world_dp_size" , [[2 , 1 ]]) #, [4, 2]])
624
640
def test_pplx_moe (
625
641
m : int ,
0 commit comments