Skip to content

Commit 4fb31ef

Browse files
committed
some cleanup
Signed-off-by: Bill Nell <[email protected]>
1 parent 58fe406 commit 4fb31ef

File tree

3 files changed

+23
-42
lines changed

3 files changed

+23
-42
lines changed

tests/kernels/test_pplx_moe.py

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -299,10 +299,13 @@ def test_fused_moe_batched_experts(
299299
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
300300

301301

302+
def rank_chunk(num, r, w):
303+
rem = num % w
304+
return (num // w) + (1 if r < rem else 0)
305+
306+
302307
def chunk_by_rank(t, r, w):
303-
num = t.shape[0]
304-
assert num % w == 0, f"{num}, {w}" # for now
305-
chunk = num // w
308+
chunk = rank_chunk(t.shape[0], r, w)
306309
#print(f"chunk {t.shape}, {w}, {r}, {chunk}, {r*chunk}:{(r + 1)*chunk}")
307310
return t[(r * chunk):(r + 1)*chunk]
308311

@@ -312,12 +315,11 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk):
312315

313316
num_tokens, hidden_dim = a.shape
314317
num_experts = w1.shape[0]
315-
num_local_experts = w1.shape[0] // pgi.world_size
316318
block_size = 128
317319
device = pgi.device
318-
rank_num_tokens = num_tokens // pgi.world_size
319320
rank = pgi.rank
320321
world_size = pgi.world_size
322+
rank_num_tokens = rank_chunk(num_tokens, rank, world_size)
321323
max_num_tokens = num_tokens
322324
#print(f"device = {device}, max_num_tokens = {max_num_tokens}, topk = {topk}, num_ex = {num_experts}, dp_size = {dp_size}")
323325

@@ -354,7 +356,7 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk):
354356
score_chunk = chunk_by_rank(scores, rank, world_size).to(device)
355357
chunk_topk_weight, chunk_topk_ids = fused_topk(a_chunk, score_chunk, topk, False)
356358

357-
print(f"chunk_topk_ids = {chunk_topk_ids.view(-1)}")
359+
#print(f"chunk_topk_ids = {chunk_topk_ids.view(-1)}")
358360

359361
b_a, b_a_scale, expert_num_tokens = dispatch_combine.dispatch(
360362
a_chunk,
@@ -372,8 +374,8 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk):
372374
#max_num = tokens_per_expert.max()
373375
tokens_per_expert = chunk_by_rank(tokens_per_expert, rank, world_size).to(dtype=torch.int32)
374376

375-
print(f"tpe {tokens_per_expert}")
376-
print(f"ent {expert_num_tokens}")
377+
#print(f"tpe {tokens_per_expert}")
378+
#print(f"ent {expert_num_tokens}")
377379

378380
#torch.set_printoptions(profile="full")
379381
#torch.distributed.all_reduce(naive_b_a, op=torch.distributed.ReduceOp.MAX)
@@ -501,15 +503,12 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk):
501503

502504
num_tokens, hidden_dim = a.shape
503505
num_experts = w1.shape[0]
504-
num_local_experts = num_experts // pgi.world_size
505506
block_size = 128
506507
device = pgi.device
507-
rank_num_tokens = num_tokens // pgi.world_size # TODO even divide
508-
509-
max_num_tokens = num_tokens
510-
#print(f"device = {device}, max_num_tokens = {max_num_tokens}, topk = {topk}, num_ex = {num_experts}, dp_size = {dp_size}")
511508
rank = pgi.rank
512509
world_size = pgi.world_size
510+
rank_num_tokens = rank_chunk(num_tokens, rank, world_size)
511+
max_num_tokens = num_tokens
513512

514513
ata = AllToAll(
515514
max_num_tokens=max_num_tokens,
@@ -558,6 +557,7 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk):
558557

559558
out = fused_experts(
560559
a_chunk,
560+
# Chunking weights like this only works for batched format
561561
chunk_by_rank(w1, rank, world_size),
562562
chunk_by_rank(w2, rank, world_size),
563563
chunk_topk_weight,
@@ -571,7 +571,7 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk):
571571

572572
#print(f"OUT {rank}: {out.shape} {out}")
573573

574-
return out[:rank_num_tokens] # chunk_by_rank?
574+
return out[:rank_num_tokens]
575575

