Skip to content

Commit e0560d5

Browse files
committed
wip
Signed-off-by: Bill Nell <[email protected]>
1 parent 4fb31ef commit e0560d5

File tree

3 files changed

+21
-7
lines changed

3 files changed

+21
-7
lines changed

tests/kernels/test_pplx_moe.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -535,14 +535,14 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk):
535535

536536
dispatch_combine = PplxDispatchCombine(
537537
ata,
538-
max_num_tokens, # // world_size?
538+
max_num_tokens,
539539
pgi.world_size,
540540
dp_size,
541541
rank,
542542
a.dtype,
543543
)
544544

545-
experts = BatchedExperts(max_num_tokens, rank)
545+
experts = BatchedExperts(rank, pgi.world_size, max_num_tokens)
546546

547547
fused_experts = FusedMoEModularKernel(
548548
dispatch_combine,
@@ -560,6 +560,8 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk):
560560
# Chunking weights like this only works for batched format
561561
chunk_by_rank(w1, rank, world_size),
562562
chunk_by_rank(w2, rank, world_size),
563+
#w1,
564+
#w2,
563565
chunk_topk_weight,
564566
chunk_topk_ids,
565567
global_num_experts=num_experts #? num_local_experts?

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1827,12 +1827,18 @@ def combine(
18271827
#print(f"END COMBINE {hex(id(self))}")
18281828

18291829

1830+
def rank_chunk(num, r, w):
1831+
rem = num % w
1832+
return (num // w) + (1 if r < rem else 0)
1833+
1834+
18301835
class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
18311836

18321837
def __init__(
18331838
self,
1834-
max_num_tokens: Optional[int] = None,
18351839
rank: int = 0,
1840+
world_size: int = 1,
1841+
max_num_tokens: Optional[int] = None,
18361842
use_fp8_w8a8: bool = False,
18371843
use_int8_w8a16: bool = False,
18381844
use_int4_w4a16: bool = False,
@@ -1847,6 +1853,7 @@ def __init__(
18471853
assert block_m is None
18481854
self.max_num_tokens = max_num_tokens
18491855
self.rank = rank
1856+
self.world_size = world_size
18501857

18511858
def workspace_shapes(
18521859
self,
@@ -1895,14 +1902,19 @@ def apply(
18951902
num_local_experts = expert_num_tokens.numel()
18961903
#print(f"shapes = {hidden_states.shape}, {w1.shape}, {w2.shape}, {out.shape} {expert_num_tokens.shape} {workspace2.shape} {num_experts}")
18971904

1905+
# TODO: don't need world_size or rank if expert_base always == 0
1906+
#assert w1.shape[0] == num_experts, f"{w1.shape} == {num_experts}"
1907+
#expert_base = rank_chunk(w1.shape[0], self.rank, self.world_size) * self.rank
1908+
expert_base = 0
1909+
18981910
for expert in range(num_local_experts): # num_experts
18991911
num = expert_num_tokens[expert]
19001912
assert num <= max_num_tokens, f"{num}, {max_num_tokens}"
19011913
#print(f"{type(num)}, {num}, {max_num_tokens}")
19021914
if num > 0:
19031915
tmp = _resize_cache(workspace2, (num, w1.shape[1] // 2))
1904-
self.activation(activation, tmp, hidden_states[expert,:num,:] @ w1[expert].transpose(0, 1))
1905-
out[expert, :num, :] = tmp @ w2[expert].transpose(0, 1)
1916+
self.activation(activation, tmp, hidden_states[expert,:num,:] @ w1[expert_base + expert].transpose(0, 1))
1917+
out[expert, :num, :] = tmp @ w2[expert_base + expert].transpose(0, 1)
19061918

19071919
return out
19081920

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,8 +246,8 @@ def set_dispatch_combine(self, dispatch_combine: FusedMoEQuantizeDispatchCombine
246246
#print(f"block_m = {block_m}")
247247

248248
if isinstance(dispatch_combine, (BatchedDispatchCombine, PplxDispatchCombine)):
249-
logger.info("BatchedExperts")
250-
experts = BatchedExperts()
249+
logger.info(f"BatchedExperts {self.moe}")
250+
experts = BatchedExperts() #rank=self.moe.ep_rank, world_size=self.moe.ep_size)
251251
else:
252252
experts = TritonExperts(
253253
use_fp8_w8a8 = False,

0 commit comments

Comments
 (0)