Skip to content

Commit c7ddca4

Browse files
committed
add guards for pplx import
Signed-off-by: Bill Nell <[email protected]>
1 parent 448658a commit c7ddca4

File tree

3 files changed

+34
-11
lines changed

3 files changed

+34
-11
lines changed

tests/kernels/moe/test_pplx_moe.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,16 @@
1010

1111
import pytest
1212
import torch
13-
from pplx_kernels import AllToAll
14-
from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
15-
nvshmem_finalize, nvshmem_get_unique_id,
16-
nvshmem_init)
13+
14+
try:
15+
from pplx_kernels import AllToAll
16+
from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
17+
nvshmem_finalize, nvshmem_get_unique_id,
18+
nvshmem_init)
19+
has_pplx = False
20+
except ImportError as ex:
21+
has_pplx = False
22+
1723
from torch.multiprocessing import (
1824
spawn) # pyright: ignore[reportPrivateImportUsage]
1925
from typing_extensions import Concatenate, ParamSpec
@@ -45,6 +51,11 @@
4551
reason="Requires multi-node environment",
4652
)
4753

54+
requires_pplx = pytest.mark.skipif(
55+
not has_pplx,
56+
reason="Requires PPLX kernels",
57+
)
58+
4859

4960
@dataclasses.dataclass
5061
class ProcessGroupInfo:
@@ -420,6 +431,7 @@ def _pplx_dispatch_combine(
420431
@pytest.mark.parametrize("topk", TOP_KS)
421432
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
422433
@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [[4, 2]])
434+
@pytest.mark.skipif(not has_pplx, reason="PPLX kernels not available.")
423435
def test_pplx_dispatch_combine(
424436
m: int,
425437
n: int,
@@ -543,6 +555,7 @@ def _pplx_moe(
543555
@pytest.mark.parametrize("topk", TOP_KS)
544556
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
545557
@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [4, 2]])
558+
@pytest.mark.skipif(not has_pplx, reason="PPLX kernels not available.")
546559
def test_pplx_moe(
547560
m: int,
548561
n: int,

vllm/distributed/parallel_state.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
"""
2424
import contextlib
2525
import gc
26+
import importlib
2627
import pickle
2728
import weakref
2829
from collections import namedtuple
@@ -34,9 +35,6 @@
3435

3536
import torch
3637
import torch.distributed
37-
from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
38-
nvshmem_finalize, nvshmem_get_unique_id,
39-
nvshmem_init)
4038
from torch.distributed import Backend, ProcessGroup
4139

4240
import vllm.envs as envs
@@ -920,7 +918,12 @@ def init_distributed_environment(
920918

921919
@run_once
922920
def pplx_init(rank, world_size):
923-
if world_size > 1:
921+
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
922+
923+
if has_pplx and world_size > 1:
924+
from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
925+
nvshmem_get_unique_id,
926+
nvshmem_init)
924927
try:
925928
global PPLX_DID_INIT
926929
logger.debug(f"PPLX_INIT {rank} {world_size}")
@@ -940,6 +943,7 @@ def pplx_init(rank, world_size):
940943
def pplx_finalize():
941944
global PPLX_DID_INIT
942945
if PPLX_DID_INIT:
946+
from pplx_kernels.nvshmem import nvshmem_finalize
943947
nvshmem_finalize()
944948

945949

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3+
import importlib
34
import threading
45
import weakref
56
from abc import abstractmethod
67
from dataclasses import dataclass
78
from enum import Enum
89
from typing import Callable, List, Optional, Tuple
910

10-
import pplx_kernels as pplx # TODO: guard this
1111
import torch
1212
import torch.nn.functional as F
1313
from torch.nn.parameter import UninitializedParameter
@@ -27,14 +27,17 @@
2727
from vllm.platforms.interface import CpuArchEnum
2828
from vllm.utils import direct_register_custom_op
2929

30+
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
31+
3032
if current_platform.is_cuda_alike():
3133
from .dispatch_combine import StandardDispatchCombine
3234
from .fused_batched_moe import BatchedDispatchCombine, BatchedTritonExperts
3335
from .fused_moe import TritonExperts, fused_experts
3436
from .modular_kernel import (FusedMoEModularKernel,
3537
FusedMoEPermuteExpertsUnpermute,
3638
FusedMoEQuantizeDispatchCombine)
37-
from .pplx_dispatch_combine import PplxDispatchCombine
39+
if has_pplx:
40+
from .pplx_dispatch_combine import PplxDispatchCombine
3841
else:
3942
fused_experts = None # type: ignore
4043
if current_platform.is_tpu():
@@ -115,6 +118,9 @@ def __init__(self):
115118
self._lock = threading.RLock() # Reentrant lock for thread safety
116119

117120
def get_or_create(self, **kwargs):
121+
assert has_pplx
122+
import pplx_kernels as pplx
123+
118124
# Create a hashable key from the kwargs
119125
key = tuple(sorted((k, v) for k, v in kwargs.items()))
120126

@@ -625,7 +631,7 @@ def __init__(
625631
dispatch_combine: FusedMoEQuantizeDispatchCombine = None
626632

627633
# TODO: move to method?
628-
if self.dp_size > 1:
634+
if self.dp_size > 1 and has_pplx:
629635
logger.info("using pplx dispatch")
630636
max_num_tokens = MOE_DP_CHUNK_SIZE # // moe.dp_size
631637
world_size = moe.ep_size

0 commit comments

Comments
 (0)