Skip to content

Commit fe1974a

Browse files
committed
somewhat working unit test
Signed-off-by: Bill Nell <[email protected]>
1 parent cb7320d commit fe1974a

File tree

4 files changed

+78
-78
lines changed

4 files changed

+78
-78
lines changed

tests/kernels/test_pplx_moe.py

Lines changed: 69 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,8 @@
99
import torch
1010
import traceback
1111

12-
from torch.nn import Parameter
13-
from torch.nn import functional as F
1412
from torch.multiprocessing import spawn # pyright: ignore[reportPrivateImportUsage]
15-
from typing import Callable, Concatenate, ParamSpec
13+
from typing import Callable, Concatenate, ParamSpec, Tuple
1614

1715
from pplx_kernels import AllToAll
1816
from pplx_kernels.nvshmem import (
@@ -25,27 +23,18 @@
2523
import vllm.model_executor.layers.fused_moe # noqa
2624
from tests.kernels.utils import (compute_max_diff, opcheck, stack_and_dev,
2725
torch_moe, torch_moe_single)
28-
from vllm import _custom_ops as ops
26+
#from vllm import _custom_ops as ops
2927
from vllm.config import VllmConfig, set_current_vllm_config
30-
from vllm.model_executor.layers.fused_moe import fused_moe
28+
#from vllm.model_executor.layers.fused_moe import fused_moe
3129
#from vllm.model_executor.layers.fused_moe.fused_batched_moe import fused_batched_experts
3230
from vllm.model_executor.layers.fused_moe.fused_moe import (
3331
fused_topk, moe_align_block_size)
34-
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
35-
fused_moe as iterative_moe)
36-
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
37-
marlin_quantize)
38-
from vllm.model_executor.layers.quantization.utils.quant_utils import (
39-
quantize_weights)
40-
from vllm.model_executor.models.mixtral import MixtralMoE
4132
from vllm.platforms import current_platform
42-
from vllm.scalar_type import scalar_types
43-
from vllm.utils import round_up
4433

4534
from vllm.model_executor.layers.activation import SiluAndMul
4635

4736
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts, BatchedDispatchCombine, BatchedExperts, fused_experts
48-
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel, FusedMoEQuantizeDispatchCombine
37+
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
4938
from vllm.model_executor.layers.fused_moe.pplx_dispatch_combine import PplxDispatchCombine
5039

5140
NUM_EXPERTS = [8, 64]
@@ -373,7 +362,8 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk):
373362
num_experts, # store at PplxDispatchCombine creation?
374363
None
375364
)
376-
#torch.cuda.synchronize() # necessary?
365+
366+
b_a = b_a * 1.5
377367

378368
out = torch.full(
379369
(max_num_tokens, hidden_dim),
@@ -392,7 +382,7 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk):
392382

393383
ata.destroy()
394384

395-
torch.distributed.barrier()
385+
#torch.distributed.barrier()
396386

397387
#print(f"OUT {rank}: {out.shape} {out[:rank_num_tokens]}")
398388

@@ -406,27 +396,34 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk):
406396
def _pplx_dispatch_combine(
407397
pgi: ProcessGroupInfo,
408398
dp_size: int,
409-
a: torch.Tensor,
410-
w1: torch.Tensor,
411-
w2: torch.Tensor,
412-
score: torch.Tensor,
399+
m, n, k, e,
400+
#a: torch.Tensor,
401+
#w1: torch.Tensor,
402+
#w2: torch.Tensor,
403+
#score: torch.Tensor,
413404
topk: int,
414405
dtype: torch.dtype,
415406
):
416407
uid = nvshmem_get_unique_id() if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
417408
torch.distributed.broadcast(uid, src=0)
418409
nvshmem_init(uid, pgi.rank, pgi.world_size)
410+
device = pgi.device
419411

420-
m, k = a.shape
421-
e, _, n = w2.shape
412+
a = torch.randn((m, k), device=device, dtype=dtype) / 10
413+
w1 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 10
414+
w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 10
415+
score = torch.randn((m, e), device=device, dtype=dtype)
416+
417+
#m, k = a.shape
418+
#e, _, n = w2.shape
422419

