Skip to content

Commit 43e229c

Browse files
committed
review comments + cudagraph debugging
Signed-off-by: Bill Nell <[email protected]>
1 parent a674762 commit 43e229c

File tree

8 files changed

+80
-63
lines changed

8 files changed

+80
-63
lines changed

tests/kernels/moe/test_batched_moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@ def make_tensors(config: BatchedMMConfig):
3131
A = torch.randn(
3232
(config.num_experts, config.max_tokens_per_expert, config.K),
3333
device="cuda",
34-
dtype=config.dtype) / 50.0
34+
dtype=config.dtype)
3535
B = torch.randn((config.num_experts, config.N, config.K),
3636
device="cuda",
37-
dtype=config.dtype) / 50.0
37+
dtype=config.dtype)
3838
C = torch.zeros(
3939
(config.num_experts, config.max_tokens_per_expert, config.N),
4040
device="cuda",

tests/kernels/moe/test_moe.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,6 @@ def test_fused_moe(
122122
def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
123123
ep_size: int, dtype: torch.dtype, group_size: int,
124124
has_zp: bool, weight_bits: int):
125-
#print(m, n, k, e, topk, dtype, group_size, has_zp, weight_bits)
126125
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
127126
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
128127
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10

tests/kernels/moe/test_pplx_moe.py

Lines changed: 59 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -24,21 +24,35 @@
2424
spawn) # pyright: ignore[reportPrivateImportUsage]
2525
from typing_extensions import Concatenate, ParamSpec
2626

27-
import vllm.model_executor.layers.fused_moe # noqa
2827
from vllm.config import VllmConfig, set_current_vllm_config
2928
from vllm.model_executor.layers.activation import SiluAndMul
29+
from vllm.model_executor.layers.fused_moe import override_config
3030
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
31-
BatchedDispatchCombine, BatchedExperts)
32-
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
31+
BatchedDispatchCombine, BatchedExperts, BatchedTritonExperts)
32+
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_topk,
33+
get_default_config)
3334
from vllm.model_executor.layers.fused_moe.modular_kernel import (
3435
FusedMoEModularKernel)
3536
from vllm.model_executor.layers.fused_moe.pplx_dispatch_combine import (
3637
PplxDispatchCombine)
3738
from vllm.platforms import current_platform
3839

40+
PPLX_DISPATCH_COMBOS = [(4, 128, 128), (32, 1024, 512), (64, 1024, 512),
41+
(222, 2048, 1024)]
42+
43+
PPLX_MOE_COMBOS = [
44+
(1, 128, 128),
45+
(2, 128, 512),
46+
(3, 1024, 2048),
47+
(32, 128, 1024),
48+
(45, 512, 2048),
49+
(64, 1024, 1024),
50+
(222, 1024, 2048),
51+
]
52+
3953
NUM_EXPERTS = [8, 64]
4054
EP_SIZE = [1, 4]
41-
TOP_KS = [2, 6]
55+
TOP_KS = [1, 2, 6]
4256

4357
vllm_config = VllmConfig()
4458
vllm_config.scheduler_config.max_num_seqs = 128
@@ -298,7 +312,6 @@ def test_fused_moe_batched_experts(
298312
torch_output,
299313
atol=2e-2,
300314
rtol=0)
301-
torch.set_printoptions(profile="full")
302315
torch.testing.assert_close(baseline_output,
303316
batched_output,
304317
atol=2e-2,
@@ -426,25 +439,24 @@ def _pplx_dispatch_combine(
426439
nvshmem_finalize()
427440

428441

429-
# TODO: this test point does not work for M == 1
430-
@pytest.mark.parametrize("m", [4, 32, 64, 222])
431-
@pytest.mark.parametrize("n", [128, 1024, 2048])
432-
@pytest.mark.parametrize("k", [128, 512, 1024])
442+
# TODO: this test point does not work for odd M due to how the test is
443+
# written, not due to limitations of the pplx kernels. The pplx_moe
444+
# test below is able to deal with odd M.
445+
@pytest.mark.parametrize("mnk", PPLX_DISPATCH_COMBOS)
433446
@pytest.mark.parametrize("e", NUM_EXPERTS)
434447
@pytest.mark.parametrize("topk", TOP_KS)
435448
@pytest.mark.parametrize("dtype", [torch.bfloat16])
436449
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
437450
@requires_pplx
438451
def test_pplx_dispatch_combine(
439-
m: int,
440-
n: int,
441-
k: int,
452+
mnk: tuple[int, int, int],
442453
e: int,
443454
topk: int,
444455
dtype: torch.dtype,
445456
world_dp_size: tuple[int, int],
446457
):
447458
current_platform.seed_everything(7)
459+
m, n, k = mnk
448460
world_size, dp_size = world_dp_size
449461
device = "cuda"
450462
a = torch.randn((m, k), device=device, dtype=dtype) / 10
@@ -454,15 +466,11 @@ def test_pplx_dispatch_combine(
454466
topk, e)
455467

456468

457-
def pplx_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids):
458-
assert torch.cuda.current_device() == pgi.local_rank
459-
469+
def pplx_moe(rank, world_size, dp_size, a, w1, w2, topk_weight, topk_ids):
470+
device = torch.device("cuda", rank)
460471
hidden_dim = a.shape[1]
461472
num_experts = w1.shape[0]
462473
block_size = 128
463-
device = pgi.device
464-
rank = pgi.rank
465-
world_size = pgi.world_size
466474
topk = topk_ids.shape[1]
467475
max_num_tokens = rank_chunk(a.shape[0], 0, world_size)
468476

@@ -490,29 +498,39 @@ def pplx_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids):
490498
dp_size,
491499
)
492500

