10
10
import traceback
11
11
12
12
from torch .multiprocessing import spawn # pyright: ignore[reportPrivateImportUsage]
13
- from typing import Callable , Concatenate , ParamSpec , Tuple
13
+ from typing import Callable , Concatenate , Optional , ParamSpec , Tuple
14
14
15
15
from pplx_kernels import AllToAll
16
16
from pplx_kernels .nvshmem import (
@@ -163,7 +163,8 @@ def parallel_launch_from_env(
163
163
def torch_dispatch (
164
164
a : torch .Tensor ,
165
165
topk_ids : torch .Tensor ,
166
- num_experts : int
166
+ num_experts : int ,
167
+ max_num_tokens : Optional [int ] = None ,
167
168
) -> Tuple [torch .Tensor , torch .Tensor ]:
168
169
assert topk_ids .dim () == 2
169
170
assert topk_ids .shape [0 ] == a .shape [0 ]
@@ -172,7 +173,8 @@ def torch_dispatch(
172
173
topk = topk_ids .shape [1 ]
173
174
174
175
tokens_per_expert = torch .bincount (topk_ids .view (- 1 ), minlength = num_experts )
175
- max_num_tokens = tokens_per_expert .max ()
176
+ if max_num_tokens is None :
177
+ max_num_tokens = tokens_per_expert .max ()
176
178
177
179
b_a = torch .zeros ((num_experts , max_num_tokens , a .shape [1 ]),
178
180
dtype = a .dtype , device = a .device )
@@ -314,11 +316,10 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk):
314
316
block_size = 128
315
317
device = pgi .device
316
318
rank_num_tokens = num_tokens // pgi .world_size
317
-
318
- max_num_tokens = num_tokens
319
- #print(f"device = {device}, max_num_tokens = {max_num_tokens}, topk = {topk}, num_ex = {num_experts}, dp_size = {dp_size}")
320
319
rank = pgi .rank
321
320
world_size = pgi .world_size
321
+ max_num_tokens = num_tokens
322
+ #print(f"device = {device}, max_num_tokens = {max_num_tokens}, topk = {topk}, num_ex = {num_experts}, dp_size = {dp_size}")
322
323
323
324
ata = AllToAll (
324
325
max_num_tokens = max_num_tokens ,
@@ -342,7 +343,7 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk):
342
343
343
344
dispatch_combine = PplxDispatchCombine (
344
345
ata ,
345
- max_num_tokens , # // world_size?
346
+ max_num_tokens ,
346
347
pgi .world_size ,
347
348
dp_size ,
348
349
rank ,
@@ -353,7 +354,7 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk):
353
354
score_chunk = chunk_by_rank (scores , rank , world_size ).to (device )
354
355
chunk_topk_weight , chunk_topk_ids = fused_topk (a_chunk , score_chunk , topk , False )
355
356
356
- # print(f"chunk_topk_ids = {chunk_topk_ids.view(-1)}")
357
+ print (f"chunk_topk_ids = { chunk_topk_ids .view (- 1 )} " )
357
358
358
359
b_a , b_a_scale , expert_num_tokens = dispatch_combine .dispatch (
359
360
a_chunk ,
@@ -371,22 +372,25 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk):
371
372
#max_num = tokens_per_expert.max()
372
373
tokens_per_expert = chunk_by_rank (tokens_per_expert , rank , world_size ).to (dtype = torch .int32 )
373
374
374
- #print(f"tpe {tokens_per_expert}")
375
- #print(f"ent {expert_num_tokens}")
375
+ print (f"tpe { tokens_per_expert } " )
376
+ print (f"ent { expert_num_tokens } " )
377
+
378
+ #torch.set_printoptions(profile="full")
379
+ #torch.distributed.all_reduce(naive_b_a, op=torch.distributed.ReduceOp.MAX)
380
+ #torch.distributed.broadcast(naive_b_a, src=rank)
376
381
377
382
#naive_b_a = chunk_by_rank(naive_b_a, rank, world_size)
378
383
379
- #torch.set_printoptions(profile="full")
380
- #print("b_a", b_a[:naive_b_a.shape[1]])
381
- #print("naive_b_a", naive_b_a)
384
+ #print("b_a", b_a.shape, b_a) #[:, :naive_b_a.shape[1]])
385
+ #print("naive_b_a", naive_b_a.shape, naive_b_a)
382
386
383
387
torch .testing .assert_close (tokens_per_expert , expert_num_tokens , atol = 0 , rtol = 0 )
384
388
#torch.testing.assert_close(b_a[:, :naive_b_a.shape[1]], naive_b_a, atol=2e-2, rtol=0)
385
389
386
390
b_a = b_a * 1.5
387
391
388
392
out = torch .full (
389
- (max_num_tokens , hidden_dim ),
393
+ (rank_num_tokens * world_size , hidden_dim ),
390
394
torch .nan ,
391
395
dtype = a .dtype ,
392
396
device = device ,
@@ -539,7 +543,7 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk):
539
543
a .dtype ,
540
544
)
541
545
542
- experts = BatchedExperts ()
546
+ experts = BatchedExperts (max_num_tokens , rank )
543
547
544
548
fused_experts = FusedMoEModularKernel (
545
549
dispatch_combine ,
@@ -554,24 +558,20 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk):
554
558
555
559
out = fused_experts (
556
560
a_chunk ,
557
- w1 ,
558
- w2 ,
561
+ chunk_by_rank ( w1 , rank , world_size ) ,
562
+ chunk_by_rank ( w2 , rank , world_size ) ,
559
563
chunk_topk_weight ,
560
564
chunk_topk_ids ,
561
- global_num_experts = num_local_experts #? num_local_experts?
565
+ global_num_experts = num_experts #? num_local_experts?
562
566
)
563
567
564
568
torch .cuda .synchronize ()
565
569
566
570
ata .destroy ()
567
571
568
- #print(f"OUT {rank}: {out.shape} {out[:rank_num_tokens]}")
569
-
570
- #torch.distributed.all_reduce(out)
571
-
572
572
#print(f"OUT {rank}: {out.shape} {out}")
573
573
574
- return out [:rank_num_tokens ]
574
+ return out [:rank_num_tokens ] # chunk_by_rank?
575
575
576
576
577
577
def _pplx_moe (
0 commit comments