@@ -491,6 +491,7 @@ def dispatch(
491
491
expert_map : Optional [torch .Tensor ],
492
492
apply_router_weight_on_input : bool ,
493
493
) -> Tuple [torch .Tensor , Optional [torch .Tensor ], Optional [torch .Tensor ]]:
494
+ assert a1 .dim () == 2
494
495
assert topk_ids .dim () == 2
495
496
assert topk_ids .shape [0 ] == a1 .shape [0 ]
496
497
@@ -504,11 +505,13 @@ def dispatch(
504
505
num_tokens , hidden_dim = a1 .shape
505
506
topk = topk_ids .shape [1 ]
506
507
507
- tokens_per_expert = torch .bincount (topk_ids .view (- 1 ),
508
- minlength = num_experts )
509
-
510
508
if self .max_num_tokens is None :
509
+ tokens_per_expert = torch .bincount (topk_ids .view (- 1 ),
510
+ minlength = num_experts )
511
511
self .max_num_tokens = int (tokens_per_expert .max ().item ())
512
+ else :
513
+ tokens_per_expert = torch .zeros (num_experts , dtype = torch .int ,
514
+ device = a1 .device )
512
515
513
516
rem_experts = num_experts % self .world_size
514
517
num_local_experts = ((num_experts // self .world_size ) +
@@ -518,23 +521,27 @@ def dispatch(
518
521
dtype = a1 .dtype ,
519
522
device = a1 .device )
520
523
521
- token_counts = torch .zeros (num_local_experts ,
522
- dtype = torch .int ,
523
- device = a1 .device )
524
-
525
524
first_expert = (((num_experts // self .world_size ) * self .rank ) +
526
525
rem_experts - self .rank )
527
526
last_expert = first_expert + num_local_experts
528
- #expert_id_range = range(first_expert, last_expert)
529
527
530
- for token in range (num_tokens ):
531
- for j in range (topk ):
532
- expert_id = topk_ids [token , j ]
533
- if expert_id >= first_expert and expert_id < last_expert :
534
- rel_index = expert_id - first_expert
535
- idx = token_counts [rel_index ]
536
- b_a1 [rel_index , idx :idx + 1 , :] = a1 [token , :]
537
- token_counts [rel_index ] = token_counts [rel_index ] + 1
528
+ # rhs = torch.empty((self.max_num_tokens, hidden_dim),
529
+ # dtype=a1.dtype, device=a1.device)
530
+
531
+ # for expert_id in range(first_expert, last_expert):
532
+ # topks = torch.any(topk_ids == expert_id, dim=1).flatten()
533
+ # rows = torch.count_nonzero(topks.flatten())
534
+ # #rhs[:rows] = a1[:topks.numel()][topks]
535
+ # topks_idx = topks.nonzero()
536
+ # torch.index_select(a1, dim=0, index=topks_idx.flatten(), out=rhs[:rows])
537
+ # b_a1[expert_id - first_expert, :rows, :] = rhs[:rows]
538
+ # tokens_per_expert[expert_id - first_expert] = rows
539
+
540
+ for expert_id in range (first_expert , last_expert ):
541
+ topks = torch .any (topk_ids == expert_id , dim = 1 ).flatten ()
542
+ rows = torch .count_nonzero (topks .flatten ())
543
+ b_a1 [expert_id - first_expert , :rows , :] = a1 [:topks .numel ()][topks ]
544
+ tokens_per_expert [expert_id - first_expert ] = rows
538
545
539
546
return b_a1 , a1_scale , tokens_per_expert
540
547
@@ -548,31 +555,32 @@ def combine(
548
555
) -> None :
549
556
num_tokens = topk_ids .shape [0 ]
550
557
num_local_experts = fused_expert_output .shape [0 ]
551
- num_experts = num_local_experts * self . world_size # NOT QUITE RIGHT
558
+ topk = topk_weights . shape [ 1 ]
552
559
K = fused_expert_output .shape [- 1 ]
553
560
assert output .shape [0 ] == num_tokens and output .shape [1 ] == K
554
- expert_counts = torch .zeros (
555
- num_experts ,
556
- dtype = torch .int ,
557
- device = fused_expert_output .device )
558
561
559
562
output .fill_ (0 )
560
563
561
564
first_expert = num_local_experts * self .rank # NOT QUITE RIGHT
562
565
last_expert = first_expert + num_local_experts
563
566
564
- for token in range (num_tokens ):
565
- expert_ids = topk_ids [token ]
566
- for i in range (expert_ids .numel ()):
567
- expert_id = expert_ids [i ]
568
- if expert_id >= first_expert and expert_id < last_expert :
569
- assert expert_id < num_experts
570
- idx = expert_counts [expert_id ]
571
- accum = fused_expert_output [expert_id - first_expert , idx :idx + 1 , :]
572
- if not apply_router_weight_on_input :
573
- accum = accum * topk_weights [token , i ]
574
- output [token , :] = output [token , :] + accum
575
- expert_counts [expert_id ] = expert_counts [expert_id ] + 1
567
+ # for expert_id in range(first_expert, last_expert):
568
+ # topkws = topk_ids == expert_id
569
+ # topks = torch.any(topkws, dim=1).flatten()
570
+ # outrhs = output[topks]
571
+ # rhs = fused_expert_output[expert_id - first_expert, :outrhs.shape[0], :]
572
+ # if not apply_router_weight_on_input:
573
+ # rhs.mul_(topk_weights[topkws].view(rhs.shape[0], 1))
574
+ # output[topks] = outrhs + rhs
575
+
576
+ for expert_id in range (first_expert , last_expert ):
577
+ topkws = topk_ids == expert_id
578
+ topks = torch .any (topkws , dim = 1 ).flatten ()
579
+ rows = torch .count_nonzero (topks )
580
+ rhs = fused_expert_output [expert_id - first_expert , :rows , :]
581
+ if not apply_router_weight_on_input :
582
+ rhs .mul_ (topk_weights [topkws ].view (rhs .shape [0 ], 1 ))
583
+ output [topks ] = output [topks ] + rhs
576
584
577
585
578
586
class BatchedExperts (mk .FusedMoEPermuteExpertsUnpermute ):
0 commit comments