Skip to content

Commit 938c516

Browse files
committed
cleanups
Signed-off-by: Bill Nell <[email protected]>
1 parent c09cefd commit 938c516

File tree

7 files changed

+90
-87
lines changed

7 files changed

+90
-87
lines changed

tests/kernels/moe/test_pplx_moe.py

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
1717
nvshmem_finalize, nvshmem_get_unique_id,
1818
nvshmem_init)
19-
has_pplx = False
19+
has_pplx = True
2020
except ImportError:
2121
has_pplx = False
2222

@@ -46,11 +46,6 @@
4646

4747
P = ParamSpec("P")
4848

49-
require_multi_node = pytest.mark.skipif(
50-
"MASTER_ADDR" not in os.environ,
51-
reason="Requires multi-node environment",
52-
)
53-
5449
requires_pplx = pytest.mark.skipif(
5550
not has_pplx,
5651
reason="Requires PPLX kernels",
@@ -180,6 +175,9 @@ def torch_dispatch(
180175

181176
tokens_per_expert = torch.bincount(topk_ids.view(-1),
182177
minlength=num_experts)
178+
179+
assert tokens_per_expert.numel() == num_experts
180+
183181
if max_num_tokens is None:
184182
max_num_tokens = int(tokens_per_expert.max().item())
185183

@@ -259,7 +257,7 @@ def torch_moe2(a, w1, w2, topk_weight, topk_ids):
259257
topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1)
260258

261259

262-
@pytest.mark.parametrize("m", [1, 33, 64, 222]) #, 1024 * 128])
260+
@pytest.mark.parametrize("m", [1, 33, 64, 222])
263261
@pytest.mark.parametrize("n", [128, 1024, 2048])
264262
@pytest.mark.parametrize("k", [128, 511, 1024])
265263
@pytest.mark.parametrize("e", NUM_EXPERTS)
@@ -309,7 +307,7 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk):
309307
rank = pgi.rank
310308
world_size = pgi.world_size
311309
rank_num_tokens = rank_chunk(num_tokens, rank, world_size)
312-
max_num_tokens = num_tokens
310+
max_num_tokens = max(num_tokens, 1)
313311

314312
ata = AllToAll.internode(
315313
max_num_tokens=max_num_tokens,
@@ -350,22 +348,23 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk):
350348
False,
351349
)
352350

353-
naive_b_a, tokens_per_expert = torch_dispatch(a_chunk, chunk_topk_ids,
354-
num_experts)
351+
if False:
352+
naive_b_a, tokens_per_expert = torch_dispatch(a_chunk, chunk_topk_ids,
353+
num_experts)
355354

356-
torch.distributed.all_reduce(tokens_per_expert)
357-
tokens_per_expert = chunk_by_rank(tokens_per_expert, rank,
358-
world_size).to(dtype=torch.int32)
355+
torch.distributed.all_reduce(tokens_per_expert)
356+
tokens_per_expert = chunk_by_rank(tokens_per_expert, rank,
357+
world_size).to(dtype=torch.int32)
359358

360-
torch.testing.assert_close(tokens_per_expert,
361-
expert_num_tokens,
362-
atol=0,
363-
rtol=0)
359+
torch.testing.assert_close(tokens_per_expert,
360+
expert_num_tokens,
361+
atol=0,
362+
rtol=0)
364363

365364
b_a = b_a * 1.5
366365

367366
out = torch.full(
368-
(rank_num_tokens * world_size, hidden_dim),
367+
(rank_num_tokens, hidden_dim),
369368
torch.nan,
370369
dtype=a.dtype,
371370
device=device,
@@ -424,14 +423,15 @@ def _pplx_dispatch_combine(
424423
nvshmem_finalize()
425424

426425

426+
# TODO: M < world_size doesn't appear to be supported by pplx?
427427
@pytest.mark.parametrize("m", [4, 32, 64, 222])
428428
@pytest.mark.parametrize("n", [128, 1024, 2048])
429-
@pytest.mark.parametrize("k", [128, 512, 1024]) # restrictions? % 128?
429+
@pytest.mark.parametrize("k", [128, 512, 1024])
430430
@pytest.mark.parametrize("e", NUM_EXPERTS)
431431
@pytest.mark.parametrize("topk", TOP_KS)
432432
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
433433
@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [[4, 2]])
434-
@pytest.mark.skipif(not has_pplx, reason="PPLX kernels not available.")
434+
@requires_pplx
435435
def test_pplx_dispatch_combine(
436436
m: int,
437437
n: int,
@@ -502,11 +502,9 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk):
502502
# Chunking weights like this only works for batched format
503503
chunk_by_rank(w1, rank, world_size),
504504
chunk_by_rank(w2, rank, world_size),
505-
#w1,
506-
#w2,
507505
chunk_topk_weight,
508506
chunk_topk_ids,
509-
global_num_experts=num_experts #? num_local_experts?
507+
global_num_experts=num_experts
510508
)
511509

