Skip to content

Commit 0dda294

Browse files
wenxindongworkjinzhen-lin
authored andcommitted
[TPU] Support Pathways in vLLM (vllm-project#21417)
Signed-off-by: wenxindongwork <[email protected]> Signed-off-by: Jinzhen Lin <[email protected]>
1 parent 3210baa commit 0dda294

File tree

2 files changed

+17
-6
lines changed

2 files changed

+17
-6
lines changed

vllm/envs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@
124124
VLLM_V1_USE_OUTLINES_CACHE: bool = False
125125
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
126126
VLLM_TPU_MOST_MODEL_LEN: Optional[int] = None
127+
VLLM_TPU_USING_PATHWAYS: bool = False
127128
VLLM_USE_DEEP_GEMM: bool = False
128129
VLLM_USE_FLASHINFER_MOE_FP8: bool = False
129130
VLLM_USE_FLASHINFER_MOE_FP4: bool = False
@@ -900,6 +901,10 @@ def get_vllm_port() -> Optional[int]:
900901
"VLLM_TPU_MOST_MODEL_LEN":
901902
lambda: maybe_convert_int(os.environ.get("VLLM_TPU_MOST_MODEL_LEN", None)),
902903

904+
# Whether using Pathways
905+
"VLLM_TPU_USING_PATHWAYS":
906+
lambda: bool("proxy" in os.getenv("JAX_PLATFORMS", "").lower()),
907+
903908
# Allow use of DeepGemm kernels for fused moe ops.
904909
"VLLM_USE_DEEP_GEMM":
905910
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))),

vllm/platforms/__init__.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3-
43
import logging
54
import traceback
65
from itertools import chain
76
from typing import TYPE_CHECKING, Optional
87

8+
from vllm import envs
99
from vllm.plugins import load_plugins_by_group
1010
from vllm.utils import resolve_obj_by_qualname, supports_xccl
1111

@@ -31,20 +31,26 @@ def vllm_version_matches_substr(substr: str) -> bool:
3131

3232

3333
def tpu_platform_plugin() -> Optional[str]:
34-
is_tpu = False
3534
logger.debug("Checking if TPU platform is available.")
35+
36+
# Check for Pathways TPU proxy
37+
if envs.VLLM_TPU_USING_PATHWAYS:
38+
logger.debug("Confirmed TPU platform is available via Pathways proxy.")
39+
return "tpu_commons.platforms.tpu_jax.TpuPlatform"
40+
41+
# Check for libtpu installation
3642
try:
3743
# While it's technically possible to install libtpu on a
3844
# non-TPU machine, this is a very uncommon scenario. Therefore,
39-
# we assume that libtpu is installed if and only if the machine
45+
# we assume that libtpu is installed only if the machine
4046
# has TPUs.
47+
4148
import libtpu # noqa: F401
42-
is_tpu = True
4349
logger.debug("Confirmed TPU platform is available.")
50+
return "vllm.platforms.tpu.TpuPlatform"
4451
except Exception as e:
4552
logger.debug("TPU platform is not available because: %s", str(e))
46-
47-
return "vllm.platforms.tpu.TpuPlatform" if is_tpu else None
53+
return None
4854

4955

5056
def cuda_platform_plugin() -> Optional[str]:

0 commit comments

Comments
 (0)