Skip to content

Commit a003bd8

Browse files
committed
improve ref impl
Signed-off-by: Bill Nell <[email protected]>
1 parent 0f2e37a commit a003bd8

File tree

4 files changed

+46
-37
lines changed

4 files changed

+46
-37
lines changed

tests/kernels/moe/test_pplx_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def torch_moe2(a, w1, w2, topk_weight, topk_ids):
276276

277277
@pytest.mark.parametrize("m", [1, 33, 64, 222])
278278
@pytest.mark.parametrize("n", [128, 1024, 2048])
279-
@pytest.mark.parametrize("k", [128, 511, 1024])
279+
@pytest.mark.parametrize("k", [128, 512, 1024])
280280
@pytest.mark.parametrize("e", NUM_EXPERTS)
281281
@pytest.mark.parametrize("topk", TOP_KS)
282282
@pytest.mark.parametrize("dtype", [torch.bfloat16])

vllm/model_executor/layers/fused_moe/fused_batched_moe.py

Lines changed: 41 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,7 @@ def dispatch(
491491
expert_map: Optional[torch.Tensor],
492492
apply_router_weight_on_input: bool,
493493
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
494+
assert a1.dim() == 2
494495
assert topk_ids.dim() == 2
495496
assert topk_ids.shape[0] == a1.shape[0]
496497

@@ -504,11 +505,13 @@ def dispatch(
504505
num_tokens, hidden_dim = a1.shape
505506
topk = topk_ids.shape[1]
506507

507-
tokens_per_expert = torch.bincount(topk_ids.view(-1),
508-
minlength=num_experts)
509-
510508
if self.max_num_tokens is None:
509+
tokens_per_expert = torch.bincount(topk_ids.view(-1),
510+
minlength=num_experts)
511511
self.max_num_tokens = int(tokens_per_expert.max().item())
512+
else:
513+
tokens_per_expert = torch.zeros(num_experts, dtype=torch.int,
514+
device=a1.device)
512515

513516
rem_experts = num_experts % self.world_size
514517
num_local_experts = ((num_experts // self.world_size) +
@@ -518,23 +521,27 @@ def dispatch(
518521
dtype=a1.dtype,
519522
device=a1.device)
520523

521-
token_counts = torch.zeros(num_local_experts,
522-
dtype=torch.int,
523-
device=a1.device)
524-
525524
first_expert = (((num_experts // self.world_size) * self.rank) +
526525
rem_experts - self.rank)
527526
last_expert = first_expert + num_local_experts
528-
#expert_id_range = range(first_expert, last_expert)
529527

530-
for token in range(num_tokens):
531-
for j in range(topk):
532-
expert_id = topk_ids[token, j]
533-
if expert_id >= first_expert and expert_id < last_expert:
534-
rel_index = expert_id - first_expert
535-
idx = token_counts[rel_index]
536-
b_a1[rel_index, idx:idx + 1, :] = a1[token, :]
537-
token_counts[rel_index] = token_counts[rel_index] + 1
528+
# rhs = torch.empty((self.max_num_tokens, hidden_dim),
529+
# dtype=a1.dtype, device=a1.device)
530+
531+
# for expert_id in range(first_expert, last_expert):
532+
# topks = torch.any(topk_ids == expert_id, dim=1).flatten()
533+
# rows = torch.count_nonzero(topks.flatten())
534+
# #rhs[:rows] = a1[:topks.numel()][topks]
535+
# topks_idx = topks.nonzero()
536+
# torch.index_select(a1, dim=0, index=topks_idx.flatten(), out=rhs[:rows])
537+
# b_a1[expert_id - first_expert, :rows, :] = rhs[:rows]
538+
# tokens_per_expert[expert_id - first_expert] = rows
539+
540+
for expert_id in range(first_expert, last_expert):
541+
topks = torch.any(topk_ids == expert_id, dim=1).flatten()
542+
rows = torch.count_nonzero(topks.flatten())
543+
b_a1[expert_id - first_expert, :rows, :] = a1[:topks.numel()][topks]
544+
tokens_per_expert[expert_id - first_expert] = rows
538545

539546
return b_a1, a1_scale, tokens_per_expert
540547

@@ -548,31 +555,32 @@ def combine(
548555
) -> None:
549556
num_tokens = topk_ids.shape[0]
550557
num_local_experts = fused_expert_output.shape[0]
551-
num_experts = num_local_experts * self.world_size # NOT QUITE RIGHT
558+
topk = topk_weights.shape[1]
552559
K = fused_expert_output.shape[-1]
553560
assert output.shape[0] == num_tokens and output.shape[1] == K
554-
expert_counts = torch.zeros(
555-
num_experts,
556-
dtype=torch.int,
557-
device=fused_expert_output.device)
558561

559562
output.fill_(0)
560563

561564
first_expert = num_local_experts * self.rank # NOT QUITE RIGHT
562565
last_expert = first_expert + num_local_experts
563566

564-
for token in range(num_tokens):
565-
expert_ids = topk_ids[token]
566-
for i in range(expert_ids.numel()):
567-
expert_id = expert_ids[i]
568-
if expert_id >= first_expert and expert_id < last_expert:
569-
assert expert_id < num_experts
570-
idx = expert_counts[expert_id]
571-
accum = fused_expert_output[expert_id - first_expert, idx:idx + 1, :]
572-
if not apply_router_weight_on_input:
573-
accum = accum * topk_weights[token, i]
574-
output[token, :] = output[token, :] + accum
575-
expert_counts[expert_id] = expert_counts[expert_id] + 1
567+
# for expert_id in range(first_expert, last_expert):
568+
# topkws = topk_ids == expert_id
569+
# topks = torch.any(topkws, dim=1).flatten()
570+
# outrhs = output[topks]
571+
# rhs = fused_expert_output[expert_id - first_expert, :outrhs.shape[0], :]
572+
# if not apply_router_weight_on_input:
573+
# rhs.mul_(topk_weights[topkws].view(rhs.shape[0], 1))
574+
# output[topks] = outrhs + rhs
575+
576+
for expert_id in range(first_expert, last_expert):
577+
topkws = topk_ids == expert_id
578+
topks = torch.any(topkws, dim=1).flatten()
579+
rows = torch.count_nonzero(topks)
580+
rhs = fused_expert_output[expert_id - first_expert, :rows, :]
581+
if not apply_router_weight_on_input:
582+
rhs.mul_(topk_weights[topkws].view(rhs.shape[0], 1))
583+
output[topks] = output[topks] + rhs
576584

577585

578586
class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def set_dispatch_combine(
262262
if isinstance(dispatch_combine,
263263
(BatchedDispatchCombine, PplxDispatchCombine)):
264264
logger.debug("BatchedTritonExperts %s", self.moe)
265-
experts = BatchedTritonExperts(
265+
experts = BatchedExperts(
266266
max_num_tokens=MOE_DP_CHUNK_SIZE,
267267
use_fp8_w8a8=False,
268268
use_int8_w8a8=False,
@@ -695,8 +695,6 @@ def _construct_dispatch_combine(
695695
rank,
696696
moe.in_dtype,
697697
)
698-
elif False:
699-
return None
700698
elif self.dp_size > 1:
701699
logger.debug("using batched dispatch")
702700
dp_size = moe.ep_size // moe.dp_size # dp_size actually means TP.
@@ -707,6 +705,8 @@ def _construct_dispatch_combine(
707705
dp_size=dp_size,
708706
rank=rank,
709707
)
708+
elif False:
709+
return None
710710
else:
711711
logger.debug("using standard dispatch")
712712
return StandardDispatchCombine(

vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def dispatch(
7272
per_act_token,
7373
self.block_shape)
7474

75+
# TODO: does rem_experts need to be 0 for pplx to work properly?
7576
rem_experts = num_experts % self.world_size
7677
num_local_experts = ((num_experts // self.world_size) +
7778
(1 if self.rank < rem_experts else 0))

0 commit comments

Comments
 (0)