423420
topk_weight, topk_ids = fused_topk(a, score, topk, False)
424421

425422
#print(f"a {a.shape}")
426423
a_rep = torch.repeat_interleave(a, topk, dim=0)
427424
#print(f"a_rep {a_rep.shape} {a_rep.view(-1, topk, k)}")
428425

429-
torch_output = (a_rep.view(-1, topk, k) * topk_weight.view(-1, topk, 1)).to(a.dtype).sum(dim=1)
426+
torch_output = (a_rep.view(-1, topk, k) * 1.5 * topk_weight.view(-1, topk, 1)).sum(dim=1).to(a.dtype)
430427

431428
#print(f"torch_output {pgi.rank}: {torch_output.shape} {torch_output}")
432429

@@ -452,35 +449,28 @@ def _pplx_dispatch_combine(
452449
nvshmem_finalize()
453450

454451

455-
@pytest.mark.parametrize("m", [2, 32, 64, 222]) #, 1024 * 128])
452+
@pytest.mark.parametrize("m", [4, 32, 64, 222]) #, 1024 * 128])
456453
@pytest.mark.parametrize("n", [128, 1024, 2048])
457454
@pytest.mark.parametrize("k", [128, 512, 1024]) # restrictions? % 128?
458455
@pytest.mark.parametrize("e", NUM_EXPERTS)
459456
@pytest.mark.parametrize("topk", TOP_KS)
460457
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
458+
@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [[4, 2]])
461459
def test_pplx_dispatch_combine(
462460
m: int,
463461
n: int,
464462
k: int,
465463
e: int,
466464
topk: int,
467465
dtype: torch.dtype,
466+
world_dp_size: Tuple[int, int],
468467
):
469468
current_platform.seed_everything(7)
470-
if False:
471-
world_size = 4
472-
dp_size = 2
473-
else:
474-
world_size = 2
475-
dp_size = 1
476-
477-
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
478-
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
479-
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
480-
score = torch.randn((m, e), device="cuda", dtype=dtype)
469+
world_size, dp_size = world_dp_size
481470

482471
parallel_launch(
483-
world_size, _pplx_dispatch_combine, dp_size, a, w1, w2, score, topk, dtype
472+
#world_size, _pplx_dispatch_combine, dp_size, a, w1, w2, score, topk, dtype
473+
world_size, _pplx_dispatch_combine, dp_size, m, n, k, e, topk, dtype
484474
)
485475

486476

@@ -489,9 +479,10 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk):
489479

490480
num_tokens, hidden_dim = a.shape
491481
num_experts = w1.shape[0]
482+
num_local_experts = num_experts // pgi.world_size
492483
block_size = 128
493484
device = pgi.device
494-
rank_num_tokens = num_tokens // pgi.world_size
485+
rank_num_tokens = num_tokens // pgi.world_size # TODO even divide
495486

496487
max_num_tokens = num_tokens
497488
#print(f"device = {device}, max_num_tokens = {max_num_tokens}, topk = {topk}, num_ex = {num_experts}, dp_size = {dp_size}")
@@ -518,6 +509,9 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk):
518509
),
519510
)
520511

512+
w1 = w1.to(device)
513+
w2 = w2.to(device)
514+
521515
dispatch_combine = PplxDispatchCombine(
522516
ata,
523517
max_num_tokens, # // world_size?
@@ -538,73 +532,77 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk):
538532
score_chunk = chunk_by_rank(scores, rank, world_size).to(device)
539533
chunk_topk_weight, chunk_topk_ids = fused_topk(a_chunk, score_chunk, topk, False)
540534

541-
#print(f"chunk_topk_ids = {chunk_topk_ids}")
535+
#print(f"chunk_topk_ids {rank} {chunk_topk_ids.shape} {chunk_topk_ids.view(-1)}")
542536

543537
out = fused_experts(
544538
a_chunk,
545-
w1, # chunk?
546-
w2, # chunk?
539+
w1,
540+
w2,
547541
chunk_topk_weight,
548542
chunk_topk_ids,
549-
global_num_experts=num_experts #? num_local_experts?
543+
global_num_experts=num_local_experts #? num_local_experts?
550544
)
551545

