Skip to content

Commit 58fe406

Browse files
committed
fix test
Signed-off-by: Bill Nell <[email protected]>
1 parent 86c2055 commit 58fe406

File tree

2 files changed

+39
-25
lines changed

2 files changed

+39
-25
lines changed

tests/kernels/test_pplx_moe.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import traceback
1111

1212
from torch.multiprocessing import spawn # pyright: ignore[reportPrivateImportUsage]
13-
from typing import Callable, Concatenate, ParamSpec, Tuple
13+
from typing import Callable, Concatenate, Optional, ParamSpec, Tuple
1414

1515
from pplx_kernels import AllToAll
1616
from pplx_kernels.nvshmem import (
@@ -163,7 +163,8 @@ def parallel_launch_from_env(
163163
def torch_dispatch(
164164
a: torch.Tensor,
165165
topk_ids: torch.Tensor,
166-
num_experts: int
166+
num_experts: int,
167+
max_num_tokens: Optional[int] = None,
167168
) -> Tuple[torch.Tensor, torch.Tensor]:
168169
assert topk_ids.dim() == 2
169170
assert topk_ids.shape[0] == a.shape[0]
@@ -172,7 +173,8 @@ def torch_dispatch(
172173
topk = topk_ids.shape[1]
173174

174175
tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts)
175-
max_num_tokens = tokens_per_expert.max()
176+
if max_num_tokens is None:
177+
max_num_tokens = tokens_per_expert.max()
176178

177179
b_a = torch.zeros((num_experts, max_num_tokens, a.shape[1]),
178180
dtype=a.dtype, device=a.device)
@@ -314,11 +316,10 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk):
314316
block_size = 128
315317
device = pgi.device
316318
rank_num_tokens = num_tokens // pgi.world_size
317-
318-
max_num_tokens = num_tokens
319-
#print(f"device = {device}, max_num_tokens = {max_num_tokens}, topk = {topk}, num_ex = {num_experts}, dp_size = {dp_size}")
320319
rank = pgi.rank
321320
world_size = pgi.world_size
321+
max_num_tokens = num_tokens
322+
#print(f"device = {device}, max_num_tokens = {max_num_tokens}, topk = {topk}, num_ex = {num_experts}, dp_size = {dp_size}")
322323

323324
ata = AllToAll(
324325
max_num_tokens=max_num_tokens,
@@ -342,7 +343,7 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk):
342343

343344
dispatch_combine = PplxDispatchCombine(
344345
ata,
345-
max_num_tokens, # // world_size?
346+
max_num_tokens,
346347
pgi.world_size,
347348
dp_size,
348349
rank,
@@ -353,7 +354,7 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk):
353354
score_chunk = chunk_by_rank(scores, rank, world_size).to(device)
354355
chunk_topk_weight, chunk_topk_ids = fused_topk(a_chunk, score_chunk, topk, False)
355356

356-
#print(f"chunk_topk_ids = {chunk_topk_ids.view(-1)}")
357+
print(f"chunk_topk_ids = {chunk_topk_ids.view(-1)}")
357358

358359
b_a, b_a_scale, expert_num_tokens = dispatch_combine.dispatch(
359360
a_chunk,
@@ -371,22 +372,25 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk):
371372
#max_num = tokens_per_expert.max()
372373
tokens_per_expert = chunk_by_rank(tokens_per_expert, rank, world_size).to(dtype=torch.int32)
373374

374-
#print(f"tpe {tokens_per_expert}")
375-
#print(f"ent {expert_num_tokens}")
375+
print(f"tpe {tokens_per_expert}")
376+
print(f"ent {expert_num_tokens}")
377+
378+
#torch.set_printoptions(profile="full")
379+
#torch.distributed.all_reduce(naive_b_a, op=torch.distributed.ReduceOp.MAX)
380+
#torch.distributed.broadcast(naive_b_a, src=rank)
376381

377382
#naive_b_a = chunk_by_rank(naive_b_a, rank, world_size)
378383

379-
#torch.set_printoptions(profile="full")
380-
#print("b_a", b_a[:naive_b_a.shape[1]])
381-
#print("naive_b_a", naive_b_a)
384+
#print("b_a", b_a.shape, b_a) #[:, :naive_b_a.shape[1]])
385+
#print("naive_b_a", naive_b_a.shape, naive_b_a)
382386

383387
torch.testing.assert_close(tokens_per_expert, expert_num_tokens, atol=0, rtol=0)
384388
#torch.testing.assert_close(b_a[:, :naive_b_a.shape[1]], naive_b_a, atol=2e-2, rtol=0)
385389

386390
b_a = b_a * 1.5
387391

388392
out = torch.full(
389-
(max_num_tokens, hidden_dim),
393+
(rank_num_tokens * world_size, hidden_dim),
390394
torch.nan,
391395
dtype=a.dtype,
392396
device=device,
@@ -539,7 +543,7 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk):
539543
a.dtype,
540544
)
541545

