Skip to content

Commit d6e801e

Browse files
committed
fix test when pplx is missing + minor tweaks
Signed-off-by: Bill Nell <[email protected]>
1 parent 9b97c83 commit d6e801e

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

tests/kernels/moe/test_pplx_moe.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,6 @@
3333
get_default_config)
3434
from vllm.model_executor.layers.fused_moe.modular_kernel import (
3535
FusedMoEModularKernel)
36-
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
37-
PplxPrepareAndFinalize)
3836
from vllm.platforms import current_platform
3937

4038
PPLX_PREPARE_COMBOS = [(4, 128, 128), (32, 1024, 512), (64, 1024, 512),
@@ -350,6 +348,9 @@ def chunk_by_rank(t: torch.Tensor, r: int, w: int) -> torch.Tensor:
350348
def pplx_prepare_finalize(pgi: ProcessGroupInfo, dp_size: int, a: torch.Tensor,
351349
topk_weight: torch.Tensor, topk_ids: torch.Tensor,
352350
num_experts: int) -> torch.Tensor:
351+
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
352+
PplxPrepareAndFinalize)
353+
353354
assert torch.cuda.current_device() == pgi.local_rank
354355

355356
topk = topk_ids.shape[1]
@@ -499,6 +500,9 @@ def pplx_moe(
499500
use_compile: bool = True,
500501
use_cudagraphs: bool = True,
501502
) -> torch.Tensor:
503+
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
504+
PplxPrepareAndFinalize)
505+
502506
device = torch.device("cuda", rank)
503507
hidden_dim = a.shape[1]
504508
num_experts = w1.shape[0]

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -833,16 +833,15 @@ def __init__(
833833

834834
# Note: get_quant_method will look at the layer's local_num_experts
835835
# for heuristic purposes, so it must be initialized first.
836-
quant_method: Optional[FusedMoEMethodBase] = None
836+
quant_method: Optional[QuantizeMethodBase] = None
837837

838838
if quant_config is None:
839839
quant_method = UnquantizedFusedMoEMethod(moe)
840840
else:
841-
quant_method = quant_config.get_quant_method(
842-
self, prefix) # type: ignore
843-
assert isinstance(quant_method, FusedMoEMethodBase)
841+
quant_method = quant_config.get_quant_method(self, prefix)
844842

845843
assert quant_method is not None
844+
assert isinstance(quant_method, FusedMoEMethodBase)
846845
self.quant_method = quant_method
847846

848847
prepare_finalize = _construct_prepare_finalize(moe, quant_config)

0 commit comments

Comments
 (0)