Skip to content

Commit 0f2e37a

Browse files
committed
wip ref impl
Signed-off-by: Bill Nell <[email protected]>
1 parent 1938bc8 commit 0f2e37a

File tree

7 files changed

+125
-41
lines changed

7 files changed

+125
-41
lines changed

csrc/activation_kernels.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
7070
int64_t num_tokens = input.numel() / input.size(-1); \
7171
dim3 grid(num_tokens); \
7272
dim3 block(std::min(d, 1024)); \
73+
if (num_tokens == 0) { return; } \
7374
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
7475
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
7576
VLLM_DISPATCH_FLOATING_TYPES( \

examples/offline_inference/data_parallel.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,11 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
114114

115115
# Create an LLM.
116116
cconfig = CompilationConfig(
117-
level=0,
117+
level=3,
118118
#cudagraph_capture_sizes=[512,504,496,488,480,472,464,456,448,440,432,424,416,408,400,392,384,376,368,360,352,344,336,328,320,312,304,296,288,280,272,264,256,248,240,232,224,216,208],
119119
#cudagraph_capture_sizes=[512,256,1],
120+
#cudagraph_capture_sizes=[192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1]
121+
#cudagraph_capture_sizes=[128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1]
120122
)
121123
llm = LLM(model=model,
122124
tensor_parallel_size=GPUs_per_dp_rank,
@@ -171,7 +173,7 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
171173
procs.append(proc)
172174
exit_code = 0
173175
for proc in procs:
174-
proc.join(timeout=300)
176+
proc.join(timeout=3000)
175177
if proc.exitcode is None:
176178
print(f"Killing process {proc.pid} that "
177179
f"didn't stop within 5 minutes.")

tests/kernels/moe/test_pplx_moe.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,50 @@ def pplx_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids):
515515
return out
516516

517517

518+
def _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids):
519+
assert torch.cuda.current_device() == pgi.local_rank
520+
521+
hidden_dim = a.shape[1]
522+
num_experts = w1.shape[0]
523+
block_size = 128
524+
device = pgi.device
525+
rank = pgi.rank
526+
world_size = pgi.world_size
527+
topk = topk_ids.shape[1]
528+
max_num_tokens = rank_chunk(a.shape[0], 0, world_size)
529+
530+
dispatch_combine = BatchedDispatchCombine(
531+
max_num_tokens=max_num_tokens,
532+
world_size=world_size,
533+
dp_size=dp_size,
534+
rank=rank,
535+
)
536+
537+
experts = BatchedExperts(a.shape[0])
538+
539+
fused_experts = FusedMoEModularKernel(
540+
dispatch_combine,
541+
experts,
542+
)
543+
544+
# TODO: workers with the same dp_rank must use the exact same inputs.
545+
546+
a_chunk = chunk_by_rank(a, rank, world_size).to(device)
547+
chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device)
548+
chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device)
549+
550+
out = fused_experts(
551+
a_chunk,
552+
# Chunking weights like this only works for batched format
553+
chunk_by_rank(w1, rank, world_size).to(device),
554+
chunk_by_rank(w2, rank, world_size).to(device),
555+
chunk_topk_weight,
556+
chunk_topk_ids,
557+
global_num_experts=num_experts)
558+
559+
return out
560+
561+
518562
def _pplx_moe(
519563
pgi: ProcessGroupInfo,
520564
dp_size: int,
@@ -536,11 +580,13 @@ def _pplx_moe(
536580
topk_weight, topk_ids = fused_topk(a, score, topk, False)
537581
torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids)
538582
pplx_output = pplx_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids)
583+
batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids)
539584

540585
torch_output = chunk_by_rank(torch_output, pgi.rank,
541586
pgi.world_size).to(pplx_output.device)
542587

543588
torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0)
589+
torch.testing.assert_close(batched_output, torch_output, atol=2e-2, rtol=0)
544590

545591
nvshmem_finalize()
546592

vllm/model_executor/layers/fused_moe/fused_batched_moe.py

