Skip to content

Commit b5be324

Browse files
committed
rename dispatch combine -> prepare finalize
Signed-off-by: Bill Nell <[email protected]>
1 parent 93dd74f commit b5be324

File tree

10 files changed

+152
-103
lines changed

10 files changed

+152
-103
lines changed

tests/kernels/moe/test_pplx_moe.py

Lines changed: 89 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,17 @@
2828
from vllm.model_executor.layers.activation import SiluAndMul
2929
from vllm.model_executor.layers.fused_moe import override_config
3030
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
31-
BatchedDispatchCombine, BatchedExperts, BatchedTritonExperts)
31+
BatchedExperts, BatchedPrepareAndFinalize, BatchedTritonExperts)
3232
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_topk,
3333
get_default_config)
3434
from vllm.model_executor.layers.fused_moe.modular_kernel import (
3535
FusedMoEModularKernel)
36-
from vllm.model_executor.layers.fused_moe.pplx_dispatch_combine import (
37-
PplxDispatchCombine)
36+
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
37+
PplxPrepareAndFinalize)
3838
from vllm.platforms import current_platform
3939

40-
PPLX_DISPATCH_COMBOS = [(4, 128, 128), (32, 1024, 512), (64, 1024, 512),
41-
(222, 2048, 1024)]
40+
PPLX_PREPARE_COMBOS = [(4, 128, 128), (32, 1024, 512), (64, 1024, 512),
41+
(222, 2048, 1024)]
4242