493-
experts = BatchedExperts(max_num_tokens=a.shape[0],
494-
world_size=world_size,
495-
dp_size=dp_size)
501+
experts = BatchedTritonExperts(max_num_tokens=a.shape[0],
502+
world_size=world_size,
503+
dp_size=dp_size)
496504

497505
fused_experts = FusedMoEModularKernel(
498506
dispatch_combine,
499507
experts,
500508
)
501509

502-
# TODO: workers with the same dp_rank must use the exact same inputs.
503-
510+
# Note: workers with the same dp_rank must use the exact same inputs.
504511
a_chunk = chunk_by_rank(a, rank, world_size).to(device)
505512
chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device)
506513
chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device)
507514

508-
out = fused_experts(
509-
a_chunk,
510-
# Chunking weights like this only works for batched format
511-
chunk_by_rank(w1, rank, world_size).to(device),
512-
chunk_by_rank(w2, rank, world_size).to(device),
513-
chunk_topk_weight,
514-
chunk_topk_ids,
515-
global_num_experts=num_experts)
515+
# Chunking weights like this only works for batched format
516+
w1_chunk = chunk_by_rank(w1, rank, world_size).to(device)
517+
w2_chunk = chunk_by_rank(w2, rank, world_size).to(device)
518+
519+
@torch.compile(backend='inductor', fullgraph=True)
520+
def _fused_experts(a, w1, w2, topk_weight, topk_ids, global_num_experts):
521+
return fused_experts(a,
522+
w1,
523+
w2,
524+
topk_weight,
525+
topk_ids,
526+
global_num_experts=global_num_experts)
527+
528+
out = _fused_experts(a_chunk,
529+
w1_chunk,
530+
w2_chunk,
531+
chunk_topk_weight,
532+
chunk_topk_ids,
533+
global_num_experts=num_experts)
516534

517535
torch.cuda.synchronize()
518536

@@ -546,8 +564,7 @@ def _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids):
546564
experts,
547565
)
548566

549-
# TODO: workers with the same dp_rank must use the exact same inputs.
550-
567+
# Note: workers with the same dp_rank must use the exact same inputs.
551568
a_chunk = chunk_by_rank(a, rank, world_size).to(device)
552569
chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device)
553570
chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device)
@@ -581,10 +598,14 @@ def _pplx_moe(
581598
m, k = a.shape
582599
e, _, n = w2.shape
583600

584-
with set_current_vllm_config(vllm_config):
601+
moe_config = get_default_config(m, e, n, k, topk, a.dtype, False)
602+
603+
with set_current_vllm_config(vllm_config), override_config(moe_config):
585604
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
586605
torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids)
587-
pplx_output = pplx_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids)
606+
pplx_output = pplx_moe(pgi.rank, pgi.world_size, dp_size, a, w1, w2,
607+
topk_weight, topk_ids)
608+
# TODO: fix + re-enable
588609
#batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight,
589610
# topk_ids)
590611

@@ -597,24 +618,21 @@ def _pplx_moe(
597618
nvshmem_finalize()
598619

599620

600-
@pytest.mark.parametrize("m", [1, 2, 3, 32, 45, 64, 222])
601-
@pytest.mark.parametrize("n", [128, 1024, 2048])
602-
@pytest.mark.parametrize("k", [128, 512, 1024])
621+
@pytest.mark.parametrize("mnk", PPLX_MOE_COMBOS)
603622
@pytest.mark.parametrize("e", NUM_EXPERTS)
604623
@pytest.mark.parametrize("topk", TOP_KS)
605624
@pytest.mark.parametrize("dtype", [torch.bfloat16])
606625
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
607626
@requires_pplx
608627
def test_pplx_moe(
609-
m: int,
610-
n: int,
611-
k: int,
628+
mnk: tuple[int, int, int],
612629
e: int,
613630
topk: int,
614631
dtype: torch.dtype,
615632
world_dp_size: tuple[int, int],
616633
):
617634
current_platform.seed_everything(7)
635+
m, n, k = mnk
618636
world_size, dp_size = world_dp_size
619637
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
620638
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10

vllm/distributed/parallel_state.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -949,25 +949,27 @@ def pplx_init(rank, world_size):
949949
nvshmem_get_unique_id, nvshmem_init)
950950
try:
951951
global PPLX_DID_INIT
952-
logger.info("PPLX_INIT rank=%d world=%d", rank, world_size)
952+
logger.debug(
953+
"Initialize NVSHMEM for PPLX kernels: rank=%d, "
954+
"world size=%d", rank, world_size)
953955
uid = nvshmem_get_unique_id(
954956
) if rank == 0 else nvshmem_alloc_empty_unique_id()
955957
uid_gpu = uid.cuda()
956958
get_world_group().broadcast(uid_gpu, src=0)
957-
logger.debug("PPLX_INIT UID = %s", uid_gpu)
958959
uid = uid_gpu.to(device='cpu')
960+
logger.debug("PPLX NVSHMEM UID = %s", uid)
959961
nvshmem_init(uid, rank, world_size)
960962
PPLX_DID_INIT = True
961963
except Exception as ex:
962-
logger.error("Failed to initialize nvshmem for pplx: %s", ex)
964+
logger.error("Failed to initialize NVSHMEM for PPLX: %s", ex)
963965