552546
torch.cuda.synchronize()
553547

554548
ata.destroy()
555549

556-
torch.distributed.barrier()
550+
#torch.distributed.barrier()
557551

558552
#print(f"OUT {rank}: {out.shape} {out[:rank_num_tokens]}")
559553

560554
#torch.distributed.all_reduce(out)
561555

562-
print(f"OUT {rank}: {out.shape} {out}")
556+
#print(f"OUT {rank}: {out.shape} {out}")
563557

564558
return out[:rank_num_tokens]
565559

566560

567561
def _pplx_moe(
568562
pgi: ProcessGroupInfo,
569563
dp_size: int,
570-
m: int,
571-
n: int,
572-
k: int,
573-
e: int,
564+
a: torch.Tensor,
565+
w1: torch.Tensor,
566+
w2: torch.Tensor,
567+
score: torch.Tensor,
574568
topk: int,
575569
dtype: torch.dtype,
576570
):
577571
uid = nvshmem_get_unique_id() if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
578572
torch.distributed.broadcast(uid, src=0)
579573
nvshmem_init(uid, pgi.rank, pgi.world_size)
580574

581-
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
582-
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
583-
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
575+
m, k = a.shape
576+
e, _, n = w2.shape
584577

585-
score = torch.randn((m, e), device="cuda", dtype=dtype)
578+
torch.set_printoptions(profile="full")
586579

587580
vllm_config = VllmConfig()
588581
with set_current_vllm_config(vllm_config):
589582
topk_weight, topk_ids = fused_topk(a, score, topk, False)
590583

584+
#print(f"topk_ids {pgi.rank} {topk_ids.shape} {topk_ids.view(-1)}")
585+
591586
torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids)
592587

593-
pplxd_output = torch_pplx_moe(pgi,
594-
dp_size,
595-
a,
596-
w1,
597-
w2,
598-
score,
599-
topk)
588+
pplx_output = torch_pplx_moe(pgi,
589+
dp_size,
590+
a,
591+
w1,
592+
w2,
593+
score,
594+
topk)
595+
596+
#print(f"torch_output {pgi.rank}: {torch_output}")
600597

601598
if False:
602-
torch.set_printoptions(profile="full")
603599
print("BASELINE")
604600
print(torch_output)
605601
print("OUTPUT")
606602
print(pplx_output)
607603

604+
torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to(pplx_output.device)
605+
608606
torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0)
609607