4343
PPLX_MOE_COMBOS = [
4444
(1, 128, 128),
@@ -175,7 +175,7 @@ def parallel_launch_from_env(
175175
)
176176

177177

178-
def torch_dispatch(
178+
def torch_prepare(
179179
a: torch.Tensor,
180180
topk_ids: torch.Tensor,
181181
num_experts: int,
@@ -211,7 +211,8 @@ def torch_dispatch(
211211
return b_a, tokens_per_expert
212212

213213

214-
def torch_combine(b_out, topk_weight, topk_ids):
214+
def torch_finalize(b_out: torch.Tensor, topk_weight: torch.Tensor,
215+
topk_ids: torch.Tensor) -> torch.Tensor:
215216
num_tokens = topk_ids.shape[0]
216217
num_experts = b_out.shape[0]
217218
K = b_out.shape[-1]
@@ -231,9 +232,15 @@ def torch_combine(b_out, topk_weight, topk_ids):
231232
return out
232233

233234

234-
def torch_batched_moe(a, w1, w2, topk_weight, topk_ids):
235+
def torch_batched_moe(
236+
a: torch.Tensor,
237+
w1: torch.Tensor,
238+
w2: torch.Tensor,
239+
topk_weight: torch.Tensor,
240+
topk_ids: torch.Tensor,
241+
) -> torch.Tensor:
235242
num_experts = w1.shape[0]
236-
b_a, tokens_per_expert = torch_dispatch(a, topk_ids, num_experts)
243+
b_a, tokens_per_expert = torch_prepare(a, topk_ids, num_experts)
237244
assert b_a.dim() == 3
238245
num_tokens, topk = topk_ids.shape
239246
_, max_num_tokens, K = b_a.shape
@@ -251,21 +258,33 @@ def torch_batched_moe(a, w1, w2, topk_weight, topk_ids):
251258
tmp[:num], b_a[expert, :num, :] @ w1[expert].transpose(0, 1))
252259
out[expert, :num, :] = tmp[:num] @ w2[expert].transpose(0, 1)
253260

254-
return torch_combine(out, topk_weight, topk_ids)
261+
return torch_finalize(out, topk_weight, topk_ids)
255262

256263

257-
def batched_moe(a, w1, w2, topk_weight, topk_ids):
264+
def batched_moe(
265+
a: torch.Tensor,
266+
w1: torch.Tensor,
267+
w2: torch.Tensor,
268+
topk_weight: torch.Tensor,
269+
topk_ids: torch.Tensor,
270+
) -> torch.Tensor:
258271
num_experts = w1.shape[0]
259272

260273
fused_experts = FusedMoEModularKernel(
261-
BatchedDispatchCombine(a.shape[0], world_size=1, dp_size=1, rank=0),
274+
BatchedPrepareAndFinalize(a.shape[0], world_size=1, dp_size=1, rank=0),
262275
BatchedExperts(max_num_tokens=a.shape[0], dp_size=1, world_size=1))
263276

264277
return fused_experts(a, w1, w2, topk_weight, topk_ids, num_experts)
265278

266279

267-
# TODO: same as torch_moe but with fused_topk factored out.
268-
def torch_moe2(a, w1, w2, topk_weight, topk_ids):
280+
# Note: same as torch_moe but with fused_topk factored out.
281+
def torch_moe2(
282+
a: torch.Tensor,
283+
w1: torch.Tensor,
284+
w2: torch.Tensor,
285+
topk_weight: torch.Tensor,
286+
topk_ids: torch.Tensor,
287+
) -> torch.Tensor:
269288
M, K = a.shape
270289
topk = topk_ids.shape[1]
271290
a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K)
@@ -318,17 +337,19 @@ def test_fused_moe_batched_experts(
318337
rtol=0)
319338

320339

321-
def rank_chunk(num, r, w):
340+
def rank_chunk(num: int, r: int, w: int) -> int:
322341
rem = num % w
323342
return (num // w) + (1 if r < rem else 0)
324343

325344

326-
def chunk_by_rank(t, r, w):
345+
def chunk_by_rank(t: torch.Tensor, r: int, w: int) -> torch.Tensor:
327346
chunk = rank_chunk(t.shape[0], r, w)
328347
return t[(r * chunk):(r + 1) * chunk]
329348

330349

331-
def pplx_dispatch_combine(pgi, dp_size, a, topk_weight, topk_ids, num_experts):
350+
def pplx_prepare_finalize(pgi: ProcessGroupInfo, dp_size: int, a: torch.Tensor,
351+
topk_weight: torch.Tensor, topk_ids: torch.Tensor,
352+
num_experts: int) -> torch.Tensor:
332353
assert torch.cuda.current_device() == pgi.local_rank
333354

334355
topk = topk_ids.shape[1]
@@ -355,7 +376,7 @@ def pplx_dispatch_combine(pgi, dp_size, a, topk_weight, topk_ids, num_experts):
355376

356377
topk_ids = topk_ids.to(dtype=torch.uint32)
357378

358-
dispatch_combine = PplxDispatchCombine(
379+
prepare_finalize = PplxPrepareAndFinalize(
359380
ata,
360381
max_num_tokens,
361382
world_size,
@@ -368,7 +389,7 @@ def pplx_dispatch_combine(pgi, dp_size, a, topk_weight, topk_ids, num_experts):
368389
chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device)
369390
chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device)
370391

371-
b_a, b_a_scale, expert_num_tokens = dispatch_combine.dispatch(
392+
b_a, b_a_scale, expert_num_tokens = prepare_finalize.prepare(
372393
a_chunk,
373394
None,
374395
None,
@@ -388,7 +409,7 @@ def pplx_dispatch_combine(pgi, dp_size, a, topk_weight, topk_ids, num_experts):
388409
device=device,
389410
)
390411

391-
dispatch_combine.combine(
412+
prepare_finalize.finalize(
392413
out,
393414
b_a,
394415
chunk_topk_weight,
@@ -405,13 +426,13 @@ def pplx_dispatch_combine(pgi, dp_size, a, topk_weight, topk_ids, num_experts):
405426
return out[:num_tokens]
406427

407428

408-
def _pplx_dispatch_combine(
429+
def _pplx_prepare_finalize(
409430
pgi: ProcessGroupInfo,
410431
dp_size: int,
411-
a,
412-
score,
413-
topk,
414-
num_experts,
432+
a: torch.Tensor,
433+
score: torch.Tensor,
434+
topk: torch.Tensor,
435+
num_experts: int,
415436
):
416437
uid = nvshmem_get_unique_id(
417438
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
@@ -428,7 +449,7 @@ def _pplx_dispatch_combine(
428449
topk_weight.view(-1, topk, 1).to(device)).sum(dim=1).to(
429450
a.dtype)
430451

431-
pplx_output = pplx_dispatch_combine(pgi, dp_size, a, topk_weight, topk_ids,
452+
pplx_output = pplx_prepare_finalize(pgi, dp_size, a, topk_weight, topk_ids,
432453
num_experts)
433454

434455
torch_output = chunk_by_rank(torch_output, pgi.rank,
@@ -439,16 +460,16 @@ def _pplx_dispatch_combine(
439460
nvshmem_finalize()
440461

441462

442-
# TODO: this test point does not work for odd M due to how the test is
463+
# TODO (bnell): this test point does not work for odd M due to how the test is
443464
# written, not due to limitations of the pplx kernels. The pplx_moe
444465
# test below is able to deal with odd M.
445-
@pytest.mark.parametrize("mnk", PPLX_DISPATCH_COMBOS)
466+
@pytest.mark.parametrize("mnk", PPLX_PREPARE_COMBOS)
446467
@pytest.mark.parametrize("e", NUM_EXPERTS)
447468
@pytest.mark.parametrize("topk", TOP_KS)
448469
@pytest.mark.parametrize("dtype", [torch.bfloat16])
449470
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
450471
@requires_pplx
451-
def test_pplx_dispatch_combine(
472+
def test_pplx_prepare_finalize(
452473
mnk: tuple[int, int, int],
453474
e: int,
454475
topk: int,
@@ -462,11 +483,22 @@ def test_pplx_dispatch_combine(
462483
a = torch.randn((m, k), device=device, dtype=dtype) / 10
463484
score = torch.randn((m, e), device=device, dtype=dtype)
464485

465-
parallel_launch(world_size, _pplx_dispatch_combine, dp_size, a, score,
486+
parallel_launch(world_size, _pplx_prepare_finalize, dp_size, a, score,
466487
topk, e)
467488

468489

469-
def pplx_moe(rank, world_size, dp_size, a, w1, w2, topk_weight, topk_ids):
490+
def pplx_moe(
491+
rank: int,
492+
world_size: int,
493+
dp_size: int,
494+
a: torch.Tensor,
495+
w1: torch.Tensor,
496+
w2: torch.Tensor,
497+
topk_weight: torch.Tensor,
498+
topk_ids: torch.Tensor,
499+
use_compile: bool = True,
500+
use_cudagraphs: bool = True,
501+
) -> torch.Tensor:
470502
device = torch.device("cuda", rank)
471503
hidden_dim = a.shape[1]
472504
num_experts = w1.shape[0]
@@ -490,7 +522,7 @@ def pplx_moe(rank, world_size, dp_size, a, w1, w2, topk_weight, topk_ids):
490522

491523
topk_ids = topk_ids.to(dtype=torch.uint32)
492524

493-
dispatch_combine = PplxDispatchCombine(
525+
prepare_finalize = PplxPrepareAndFinalize(
494526
ata,
495527
max_num_tokens,
496528
world_size,
@@ -503,7 +535,7 @@ def pplx_moe(rank, world_size, dp_size, a, w1, w2, topk_weight, topk_ids):
503535
dp_size=dp_size)
504536

505537
fused_experts = FusedMoEModularKernel(
506-
dispatch_combine,
538+
prepare_finalize,
507539
experts,
508540
)
509541

@@ -516,14 +548,12 @@ def pplx_moe(rank, world_size, dp_size, a, w1, w2, topk_weight, topk_ids):
516548
w1_chunk = chunk_by_rank(w1, rank, world_size).to(device)
517549
w2_chunk = chunk_by_rank(w2, rank, world_size).to(device)
518550

519-
@torch.compile(backend='inductor', fullgraph=True)
520-
def _fused_experts(a, w1, w2, topk_weight, topk_ids, global_num_experts):
521-
return fused_experts(a,
522-
w1,
523-
w2,
524-
topk_weight,
525-
topk_ids,
526-
global_num_experts=global_num_experts)
551+
if use_compile:
552+
_fused_experts = torch.compile(fused_experts,
553+
backend='inductor',
554+
fullgraph=True)
555+
else:
556+
_fused_experts = fused_experts
527557

528558
out = _fused_experts(a_chunk,
529559
w1_chunk,
@@ -532,6 +562,21 @@ def _fused_experts(a, w1, w2, topk_weight, topk_ids, global_num_experts):
532562
chunk_topk_ids,
533563
global_num_experts=num_experts)
534564

565+
if use_cudagraphs:
566+
out.fill_(0)
567+
stream = torch.cuda.Stream()
568+
graph = torch.cuda.CUDAGraph()
569+
with torch.cuda.graph(graph, stream=stream):
570+
out = _fused_experts(a_chunk,
571+
w1_chunk,
572+
w2_chunk,
573+
chunk_topk_weight,
574+
chunk_topk_ids,
575+
global_num_experts=num_experts)
576+
577+
torch.cuda.synchronize()
578+
graph.replay()
579+
535580
torch.cuda.synchronize()
536581

537582
ata.destroy()
@@ -548,7 +593,7 @@ def _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids):
548593
world_size = pgi.world_size
549594
max_num_tokens = rank_chunk(a.shape[0], 0, world_size)
550595

551-
dispatch_combine = BatchedDispatchCombine(
596+
prepare_finalize = BatchedPrepareAndFinalize(
552597
max_num_tokens=max_num_tokens,
553598
world_size=world_size,
554599
dp_size=dp_size,
@@ -560,7 +605,7 @@ def _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids):
560605
dp_size=1)
561606

562607
fused_experts = FusedMoEModularKernel(
563-
dispatch_combine,
608+
prepare_finalize,
564609
experts,
565610
)
566611

@@ -605,7 +650,7 @@ def _pplx_moe(
605650
torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids)
606651
pplx_output = pplx_moe(pgi.rank, pgi.world_size, dp_size, a, w1, w2,
607652
topk_weight, topk_ids)
608-
# TODO: fix + re-enable
653+
# TODO (bnell): fix + re-enable
609654
#batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight,
610655
# topk_ids)
611656

vllm/model_executor/layers/fused_moe/cutlass_moe.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77

88
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
99
from vllm import _custom_ops as ops
10-
from vllm.model_executor.layers.fused_moe.dispatch_combine import (
11-
StandardDispatchCombine)
10+
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
11+
StandardPrepareAndFinalize)
1212
from vllm.model_executor.layers.fused_moe.utils import _fp8_perm, _resize_cache
1313
from vllm.scalar_type import scalar_types
1414

@@ -184,7 +184,7 @@ def modular_cutlass_moe_fp8(
184184
out_dtype: torch.dtype = torch.half,
185185
) -> mk.FusedMoEModularKernel:
186186
return mk.FusedMoEModularKernel(
187-
StandardDispatchCombine(
187+
StandardPrepareAndFinalize(
188188
per_channel_quant=per_act_token,
189189
quant_dtype=torch.float8_e4m3fn,
190190
),

vllm/model_executor/layers/fused_moe/deep_gemm_moe.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77

88
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
99
from vllm.logger import init_logger
10-
from vllm.model_executor.layers.fused_moe.dispatch_combine import (
11-
StandardDispatchCombine)
1210
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
1311
_moe_permute)
12+
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
13+
StandardPrepareAndFinalize)
1414
from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize,
1515
_resize_cache)
1616
from vllm.utils import round_up
@@ -153,8 +153,8 @@ def apply(
153153

154154
def modular_deep_gemm_fused_moe_fp8() -> mk.FusedMoEModularKernel:
155155
return mk.FusedMoEModularKernel(
156-
StandardDispatchCombine(quant_dtype=torch.float8_e4m3fn,
157-
block_shape=deep_gemm_block_shape()),
156+
StandardPrepareAndFinalize(quant_dtype=torch.float8_e4m3fn,
157+
block_shape=deep_gemm_block_shape()),
158158
DeepGemmExperts(),
159159
)
160160

vllm/model_executor/layers/fused_moe/fused_batched_moe.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,9 @@ def invoke_moe_batched_triton_kernel(
333333
BLOCK_M = config['BLOCK_SIZE_M']
334334
BLOCK_N = config['BLOCK_SIZE_N']
335335
BLOCK_K = config['BLOCK_SIZE_K']
336-
assert (torch.compiler.is_compiling() or max_num_tokens % BLOCK_M == 0)
336+
assert (torch.compiler.is_compiling()
337+
or torch.cuda.is_current_stream_capturing()
338+
or max_num_tokens % BLOCK_M == 0)
337339

338340
grid = (expert_num_tokens.size(0), triton.cdiv(max_num_tokens, BLOCK_M) *
339341
triton.cdiv(B.size(1), BLOCK_N))
@@ -384,9 +386,9 @@ def rank_chunk(num, r, w):
384386
return (num // w) + (1 if r < rem else 0)
385387

386388

387-
class BatchedDispatchCombine(mk.FusedMoEQuantizeDispatchCombine):
389+
class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
388390
"""
389-
A reference dispatch/combine class that reorganizes the tokens into
391+
A reference prepare/finalize class that reorganizes the tokens into
390392
expert batched format, i.e. E x max_num_tokens x K. This is the format
391393
that the PPLX dispatch/combine kernels use.
392394
"""
@@ -399,7 +401,7 @@ def __init__(self, max_num_tokens: Optional[int], world_size: int,
399401
self.rank = rank
400402
self.max_num_tokens = max_num_tokens
401403

402-
def dispatch(
404+
def prepare(
403405
self,
404406
a1: torch.Tensor,
405407
a1_scale: Optional[torch.Tensor],
@@ -454,7 +456,7 @@ def dispatch(
454456

455457
return b_a1, a1_scale, tokens_per_expert
456458

457-
def combine(
459+
def finalize(
458460
self,
459461
output: torch.Tensor,
460462
fused_expert_output: torch.Tensor,

0 commit comments

Comments
 (0)