512510
torch.cuda.synchronize()
@@ -547,15 +545,15 @@ def _pplx_moe(
547545
nvshmem_finalize()
548546

549547

550-
# TODO: M == 1 doesn't work
548+
# TODO: M < world_size doesn't appear to be supported by pplx?
551549
@pytest.mark.parametrize("m", [2, 3, 32, 45, 64, 222])
552550
@pytest.mark.parametrize("n", [128, 1024, 2048])
553551
@pytest.mark.parametrize("k", [128, 512, 1024])
554552
@pytest.mark.parametrize("e", NUM_EXPERTS)
555553
@pytest.mark.parametrize("topk", TOP_KS)
556554
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
557555
@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [4, 2]])
558-
@pytest.mark.skipif(not has_pplx, reason="PPLX kernels not available.")
556+
@requires_pplx
559557
def test_pplx_moe(
560558
m: int,
561559
n: int,

vllm/forward_context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def set_forward_context(attn_metadata: Any,
9292
dtype=torch.int32)
9393
from vllm.distributed.parallel_state import get_dp_group
9494
dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group)
95-
#TODO device?
95+
#TODO device? (tms)
9696
max_tokens_across_dp = torch.max(
9797
num_tokens_tensor) #.to(device="cuda")
9898
cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0)

vllm/model_executor/layers/fused_moe/deep_gemm_moe.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# SPDX-License-Identifier: Apache-2.0
2+
import functools
23
import importlib.util
34
from typing import Optional, Tuple
45

@@ -19,6 +20,7 @@
1920
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
2021

2122

23+
@functools.cache
2224
def deep_gemm_block_shape() -> list[int]:
2325
# Lazy import to avoid CUDA initialization problems.
2426
import deep_gemm as dg

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 50 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,7 @@ class MoEConfig:
6363
ep_size: int
6464
ep_rank: int
6565

66-
in_dtype: torch.dtype
67-
out_dtype: torch.dtype
66+
in_dtype: torch.dtype # The activation type.
6867

6968
# TODO: add more quantization params, blocked, per-token, etc.
7069
block_size: int = 128
@@ -142,7 +141,6 @@ def get_all_to_all(**kwargs):
142141
return _all_to_all_cache.get_or_create(**kwargs)
143142

144143

145-
#TODO: Every change in this class is a broken hack!!
146144
@CustomOp.register("unquantized_fused_moe")
147145
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
148146
"""MoE method without quantization."""
@@ -249,18 +247,15 @@ def apply(
249247
activation=activation,
250248
apply_router_weight_on_input=apply_router_weight_on_input)
251249

252-
# Maybe extra args
253250
def set_dispatch_combine(
254251
self, dispatch_combine: FusedMoEQuantizeDispatchCombine) -> bool:
255252
assert self.fused_experts == fused_experts
256253

257-
#block_m = MOE_DP_CHUNK_SIZE * (self.moe.ep_size // self.moe.dp_size)
258-
259254
experts: FusedMoEPermuteExpertsUnpermute = None
260255

