Skip to content

Commit ec10fd0

Browse files
authored
[Bugfix] Move current_platform import to avoid python import cache. (vllm-project#16601)
Signed-off-by: iwzbi <[email protected]>
1 parent 0426e3c commit ec10fd0

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

tests/kernels/attention/test_attention_selector.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,12 @@ def test_env(
8484
m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0")
8585

8686
if device == "cpu":
87-
with patch("vllm.attention.selector.current_platform", CpuPlatform()):
87+
with patch("vllm.platforms.current_platform", CpuPlatform()):
8888
backend = get_attn_backend(16, torch.float16, None, block_size)
8989
assert backend.get_name() == "TORCH_SDPA"
9090

9191
elif device == "hip":
92-
with patch("vllm.attention.selector.current_platform", RocmPlatform()):
92+
with patch("vllm.platforms.current_platform", RocmPlatform()):
9393
if use_mla:
9494
# ROCm MLA backend logic:
9595
# - TRITON_MLA: supported when block_size != 1
@@ -126,7 +126,7 @@ def test_env(
126126
assert backend.get_name() == expected
127127

128128
elif device == "cuda":
129-
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
129+
with patch("vllm.platforms.current_platform", CudaPlatform()):
130130
if use_mla:
131131
# CUDA MLA backend logic:
132132
# - CUTLASS_MLA: only supported with block_size == 128
@@ -214,12 +214,12 @@ def test_env(
214214
def test_fp32_fallback(device: str):
215215
"""Test attention backend selection with fp32."""
216216
if device == "cpu":
217-
with patch("vllm.attention.selector.current_platform", CpuPlatform()):
217+
with patch("vllm.platforms.current_platform", CpuPlatform()):
218218
backend = get_attn_backend(16, torch.float32, None, 16)
219219
assert backend.get_name() == "TORCH_SDPA"
220220

221221
elif device == "cuda":
222-
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
222+
with patch("vllm.platforms.current_platform", CudaPlatform()):
223223
backend = get_attn_backend(16, torch.float32, None, 16)
224224
assert backend.get_name() == "FLEX_ATTENTION"
225225

@@ -277,7 +277,7 @@ def test_invalid_env(monkeypatch: pytest.MonkeyPatch):
277277
"""Test that invalid attention backend names raise ValueError."""
278278
with (
279279
monkeypatch.context() as m,
280-
patch("vllm.attention.selector.current_platform", CudaPlatform()),
280+
patch("vllm.platforms.current_platform", CudaPlatform()),
281281
):
282282
m.setenv(STR_BACKEND_ENV_VAR, STR_INVALID_VAL)
283283

vllm/attention/selector.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from vllm.attention.backends.abstract import AttentionBackend
1515
from vllm.attention.backends.registry import _Backend, backend_name_to_enum
1616
from vllm.logger import init_logger
17-
from vllm.platforms import current_platform
1817
from vllm.utils import STR_BACKEND_ENV_VAR, resolve_obj_by_qualname
1918

2019
logger = init_logger(__name__)
@@ -192,6 +191,8 @@ def _cached_get_attn_backend(
192191
)
193192

194193
# get device-specific attn_backend
194+
from vllm.platforms import current_platform
195+
195196
attention_cls = current_platform.get_attn_backend_cls(
196197
selected_backend,
197198
head_size,

0 commit comments

Comments
 (0)