610608
nvshmem_finalize()
@@ -616,28 +614,31 @@ def _pplx_moe(
616614
# @pytest.mark.parametrize("e", NUM_EXPERTS)
617615
# @pytest.mark.parametrize("topk", TOP_KS)
618616
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
619-
@pytest.mark.parametrize("m", [128]) ##, 32]) #, 1024 * 128])
617+
@pytest.mark.parametrize("m", [64]) ##, 32]) #, 1024 * 128])
620618
@pytest.mark.parametrize("n", [128])
621619
@pytest.mark.parametrize("k", [128])
622620
@pytest.mark.parametrize("e", [8]) #NUM_EXPERTS)
623621
@pytest.mark.parametrize("topk", [2]) #TOP_KS)
624622
@pytest.mark.parametrize("dtype", [torch.bfloat16])
623+
@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [4, 2]])
625624
def test_pplx_moe(
626625
m: int,
627626
n: int,
628627
k: int,
629628
e: int,
630629
topk: int,
631630
dtype: torch.dtype,
631+
world_dp_size: Tuple[int, int],
632632
):
633633
current_platform.seed_everything(7)
634-
if False:
635-
world_size = 4
636-
dp_size = 2
637-
else:
638-
world_size = 2
639-
dp_size = 1
634+
world_size, dp_size = world_dp_size
635+
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
636+
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
637+
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
638+
score = torch.randn((m, e), device="cuda", dtype=dtype)
639+
640640
parallel_launch(
641-
world_size, _pplx_moe, dp_size, m, n, k, e, topk, dtype
641+
world_size, _pplx_moe, dp_size, a, w1, w2, score, topk, dtype
642+
#world_size, _pplx_moe, dp_size, m, n, k, e, topk, dtype
642643
)
643644

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1858,7 +1858,7 @@ def workspace_shapes(
18581858
a: torch.Tensor,
18591859
) -> Tuple[int, int, torch.dtype]:
18601860
max_num_tokens = a.shape[1]
1861-
workspace13 = num_experts * max_num_tokens * K
1861+
workspace13 = num_experts * max_num_tokens * K * 2 # *2 = HACK!!!!!
18621862
workspace2 = max_num_tokens * (N // 2)
18631863
return (workspace13, workspace2, a_dtype)
18641864

@@ -1889,7 +1889,8 @@ def apply(
18891889
print(f"global_num_experts = {global_num_experts}")
18901890
num_experts = global_num_experts
18911891
out = _resize_cache(workspace13, (num_experts, max_num_tokens, w2.shape[1]))
1892-
for expert in range(num_experts):
1892+
num_local_experts = expert_num_tokens.numel()
1893+
for expert in range(num_local_experts): # num_experts
18931894
num = expert_num_tokens[expert]
18941895
if num > 0:
18951896
tmp = _resize_cache(workspace2, (num, w1.shape[1] // 2))

vllm/model_executor/layers/fused_moe/modular_kernel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ def forward(
323323
if global_num_experts == -1:
324324
global_num_experts = E
325325

326-
output = a1 if inplace else torch.zeros_like(a1)
326+
output = a1 if inplace else torch.empty_like(a1)
327327

328328
workspace13_shape, workspace2_shape, workspace_dtype = (
329329
self.fused_experts.workspace_shapes(a1, M, N, K, top_k,

vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -78,15 +78,15 @@ def dispatch(
7878
#expert_num_tokens.fill_(-1) # debugging, remove later
7979

8080
num_dp = self.world_size // self.dp_size
81-
print(f"GOT HERE A {self.rank}: {self.max_num_tokens} {num_dp} {hidden_dim}")
81+
logger.debug(f"GOT HERE A {self.rank}: {self.max_num_tokens} {num_dp} {hidden_dim}")
8282
expert_x = torch.empty(
8383
(num_local_experts, self.max_num_tokens * num_dp, a1q.shape[-1]),
8484
dtype=a1q.dtype,
8585
device=device,
8686
)
8787
expert_x.fill_(torch.nan) # debugging, remove later
8888

89-
print(f"GOT HERE B {self.rank}")
89+
logger.debug(f"GOT HERE B {self.rank}")
9090

9191
expert_x_scale: Optional[torch.Tensor] = None
9292
if a1q.dtype.itemsize == 1:
@@ -103,7 +103,7 @@ def dispatch(
103103
device=device,
104104
)
105105

106-
print(f"GOT HERE C {self.rank}")
106+
logger.debug(f"GOT HERE C {self.rank}")
107107

108108
# This argument is optional, defaults to indices.shape[0]
109109
# This causes a deadlock????
@@ -114,8 +114,6 @@ def dispatch(
114114
# TODO: optimize this?
115115
indices = rank_topk_ids.to(dtype=torch.uint32)
116116

117-
print(f"GOT HERE D {self.rank}")
118-
119117
self.a2a.dispatch(
120118
out_expert_num_tokens=expert_num_tokens,
121119
out_expert_x=expert_x,
@@ -140,7 +138,7 @@ def combine(
140138
#device = get_dp_group().device
141139
#assert fused_expert_output.device == device
142140

143-
print(f"COMBINE START {self.rank}")
141+
logger.debug(f"COMBINE START {self.rank}")
144142

145143
# This argument is optional
146144
#bound_m = get_forward_context().dp_metadata.dp_rank_num_tokens
@@ -161,4 +159,4 @@ def combine(
161159
expert_y=fused_expert_output,
162160
bound_m=bound_m)
163161

164-
print(f"COMBINE END {self.rank}")
162+
logger.debug(f"COMBINE END {self.rank}")

0 commit comments

Comments
 (0)