Skip to content

Commit 3b72bc5

Browse files
committed
wip ref impl
Signed-off-by: Bill Nell <[email protected]>
1 parent aaefc27 commit 3b72bc5

File tree

7 files changed

+126
-42
lines changed

7 files changed

+126
-42
lines changed

csrc/activation_kernels.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
7070
int64_t num_tokens = input.numel() / input.size(-1); \
7171
dim3 grid(num_tokens); \
7272
dim3 block(std::min(d, 1024)); \
73+
if (num_tokens == 0) { return; } \
7374
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
7475
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
7576
VLLM_DISPATCH_FLOATING_TYPES( \

examples/offline_inference/data_parallel.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,11 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
114114

115115
# Create an LLM.
116116
cconfig = CompilationConfig(
117-
level=0,
117+
level=3,
118118
#cudagraph_capture_sizes=[512,504,496,488,480,472,464,456,448,440,432,424,416,408,400,392,384,376,368,360,352,344,336,328,320,312,304,296,288,280,272,264,256,248,240,232,224,216,208],
119119
#cudagraph_capture_sizes=[512,256,1],
120+
#cudagraph_capture_sizes=[192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1]
121+
#cudagraph_capture_sizes=[128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1]
120122
)
121123
llm = LLM(model=model,
122124
tensor_parallel_size=GPUs_per_dp_rank,
@@ -171,7 +173,7 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
171173
procs.append(proc)
172174
exit_code = 0
173175
for proc in procs:
174-
proc.join(timeout=300)
176+
proc.join(timeout=3000)
175177
if proc.exitcode is None:
176178
print(f"Killing process {proc.pid} that "
177179
f"didn't stop within 5 minutes.")

tests/kernels/moe/test_pplx_moe.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,50 @@ def pplx_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids):
515515
return out
516516

517517

518+
def _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids):
519+
assert torch.cuda.current_device() == pgi.local_rank
520+
521+
hidden_dim = a.shape[1]
522+
num_experts = w1.shape[0]
523+
block_size = 128
524+
device = pgi.device
525+
rank = pgi.rank
526+
world_size = pgi.world_size
527+
topk = topk_ids.shape[1]
528+
max_num_tokens = rank_chunk(a.shape[0], 0, world_size)
529+
530+
dispatch_combine = BatchedDispatchCombine(
531+
max_num_tokens=max_num_tokens,
532+
world_size=world_size,
533+
dp_size=dp_size,
534+
rank=rank,
535+
)
536+
537+
experts = BatchedExperts(a.shape[0])
538+
539+
fused_experts = FusedMoEModularKernel(
540+
dispatch_combine,
541+
experts,
542+
)
543+
544+
# TODO: workers with the same dp_rank must use the exact same inputs.
545+
546+
a_chunk = chunk_by_rank(a, rank, world_size).to(device)
547+
chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device)
548+
chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device)
549+
550+
out = fused_experts(
551+
a_chunk,
552+
# Chunking weights like this only works for batched format
553+
chunk_by_rank(w1, rank, world_size).to(device),
554+
chunk_by_rank(w2, rank, world_size).to(device),
555+
chunk_topk_weight,
556+
chunk_topk_ids,
557+
global_num_experts=num_experts)
558+
559+
return out
560+
561+
518562
def _pplx_moe(
519563
pgi: ProcessGroupInfo,
520564
dp_size: int,
@@ -536,11 +580,13 @@ def _pplx_moe(
536580
topk_weight, topk_ids = fused_topk(a, score, topk, False)
537581
torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids)
538582
pplx_output = pplx_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids)
583+
batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids)
539584

540585
torch_output = chunk_by_rank(torch_output, pgi.rank,
541586
pgi.world_size).to(pplx_output.device)
542587

543588
torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0)
589+
torch.testing.assert_close(batched_output, torch_output, atol=2e-2, rtol=0)
544590

545591
nvshmem_finalize()
546592

vllm/model_executor/layers/fused_moe/fused_batched_moe.py