542-
experts = BatchedExperts()
546+
experts = BatchedExperts(max_num_tokens, rank)
543547

544548
fused_experts = FusedMoEModularKernel(
545549
dispatch_combine,
@@ -554,24 +558,20 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk):
554558

555559
out = fused_experts(
556560
a_chunk,
557-
w1,
558-
w2,
561+
chunk_by_rank(w1, rank, world_size),
562+
chunk_by_rank(w2, rank, world_size),
559563
chunk_topk_weight,
560564
chunk_topk_ids,
561-
global_num_experts=num_local_experts #? num_local_experts?
565+
global_num_experts=num_experts #? num_local_experts?
562566
)
563567

564568
torch.cuda.synchronize()
565569

566570
ata.destroy()
567571

568-
#print(f"OUT {rank}: {out.shape} {out[:rank_num_tokens]}")
569-
570-
#torch.distributed.all_reduce(out)
571-
572572
#print(f"OUT {rank}: {out.shape} {out}")
573573

574-
return out[:rank_num_tokens]
574+
return out[:rank_num_tokens] # chunk_by_rank?
575575

576576

577577
def _pplx_moe(

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1834,6 +1834,8 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
18341834

18351835
def __init__(
18361836
self,
1837+
max_num_tokens: Optional[int] = None,
1838+
rank: int = 0,
18371839
use_fp8_w8a8: bool = False,
18381840
use_int8_w8a16: bool = False,
18391841
use_int4_w4a16: bool = False,
@@ -1846,6 +1848,8 @@ def __init__(
18461848
assert not use_int8_w8a16
18471849
assert block_shape is None
18481850
assert block_m is None
1851+
self.max_num_tokens = max_num_tokens
1852+
self.rank = rank
18491853

18501854
def workspace_shapes(
18511855
self,
@@ -1857,7 +1861,8 @@ def workspace_shapes(
18571861
num_experts: int,
18581862
a: torch.Tensor,
18591863
) -> Tuple[int, int, torch.dtype]:
1860-
max_num_tokens = a.shape[1]
1864+
#assert self.max_num_tokens >= a.shape[1]
1865+
max_num_tokens = a.shape[1] if self.max_num_tokens is None else self.max_num_tokens
18611866
workspace13 = num_experts * max_num_tokens * K * topk * 2 # TODO: *2 is a hack
18621867
workspace2 = max_num_tokens * N
18631868
return (workspace13, workspace2, a_dtype)
@@ -1885,13 +1890,20 @@ def apply(
18851890
assert hidden_states.dim() == 3
18861891
assert expert_num_tokens is not None
18871892
num_tokens, topk = topk_ids.shape
1888-
_, max_num_tokens, K = hidden_states.shape
1893+
_, tmp_max_num_tokens, K = hidden_states.shape
1894+
max_num_tokens = tmp_max_num_tokens if self.max_num_tokens is None else self.max_num_tokens
18891895
print(f"global_num_experts = {global_num_experts}")
18901896
num_experts = global_num_experts
18911897
out = _resize_cache(workspace13, (num_experts, max_num_tokens, w2.shape[1]))
18921898
num_local_experts = expert_num_tokens.numel()
1899+
#assert num_local_experts >= topk_ids.view(-1).max()
1900+
#print(f"apply a={hidden_states}")
1901+
#print(f"apply topk={topk_ids}")
1902+
#print(f"apply num_tokens={expert_num_tokens}")
1903+
18931904
for expert in range(num_local_experts): # num_experts
18941905
num = expert_num_tokens[expert]
1906+
assert num <= max_num_tokens
18951907
if num > 0:
18961908
tmp = _resize_cache(workspace2, (num, w1.shape[1] // 2))
18971909
self.activation(activation, tmp, hidden_states[expert,:num,:] @ w1[expert].transpose(0, 1))
@@ -1904,6 +1916,8 @@ def apply(
19041916

19051917
#print("END EXPERTS")
19061918

1919+
#print(f"apply out={out}")
1920+
19071921
return out
19081922

19091923

0 commit comments

Comments
 (0)