261256
if isinstance(dispatch_combine,
262257
(BatchedDispatchCombine, PplxDispatchCombine)):
263-
logger.info("BatchedTritonExperts %s", self.moe)
258+
logger.debug("BatchedTritonExperts %s", self.moe)
264259
experts = BatchedTritonExperts(
265260
use_fp8_w8a8=False,
266261
use_int8_w8a8=False,
@@ -269,7 +264,7 @@ def set_dispatch_combine(
269264
block_shape=None,
270265
)
271266
else:
272-
logger.info("TritonExperts %s", self.moe)
267+
logger.debug("TritonExperts %s", self.moe)
273268
experts = TritonExperts(
274269
use_fp8_w8a8=False,
275270
use_int8_w8a8=False,
@@ -611,8 +606,7 @@ def __init__(
611606
dp_rank=self.dp_rank,
612607
ep_size=self.ep_size,
613608
ep_rank=self.ep_rank,
614-
in_dtype=params_dtype, # this is probably not right, where to get?
615-
out_dtype=params_dtype, # ditto.
609+
in_dtype=params_dtype, # TODO: is this right?
616610
)
617611

618612
# Note: get_quant_method will look at the layer's local_num_experts
@@ -628,12 +622,42 @@ def __init__(
628622
assert quant_method is not None
629623
self.quant_method = quant_method
630624

631-
dispatch_combine: FusedMoEQuantizeDispatchCombine = None
625+
dispatch_combine = self._construct_dispatch_combine(
626+
moe, quant_config)
627+
628+
success = self.quant_method.set_dispatch_combine(dispatch_combine)
629+
630+
if not success:
631+
logger.warning("DP+EP not supported for %s.",
632+
type(self.quant_method))
633+
634+
self.apply_router_weight_on_input = apply_router_weight_on_input
635+
moe_quant_params = {
636+
"num_experts": self.local_num_experts,
637+
"hidden_size": hidden_size,
638+
"intermediate_size_per_partition":
639+
self.intermediate_size_per_partition,
640+
"params_dtype": params_dtype,
641+
"weight_loader": self.weight_loader,
642+
}
643+
# need full intermediate size pre-sharding for WNA16 act order
644+
if (self.quant_method.__class__.__name__
645+
in ("GPTQMarlinMoEMethod",
646+
"CompressedTensorsWNA16MarlinMoEMethod",
647+
"CompressedTensorsWNA16MoEMethod")):
648+
moe_quant_params["intermediate_size_full"] = intermediate_size
649+
650+
self.quant_method.create_weights(layer=self, **moe_quant_params)
632651

633-
# TODO: move to method?
652+
# TODO: return Optional?
653+
def _construct_dispatch_combine(
654+
self,
655+
moe: MoEConfig,
656+
quant_config: Optional[QuantizationConfig],
657+
) -> FusedMoEQuantizeDispatchCombine:
634658
if self.dp_size > 1 and has_pplx:
635-
logger.info("using pplx dispatch")
636-
max_num_tokens = MOE_DP_CHUNK_SIZE # // moe.dp_size
659+
logger.debug("using pplx dispatch")
660+
max_num_tokens = MOE_DP_CHUNK_SIZE
637661
world_size = moe.ep_size
638662
dp_size = moe.ep_size // moe.dp_size # dp_size actually means TP.
639663
rank = moe.ep_rank
@@ -654,51 +678,28 @@ def __init__(
654678
(moe.hidden_dim + moe.block_size - 1) // moe.block_size *
655679
torch.float32.itemsize)))
656680

657-
dispatch_combine = PplxDispatchCombine(
681+
return PplxDispatchCombine(
658682
all_to_all,
659683
max_num_tokens,
660684
world_size,
661685
dp_size,
662-
rank, # just for debugging
686+
rank,
663687
moe.in_dtype,
664688
)
665689
elif True:
666-
logger.info("using standard dispatch")
667-
dispatch_combine = StandardDispatchCombine(
690+
logger.debug("using standard dispatch")
691+
return StandardDispatchCombine(
668692
moe.in_dtype,
669693
quant_config.weight_block_size
670694
if quant_config is not None else None,
671695
)
672696
else:
673-
logger.info("using batched dispatch")
674-
dispatch_combine = BatchedDispatchCombine(
697+
logger.debug("using batched dispatch")
698+
return BatchedDispatchCombine(
675699
moe.ep_size,
676700
moe.ep_rank,
677701
)
678702

679-
success = self.quant_method.set_dispatch_combine(dispatch_combine)
680-
if not success:
681-
logger.warning("DP+EP not supported for %s.",
682-
type(self.quant_method))
683-
684-
self.apply_router_weight_on_input = apply_router_weight_on_input
685-
moe_quant_params = {
686-
"num_experts": self.local_num_experts,
687-
"hidden_size": hidden_size,
688-
"intermediate_size_per_partition":
689-
self.intermediate_size_per_partition,
690-
"params_dtype": params_dtype,
691-
"weight_loader": self.weight_loader,
692-
}
693-
# need full intermediate size pre-sharding for WNA16 act order
694-
if (self.quant_method.__class__.__name__
695-
in ("GPTQMarlinMoEMethod",
696-
"CompressedTensorsWNA16MarlinMoEMethod",
697-
"CompressedTensorsWNA16MoEMethod")):
698-
moe_quant_params["intermediate_size_full"] = intermediate_size
699-
700-
self.quant_method.create_weights(layer=self, **moe_quant_params)
701-
702703
def _load_per_tensor_weight_scale(self, shard_id: str,
703704
param: torch.nn.Parameter,
704705
loaded_weight: torch.Tensor,
@@ -1015,9 +1016,14 @@ def naive_multicast(self, x: torch.Tensor,
10151016

10161017
return buffer
10171018

1019+
# TODO: will this be cudagraph-able? (probably not)
1020+
# This should not be necessary.
1021+
def invalid_pplx(self, hidden_states: torch.Tensor) -> bool:
1022+
return has_pplx and hidden_states.shape[0] < self.dp_size
1023+
10181024
def forward(self, hidden_states: torch.Tensor,
10191025
router_logits: torch.Tensor):
1020-
if self.use_direct_call:
1026+
if self.use_direct_call or self.invalid_pplx(hidden_states):
10211027
return self.forward_impl(hidden_states, router_logits)
10221028
else:
10231029
return torch.ops.vllm.moe_forward(hidden_states, router_logits,

vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def __init__(self,
2828
quant_dtype: Optional[torch.dtype] = None,
2929
block_shape: Optional[List[int]] = None):
3030
super().__init__()
31+
assert max_num_tokens > 0
3132
self.a2a = a2a
3233
self.block_shape = block_shape
3334
self.max_num_tokens = max_num_tokens
@@ -47,13 +48,15 @@ def dispatch(
4748
expert_map: Optional[torch.Tensor],
4849
apply_router_weight_on_input: bool,
4950
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
50-
# Is this always going to be a1.device?
51-
device = a1.device
51+
num_tokens = a1.shape[0] # M
5252
hidden_dim = a1.shape[-1] # K
5353

54-
# ??
54+
assert rank_topk_ids.shape[0] == num_tokens
5555
# assert expert_map is None, "NYI"
5656

57+
# Is this always going to be a1.device?
58+
device = a1.device
59+
5760
if apply_router_weight_on_input:
5861
topk = rank_topk_ids.shape[1]
5962
# TODO: this only works for topK=1, will need to update for topK>1
@@ -102,7 +105,6 @@ def dispatch(
102105
)
103106

104107
# This argument is optional, defaults to indices.shape[0]
105-
num_tokens = a1.shape[0] # M
106108
bound_m = torch.tensor([num_tokens], dtype=torch.uint32, device=device)
107109

108110
# TODO: optimize this?
@@ -133,7 +135,9 @@ def combine(
133135
dtype=torch.uint32,
134136
device=fused_expert_output.device)
135137

136-
assert output.shape[0] <= self.max_num_tokens
138+
assert topk_ids.shape[0] <= num_tokens
139+
assert output.shape[0] <= self.max_num_tokens, \
140+
f"{output.shape[0]} <= {self.max_num_tokens}"
137141
assert output.shape[1] == fused_expert_output.shape[-1]
138142

139143
# Set weights to 1 if we did them in dispatch. This is hacky.

vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212
class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
1313

1414
def __init__(self,
15-
use_fp8_w8a8: bool,
16-
use_int8_w8a8: bool,
17-
use_int8_w8a16: bool,
18-
use_int4_w4a16: bool,
19-
per_channel_quant: bool,
15+
use_fp8_w8a8: bool = False,
16+
use_int8_w8a8: bool = False,
17+
use_int8_w8a16: bool = False,
18+
use_int4_w4a16: bool = False,
19+
per_channel_quant: bool = False,
2020
block_shape: Optional[List[int]] = None,
2121
block_m: Optional[int] = None,
2222
allow_deep_gemm: bool = False):

0 commit comments

Comments
 (0)