Lines changed: 50 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,11 @@ def invoke_batched_silu_and_mul(
466466
compute_tl_dtype, D, BLOCK_M, BLOCK_D)
467467

468468

469+
def rank_chunk(num, r, w):
470+
rem = num % w
471+
return (num // w) + (1 if r < rem else 0)
472+
473+
469474
class BatchedDispatchCombine(mk.FusedMoEQuantizeDispatchCombine):
470475

471476
def __init__(self, max_num_tokens: Optional[int], world_size: int, dp_size: int, rank: int):
@@ -505,20 +510,31 @@ def dispatch(
505510
if self.max_num_tokens is None:
506511
self.max_num_tokens = int(tokens_per_expert.max().item())
507512

508-
b_a1 = torch.zeros((num_experts, self.max_num_tokens, hidden_dim),
513+
rem_experts = num_experts % self.world_size
514+
num_local_experts = ((num_experts // self.world_size) +
515+
(1 if self.rank < rem_experts else 0))
516+
517+
b_a1 = torch.zeros((num_local_experts, self.max_num_tokens, hidden_dim),
509518
dtype=a1.dtype,
510519
device=a1.device)
511520

512-
token_counts = torch.zeros(num_experts,
521+
token_counts = torch.zeros(num_local_experts,
513522
dtype=torch.int,
514523
device=a1.device)
515524

525+
first_expert = (((num_experts // self.world_size) * self.rank) +
526+
rem_experts - self.rank)
527+
last_expert = first_expert + num_local_experts
528+
#expert_id_range = range(first_expert, last_expert)
529+
516530
for token in range(num_tokens):
517531
for j in range(topk):
518532
expert_id = topk_ids[token, j]
519-
idx = token_counts[expert_id]
520-
b_a1[expert_id, idx:idx + 1, :] = a1[token, :]
521-
token_counts[expert_id] = token_counts[expert_id] + 1
533+
if expert_id >= first_expert and expert_id < last_expert:
534+
rel_index = expert_id - first_expert
535+
idx = token_counts[rel_index]
536+
b_a1[rel_index, idx:idx + 1, :] = a1[token, :]
537+
token_counts[rel_index] = token_counts[rel_index] + 1
522538

523539
return b_a1, a1_scale, tokens_per_expert
524540

@@ -531,7 +547,8 @@ def combine(
531547
apply_router_weight_on_input: bool,
532548
) -> None:
533549
num_tokens = topk_ids.shape[0]
534-
num_experts = fused_expert_output.shape[0]
550+
num_local_experts = fused_expert_output.shape[0]
551+
num_experts = num_local_experts * self.world_size # NOT QUITE RIGHT
535552
K = fused_expert_output.shape[-1]
536553
assert output.shape[0] == num_tokens and output.shape[1] == K
537554
expert_counts = torch.zeros(
@@ -541,17 +558,21 @@ def combine(
541558

542559
output.fill_(0)
543560

561+
first_expert = num_local_experts * self.rank # NOT QUITE RIGHT
562+
last_expert = first_expert + num_local_experts
563+
544564
for token in range(num_tokens):
545565
expert_ids = topk_ids[token]
546566
for i in range(expert_ids.numel()):
547567
expert_id = expert_ids[i]
548-
assert expert_id < num_experts
549-
idx = expert_counts[expert_id]
550-
accum = fused_expert_output[expert_id, idx:idx + 1, :]
551-
if not apply_router_weight_on_input:
552-
accum = accum * topk_weights[token, i]
553-
output[token, :] = output[token, :] + accum
554-
expert_counts[expert_id] = expert_counts[expert_id] + 1
568+
if expert_id >= first_expert and expert_id < last_expert:
569+
assert expert_id < num_experts
570+
idx = expert_counts[expert_id]
571+
accum = fused_expert_output[expert_id - first_expert, idx:idx + 1, :]
572+
if not apply_router_weight_on_input:
573+
accum = accum * topk_weights[token, i]
574+
output[token, :] = output[token, :] + accum
575+
expert_counts[expert_id] = expert_counts[expert_id] + 1
555576

556577

557578
class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
@@ -622,20 +643,26 @@ def apply(
622643
num_experts = global_num_experts
623644
out = _resize_cache(workspace13,
624645
(num_experts, max_num_tokens, hidden_dim))
625-
num_local_experts = expert_num_tokens.numel()
626-
assert num_local_experts == w1.shape[0]
646+
num_local_experts = w1.shape[0] #expert_num_tokens.numel()
647+
assert num_local_experts == w1.shape[0], f"{num_local_experts} == {w1.shape[0]}"
627648

628649
N = w1.shape[1] // 2
629650

651+
# Not cudagraph friendly
652+
assert (torch.cuda.is_current_stream_capturing() or
653+
torch.all(expert_num_tokens <= max_num_tokens)), (
654+
f"{expert_num_tokens} <= {max_num_tokens}")
655+
630656
for expert in range(num_local_experts):
631-
num = expert_num_tokens[expert].item()
632-
assert num <= max_num_tokens, f"{num} <= {max_num_tokens}"
633-
if num > 0: # CUDAGRAPH unfriendly
634-
tmp = _resize_cache(workspace2, (num, N))
635-
input = hidden_states[expert, :num, :] @ w1[expert].transpose(0, 1)
636-
assert input.shape[1] == N * 2
637-
self.activation(activation, tmp, input)
638-
out[expert, :num, :] = tmp @ w2[expert].transpose(0, 1)
657+
# Indexing expert_num_tokens doesn't work w/cudagraphs
658+
if torch.cuda.is_current_stream_capturing():
659+
num = max_num_tokens
660+
else:
661+
num = int(expert_num_tokens[expert].item())
662+
tmp = _resize_cache(workspace2, (num, N))
663+
input = hidden_states[expert, :num, :] @ w1[expert].transpose(0, 1)
664+
self.activation(activation, tmp, input)
665+
out[expert, :num, :] = tmp @ w2[expert].transpose(0, 1)
639666

640667
return out
641668

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -870,7 +870,7 @@ def fused_topk(
870870
gating_output: torch.Tensor,
871871
topk: int,
872872
renormalize: bool,
873-
indices_type: torch.dtype = torch.int32,
873+
indices_type: Optional[torch.dtype] = None,
874874
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
875875
assert hidden_states.shape[0] == gating_output.shape[0], (
876876
"Number of tokens mismatch")
@@ -881,10 +881,12 @@ def fused_topk(
881881
topk,
882882
dtype=torch.float32,
883883
device=hidden_states.device)
884-
topk_ids = torch.empty(M,
885-
topk,
886-
dtype=indices_type,
887-
device=hidden_states.device)
884+
topk_ids = torch.empty(
885+
M,
886+
topk,
887+
dtype=torch.int32 if indices_type is None else indices_type,
888+
device=hidden_states.device
889+
)
888890
token_expert_indicies = torch.empty(M,
889891
topk,
890892
dtype=torch.int32,

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def get_or_create(self, **kwargs):
136136

137137
with self._lock:
138138
instance = self._cache.get(key)
139-
if instance is None:
139+
if True or instance is None:
140140
# TODO: should be intranode
141141
instance = pplx.AllToAll.internode(**kwargs)
142142
self._cache[key] = instance
@@ -272,6 +272,8 @@ def set_dispatch_combine(
272272

273273
experts: Optional[FusedMoEPermuteExpertsUnpermute] = None
274274

275+
self.using_pplx = False
276+
275277
if isinstance(dispatch_combine,
276278
(BatchedDispatchCombine, PplxDispatchCombine)):
277279
logger.debug("BatchedTritonExperts %s", self.moe)
@@ -283,6 +285,7 @@ def set_dispatch_combine(
283285
use_int4_w4a16=False,
284286
block_shape=None,
285287
)
288+
self.using_pplx = isinstance(dispatch_combine, PplxDispatchCombine)
286289
else:
287290
logger.debug("TritonExperts %s", self.moe)
288291
experts = TritonExperts(
@@ -329,7 +332,8 @@ def forward_cuda(
329332
num_expert_group=num_expert_group,
330333
custom_routing_function=custom_routing_function,
331334
scoring_func=scoring_func,
332-
e_score_correction_bias=e_score_correction_bias)
335+
e_score_correction_bias=e_score_correction_bias,
336+
indices_type=torch.uint32 if self.using_pplx else None)
333337

334338
if self.rocm_aiter_moe_enabled:
335339
return self.rocm_aiter_fused_experts(
@@ -687,7 +691,7 @@ def _construct_dispatch_combine(
687691
max_num_tokens = MOE_DP_CHUNK_SIZE
688692
world_size = moe.ep_size
689693

690-
if False and self.dp_size > 1 and has_pplx:
694+
if self.dp_size > 1 and has_pplx:
691695
logger.debug("using pplx dispatch")
692696
dp_size = moe.ep_size // moe.dp_size # dp_size actually means TP.
693697
rank = moe.ep_rank
@@ -1020,13 +1024,16 @@ def select_experts(hidden_states: torch.Tensor,
10201024
num_expert_group: Optional[int] = None,
10211025
custom_routing_function: Optional[Callable] = None,
10221026
scoring_func: str = "softmax",
1023-
e_score_correction_bias: Optional[torch.Tensor] = None):
1024-
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
1027+
e_score_correction_bias: Optional[torch.Tensor] = None,
1028+
indices_type: Optional[torch.dtype] = None):
1029+
from vllm.model_executor.layers.fused_moe.fused_moe import (
1030+
fused_topk, grouped_topk)
10251031

10261032
# DeekSeekv2 uses grouped_top_k
10271033
if use_grouped_topk:
10281034
assert topk_group is not None
10291035
assert num_expert_group is not None
1036+
assert indices_type is None or indices_type == torch.int32
10301037
topk_weights, topk_ids = grouped_topk(
10311038
hidden_states=hidden_states,
10321039
gating_output=router_logits,
@@ -1041,10 +1048,10 @@ def select_experts(hidden_states: torch.Tensor,
10411048
gating_output=router_logits,
10421049
topk=top_k,
10431050
renormalize=renormalize,
1044-
# XXXXX how to do this?
1045-
#indices_type=torch.uint32,
1051+
indices_type=indices_type,
10461052
)
10471053
else:
1054+
assert indices_type is None or indices_type == torch.int32
10481055
topk_weights, topk_ids = custom_routing_function(
10491056
hidden_states=hidden_states,
10501057
gating_output=router_logits,

vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,12 +105,10 @@ def dispatch(
105105
)
106106

107107
# This argument is optional, defaults to indices.shape[0]
108+
# There's not much point setting this unless it is != indices.shape[0]
108109
#bound_m = torch.tensor([num_tokens], dtype=torch.uint32, device=device)
109110
bound_m = None
110111

111-
# TODO: optimize this?
112-
#indices = rank_topk_ids.to(dtype=torch.uint32)
113-
114112
self.a2a.dispatch(
115113
out_expert_num_tokens=expert_num_tokens,
116114
out_expert_x=expert_x,
@@ -130,14 +128,15 @@ def combine(
130128
topk_ids: torch.Tensor,
131129
apply_router_weight_on_input: bool,
132130
) -> None:
133-
# This argument is optional
134131
num_tokens = output.shape[0] # M
132+
# This argument is optional
133+
# There's not much point setting this unless it is != topk_ids.shape[0]
135134
#bound_m = torch.tensor([num_tokens],
136135
# dtype=torch.uint32,
137136
# device=fused_expert_output.device)
138137
bound_m = None
139138

140-
assert topk_ids.shape[0] <= num_tokens
139+
assert topk_ids.shape[0] == num_tokens
141140
assert output.shape[0] <= self.max_num_tokens, \
142141
f"{output.shape[0]} <= {self.max_num_tokens}"
143142
assert output.shape[1] == fused_expert_output.shape[-1]

0 commit comments

Comments
 (0)