964966

965967
@run_once
966968
def pplx_finalize():
967969
global PPLX_DID_INIT
968970
if PPLX_DID_INIT:
969971
from pplx_kernels.nvshmem import nvshmem_finalize
970-
logger.info("PPLX finalize")
972+
logger.debug("PPLX NVSHMEM finalize")
971973
from vllm.model_executor.layers.fused_moe.layer import (
972974
_all_to_all_cache)
973975
_all_to_all_cache.destroy()

vllm/model_executor/layers/fused_moe/fused_batched_moe.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ def invoke_moe_batched_triton_kernel(
333333
BLOCK_M = config['BLOCK_SIZE_M']
334334
BLOCK_N = config['BLOCK_SIZE_N']
335335
BLOCK_K = config['BLOCK_SIZE_K']
336-
assert max_num_tokens % BLOCK_M == 0
336+
assert (torch.compiler.is_compiling() or max_num_tokens % BLOCK_M == 0)
337337

338338
grid = (expert_num_tokens.size(0), triton.cdiv(max_num_tokens, BLOCK_M) *
339339
triton.cdiv(B.size(1), BLOCK_N))
@@ -559,13 +559,15 @@ def apply(
559559
N = w1.size(1) // 2
560560

561561
# Not cudagraph friendly
562-
assert (torch.cuda.is_current_stream_capturing()
562+
assert (torch.compiler.is_compiling()
563+
or torch.cuda.is_current_stream_capturing()
563564
or torch.all(expert_num_tokens <= max_num_tokens * num_dp)), (
564565
f"{expert_num_tokens} <= {max_num_tokens * num_dp}")
565566

566567
for expert in range(num_local_experts):
567-
# Indexing expert_num_tokens doesn't work w/cudagraphs
568-
if torch.cuda.is_current_stream_capturing():
568+
# Indexing expert_num_tokens doesn't work w/cudagraphs or inductor
569+
if (torch.compiler.is_compiling()
570+
or torch.cuda.is_current_stream_capturing()):
569571
num = max_num_tokens * num_dp
570572
else:
571573
num = int(expert_num_tokens[expert].item())

vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def dispatch(
103103

104104
# This argument is optional, defaults to indices.size(0)
105105
# There's not much point setting this unless it is != indices.size(0)
106-
bound_m = None
106+
bound_m: Optional[torch.Tensor] = None
107107

108108
self.a2a.dispatch(
109109
out_expert_num_tokens=expert_num_tokens,
@@ -128,9 +128,10 @@ def combine(
128128
num_tokens = output.size(0) # M
129129
# This argument is optional
130130
# There's not much point setting this unless it is != topk_ids.size(0)
131-
bound_m = None
131+
bound_m: Optional[torch.Tensor] = None
132132

133-
assert topk_ids.size(0) == num_tokens
133+
assert topk_ids.size(0) == num_tokens, (
134+
f"{topk_ids.size(0)} == {num_tokens}")
134135
assert output.size(0) <= self.max_num_tokens, (
135136
f"{output.size(0)} <= {self.max_num_tokens}")
136137
assert output.size(1) == fused_expert_output.size(-1)

vllm/model_executor/models/deepseek_v2.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
173173
* (1. / self.routed_scaling_factor)
174174

175175
if self.tp_size > 1:
176-
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501
177-
final_hidden_states)
176+
final_hidden_states = (
177+
self.experts.maybe_all_reduce_tensor_model_parallel(
178+
final_hidden_states))
178179

179180
return final_hidden_states.view(num_tokens, hidden_dim)
180181

vllm/model_executor/models/granitemoe.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@ def __init__(self,
7070
prefix: str = ""):
7171
super().__init__()
7272
self.hidden_size = hidden_size
73-
self.tp_size = get_tensor_model_parallel_world_size()
7473

7574
# Gate always runs at half / full precision for now.
7675
self.gate = ReplicatedLinear(hidden_size,
@@ -98,11 +97,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
9897
# router_logits: (num_tokens, n_experts)
9998
router_logits, _ = self.gate(hidden_states)
10099
final_hidden_states = self.experts(hidden_states, router_logits)
101-
102-
if self.tp_size > 1:
103-
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501
104-
final_hidden_states)
105-
106100
return final_hidden_states.view(orig_shape)
107101

108102

0 commit comments

Comments
 (0)