Lines changed: 50 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,11 @@ def invoke_batched_silu_and_mul(
466466
compute_tl_dtype, D, BLOCK_M, BLOCK_D)
467467

468468

469+
def rank_chunk(num, r, w):
470+
rem = num % w
471+
return (num // w) + (1 if r < rem else 0)
472+
473+
469474
class BatchedDispatchCombine(mk.FusedMoEQuantizeDispatchCombine):
470475

471476
def __init__(self, max_num_tokens: Optional[int], world_size: int, dp_size: int, rank: int):
@@ -505,20 +510,31 @@ def dispatch(
505510
if self.max_num_tokens is None:
506511
self.max_num_tokens = int(tokens_per_expert.max().item())
507512

508-
b_a1 = torch.zeros((num_experts, self.max_num_tokens, hidden_dim),
513+
rem_experts = num_experts % self.world_size
514+
num_local_experts = ((num_experts // self.world_size) +
515+
(1 if self.rank < rem_experts else 0))
516+
517+
b_a1 = torch.zeros((num_local_experts, self.max_num_tokens, hidden_dim),
509518
dtype=a1.dtype,
510519
device=a1.device)
511520

512-
token_counts = torch.zeros(num_experts,
521+
token_counts = torch.zeros(num_local_experts,
513522
dtype=torch.int,
514523
device=a1.device)
515524

525+
first_expert = (((num_experts // self.world_size) * self.rank) +
526+
rem_experts - self.rank)
527+
last_expert = first_expert + num_local_experts
528+
#expert_id_range = range(first_expert, last_expert)
529+
516530
for token in range(num_tokens):
517531
for j in range(topk):
518532
expert_id = topk_ids[token, j]
519-
idx = token_counts[expert_id]
520-
b_a1[expert_id, idx:idx + 1, :] = a1[token, :]
521-
token_counts[expert_id] = token_counts[expert_id] + 1
533+
if expert_id >= first_expert and expert_id < last_expert:
534+
rel_index = expert_id - first_expert
535+
idx = token_counts[rel_index]
536+
b_a1[rel_index, idx:idx + 1, :] = a1[token, :]
537+
token_counts[rel_index] = token_counts[rel_index] + 1
522538

523539
return b_a1, a1_scale, tokens_per_expert
524540

@@ -531,7 +547,8 @@ def combine(
531547
apply_router_weight_on_input: bool,
532548
) -> None:
533549
num_tokens = topk_ids.shape[0]
534-
num_experts = fused_expert_output.shape[0]
550+
num_local_experts = fused_expert_output.shape[0]
551+
num_experts = num_local_experts * self.world_size # NOT QUITE RIGHT
535552
K = fused_expert_output.shape[-1]
536553
assert output.shape[0] == num_tokens and output.shape[1] == K
537554
expert_counts = torch.zeros(
@@ -541,17 +558,21 @@ def combine(
541558

542559
output.fill_(0)
543560

561+
first_expert = num_local_experts * self.rank # NOT QUITE RIGHT
562+
last_expert = first_expert + num_local_experts
563+
544564
for token in range(num_tokens):
545565
expert_ids = topk_ids[token]
546566
for i in range(expert_ids.numel()):
547567
expert_id = expert_ids[i]
548-
assert expert_id < num_experts
549-
idx = expert_counts[expert_id]
550-
accum = fused_expert_output[expert_id, idx:idx + 1, :]
551-
if not apply_router_weight_on_input:
552-
accum = accum * topk_weights[token, i]
553-
output[token, :] = output[token, :] + accum
554-
expert_counts[expert_id] = expert_counts[expert_id] + 1
568+
if expert_id >= first_expert and expert_id < last_expert:
569+
assert expert_id < num_experts
570+
idx = expert_counts[expert_id]
571+
accum = fused_expert_output[expert_id - first_expert, idx:idx + 1, :]
572+
if not apply_router_weight_on_input:
573+
accum = accum * topk_weights[token, i]
574+
output[token, :] = output[token, :] + accum
575+
expert_counts[expert_id] = expert_counts[expert_id] + 1
555576

556577

557578
class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
@@ -622,20 +643,26 @@ def apply(
622643
num_experts = global_num_experts
623644
out = _resize_cache(workspace13,
624645
(num_experts, max_num_tokens, hidden_dim))
625-
num_local_experts = expert_num_tokens.numel()
626-
assert num_local_experts == w1.shape[0]
646+
num_local_experts = w1.shape[0] #expert_num_tokens.numel()
647+
assert num_local_experts == w1.shape[0], f"{num_local_experts} == {w1.shape[0]}"
627648

628649
N = w1.shape[1] // 2
629650

651+
# Not cudagraph friendly
652+
assert (torch.cuda.is_current_stream_capturing() or
653+
torch.all(expert_num_tokens <= max_num_tokens)), (
654+
f"{expert_num_tokens} <= {max_num_tokens}")
655+
630656
for expert in range(num_local_experts):
631-
num = expert_num_tokens[expert].item()
632-
assert num <= max_num_tokens, f"{num} <= {max_num_tokens}"
633-
if num > 0: # CUDAGRAPH unfriendly
634-
tmp = _resize_cache(workspace2, (num, N))
635-
input = hidden_states[expert, :num, :] @ w1[expert].transpose(0, 1)
636-
assert input.shape[1] == N * 2
637-
self.activation(activation, tmp, input)
638-
out[expert, :num, :] = tmp @ w2[expert].transpose(0, 1)
657+
# Indexing expert_num_tokens doesn't work w/cudagraphs
658+
if torch.cuda.is_current_stream_capturing():
659+
num = max_num_tokens
660+
else:
661+
num = int(expert_num_tokens[expert].item())
662+
tmp = _resize_cache(workspace2, (num, N))
663+
input = hidden_states[expert, :num, :] @ w1[expert].transpose(0, 1)
664+
self.activation(activation, tmp, input)
665+
out[expert, :num, :] = tmp @ w2[expert].transpose(0, 1)
639666

640667
return out
641668

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -870,7 +870,7 @@ def fused_topk(
870870
gating_output: torch.Tensor,
871871
topk: int,
872872
renormalize: bool,
873-
indices_type: torch.dtype = torch.int32,
873+
indices_type: Optional[torch.dtype] = None,
874874
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
875875
assert hidden_states.shape[0] == gating_output.shape[0], (
876876
"Number of tokens mismatch")
@@ -881,10 +881,12 @@ def fused_topk(
881881
topk,
882882
dtype=torch.float32,
883883
device=hidden_states.device)
884-
topk_ids = torch.empty(M,
885-
topk,
886-
dtype=indices_type,
887-
device=hidden_states.device)
884+
topk_ids = torch.empty(
885+
M,
886+
topk,
887+
dtype=torch.int32 if indices_type is None else indices_type,
888+
device=hidden_states.device
889+
)
888890
token_expert_indicies = torch.empty(M,
889891
topk,
890892
dtype=torch.int32,

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def get_or_create(self, **kwargs):
128128

129129
with self._lock:
130130
instance = self._cache.get(key)
131-
if instance is None:
131+
if True or instance is None:
132132
# TODO: should be intranode
133133
instance = pplx.AllToAll.internode(**kwargs)
134134
self._cache[key] = instance
@@ -152,6 +152,7 @@ def __init__(self, moe: MoEConfig):
152152
super().__init__()
153153
self.fused_experts = fused_experts
154154
self.moe = moe
155+
self.using_pplx = False
155156

156157
def create_weights(self, layer: torch.nn.Module, num_experts: int,
157158
hidden_size: int, intermediate_size_per_partition: int,
@@ -256,6 +257,8 @@ def set_dispatch_combine(
256257

257258
experts: Optional[FusedMoEPermuteExpertsUnpermute] = None
258259

260+
self.using_pplx = False
261+
259262
if isinstance(dispatch_combine,
260263
(BatchedDispatchCombine, PplxDispatchCombine)):
261264
logger.debug("BatchedTritonExperts %s", self.moe)
@@ -267,6 +270,7 @@ def set_dispatch_combine(
267270
use_int4_w4a16=False,
268271
block_shape=None,
269272
)
273+
self.using_pplx = isinstance(dispatch_combine, PplxDispatchCombine)
270274
else:
271275
logger.debug("TritonExperts %s", self.moe)
272276
experts = TritonExperts(
@@ -313,7 +317,8 @@ def forward_cuda(
313317
num_expert_group=num_expert_group,
314318
custom_routing_function=custom_routing_function,
315319
scoring_func=scoring_func,
316-
e_score_correction_bias=e_score_correction_bias)
320+
e_score_correction_bias=e_score_correction_bias,
321+
indices_type=torch.uint32 if self.using_pplx else None)
317322

318323
return self.fused_experts(
319324
hidden_states=x,
@@ -661,7 +666,7 @@ def _construct_dispatch_combine(
661666
max_num_tokens = MOE_DP_CHUNK_SIZE
662667
world_size = moe.ep_size
663668

664-
if False and self.dp_size > 1 and has_pplx:
669+
if self.dp_size > 1 and has_pplx:
665670
logger.debug("using pplx dispatch")
666671
dp_size = moe.ep_size // moe.dp_size # dp_size actually means TP.
667672
rank = moe.ep_rank
@@ -977,14 +982,16 @@ def select_experts(hidden_states: torch.Tensor,
977982
num_expert_group: Optional[int] = None,
978983
custom_routing_function: Optional[Callable] = None,
979984
scoring_func: str = "softmax",
980-
e_score_correction_bias: Optional[torch.Tensor] = None):
985+
e_score_correction_bias: Optional[torch.Tensor] = None,
986+
indices_type: Optional[torch.dtype] = None):
981987
from vllm.model_executor.layers.fused_moe.fused_moe import (
982988
fused_topk, grouped_topk)
983989

984990
# DeekSeekv2 uses grouped_top_k
985991
if use_grouped_topk:
986992
assert topk_group is not None
987993
assert num_expert_group is not None
994+
assert indices_type is None or indices_type == torch.int32
988995
topk_weights, topk_ids = grouped_topk(
989996
hidden_states=hidden_states,
990997
gating_output=router_logits,
@@ -999,10 +1006,10 @@ def select_experts(hidden_states: torch.Tensor,
9991006
gating_output=router_logits,
10001007
topk=top_k,
10011008
renormalize=renormalize,
1002-
# XXXXX how to do this?
1003-
#indices_type=torch.uint32,
1009+
indices_type=indices_type,
10041010
)
10051011
else:
1012+
assert indices_type is None or indices_type == torch.int32
10061013
topk_weights, topk_ids = custom_routing_function(
10071014
hidden_states=hidden_states,
10081015
gating_output=router_logits,

vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,12 +105,10 @@ def dispatch(
105105
)
106106

107107
# This argument is optional, defaults to indices.shape[0]
108+
# There's not much point setting this unless it is != indices.shape[0]
108109
#bound_m = torch.tensor([num_tokens], dtype=torch.uint32, device=device)
109110
bound_m = None
110111

111-
# TODO: optimize this?
112-
#indices = rank_topk_ids.to(dtype=torch.uint32)
113-
114112
self.a2a.dispatch(
115113
out_expert_num_tokens=expert_num_tokens,
116114
out_expert_x=expert_x,
@@ -130,14 +128,15 @@ def combine(
130128
topk_ids: torch.Tensor,
131129
apply_router_weight_on_input: bool,
132130
) -> None:
133-
# This argument is optional
134131
num_tokens = output.shape[0] # M
132+
# This argument is optional
133+
# There's not much point setting this unless it is != topk_ids.shape[0]
135134
#bound_m = torch.tensor([num_tokens],
136135
# dtype=torch.uint32,
137136
# device=fused_expert_output.device)
138137
bound_m = None
139138

140-
assert topk_ids.shape[0] <= num_tokens
139+
assert topk_ids.shape[0] == num_tokens
141140
assert output.shape[0] <= self.max_num_tokens, \
142141
f"{output.shape[0]} <= {self.max_num_tokens}"
143142
assert output.shape[1] == fused_expert_output.shape[-1]

0 commit comments

Comments
 (0)