576576

577577
def _pplx_moe(
@@ -624,18 +624,13 @@ def _pplx_moe(
624624
nvshmem_finalize()
625625

626626

627-
@pytest.mark.parametrize("m", [2, 32, 64, 222]) #, 1024 * 128])
628-
@pytest.mark.parametrize("n", [128, 1024, 2048])
629-
@pytest.mark.parametrize("k", [128, 512, 1024])
627+
# TODO: M == 1 doesn't work
628+
@pytest.mark.parametrize("m", [2, 3, 32, 45, 64, 222]) #, 1024 * 128])
629+
@pytest.mark.parametrize("n", [128, 1024])# , 2048])
630+
@pytest.mark.parametrize("k", [128, 512]) # , 1024])
630631
@pytest.mark.parametrize("e", NUM_EXPERTS)
631632
@pytest.mark.parametrize("topk", TOP_KS)
632633
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
633-
# @pytest.mark.parametrize("m", [64]) ##, 32]) #, 1024 * 128])
634-
# @pytest.mark.parametrize("n", [128])
635-
# @pytest.mark.parametrize("k", [128])
636-
# @pytest.mark.parametrize("e", [8]) #NUM_EXPERTS)
637-
# @pytest.mark.parametrize("topk", [2]) #TOP_KS)
638-
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
639634
@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [4, 2]])
640635
def test_pplx_moe(
641636
m: int,

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1777,9 +1777,6 @@ def dispatch(
17771777
num_tokens = a1.shape[0]
17781778
topk = topk_ids.shape[1]
17791779

1780-
#assert num_experts % self.world_size == 0
1781-
#num_local_experts = num_experts // self.world_size
1782-
17831780
tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts)
17841781
max_num_tokens = tokens_per_expert.max()
17851782
expert_counts = torch.zeros(num_experts, dtype=torch.int, device=a1.device)
@@ -1892,31 +1889,20 @@ def apply(
18921889
num_tokens, topk = topk_ids.shape
18931890
_, tmp_max_num_tokens, K = hidden_states.shape
18941891
max_num_tokens = tmp_max_num_tokens if self.max_num_tokens is None else self.max_num_tokens
1895-
print(f"global_num_experts = {global_num_experts}")
1892+
#print(f"global_num_experts = {global_num_experts}")
18961893
num_experts = global_num_experts
18971894
out = _resize_cache(workspace13, (num_experts, max_num_tokens, w2.shape[1]))
18981895
num_local_experts = expert_num_tokens.numel()
1899-
#assert num_local_experts >= topk_ids.view(-1).max()
1900-
#print(f"apply a={hidden_states}")
1901-
#print(f"apply topk={topk_ids}")
1902-
#print(f"apply num_tokens={expert_num_tokens}")
1896+
#print(f"shapes = {hidden_states.shape}, {w1.shape}, {w2.shape}, {out.shape} {expert_num_tokens.shape} {workspace2.shape} {num_experts}")
19031897

19041898
for expert in range(num_local_experts): # num_experts
19051899
num = expert_num_tokens[expert]
1906-
assert num <= max_num_tokens
1900+
assert num <= max_num_tokens, f"{num}, {max_num_tokens}"
1901+
#print(f"{type(num)}, {num}, {max_num_tokens}")
19071902
if num > 0:
19081903
tmp = _resize_cache(workspace2, (num, w1.shape[1] // 2))
19091904
self.activation(activation, tmp, hidden_states[expert,:num,:] @ w1[expert].transpose(0, 1))
19101905
out[expert, :num, :] = tmp @ w2[expert].transpose(0, 1)
1911-
# fill remainder with 0???
1912-
#out[expert, num:, :].fill_(0)
1913-
else:
1914-
#out[expert, :, :].fill_(0) # ??
1915-
pass
1916-
1917-
#print("END EXPERTS")
1918-
1919-
#print(f"apply out={out}")
19201906

19211907
return out
19221908

vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def dispatch(
8484
dtype=a1q.dtype,
8585
device=device,
8686
)
87-
expert_x.fill_(0) #torch.nan # debugging, remove later
87+
#expert_x.fill_(0) #torch.nan # debugging, remove later
8888

8989
logger.debug(f"GOT HERE B {self.rank}")
9090

0 commit comments

Comments
 (0)