@@ -84,12 +84,12 @@ def test_env(
84
84
m .setenv ("VLLM_MLA_DISABLE" , "1" if use_mla else "0" )
85
85
86
86
if device == "cpu" :
87
- with patch ("vllm.attention.selector .current_platform" , CpuPlatform ()):
87
+ with patch ("vllm.platforms .current_platform" , CpuPlatform ()):
88
88
backend = get_attn_backend (16 , torch .float16 , None , block_size )
89
89
assert backend .get_name () == "TORCH_SDPA"
90
90
91
91
elif device == "hip" :
92
- with patch ("vllm.attention.selector .current_platform" , RocmPlatform ()):
92
+ with patch ("vllm.platforms .current_platform" , RocmPlatform ()):
93
93
if use_mla :
94
94
# ROCm MLA backend logic:
95
95
# - TRITON_MLA: supported when block_size != 1
@@ -126,7 +126,7 @@ def test_env(
126
126
assert backend .get_name () == expected
127
127
128
128
elif device == "cuda" :
129
- with patch ("vllm.attention.selector .current_platform" , CudaPlatform ()):
129
+ with patch ("vllm.platforms .current_platform" , CudaPlatform ()):
130
130
if use_mla :
131
131
# CUDA MLA backend logic:
132
132
# - CUTLASS_MLA: only supported with block_size == 128
@@ -214,12 +214,12 @@ def test_env(
214
214
def test_fp32_fallback (device : str ):
215
215
"""Test attention backend selection with fp32."""
216
216
if device == "cpu" :
217
- with patch ("vllm.attention.selector .current_platform" , CpuPlatform ()):
217
+ with patch ("vllm.platforms .current_platform" , CpuPlatform ()):
218
218
backend = get_attn_backend (16 , torch .float32 , None , 16 )
219
219
assert backend .get_name () == "TORCH_SDPA"
220
220
221
221
elif device == "cuda" :
222
- with patch ("vllm.attention.selector .current_platform" , CudaPlatform ()):
222
+ with patch ("vllm.platforms .current_platform" , CudaPlatform ()):
223
223
backend = get_attn_backend (16 , torch .float32 , None , 16 )
224
224
assert backend .get_name () == "FLEX_ATTENTION"
225
225
@@ -277,7 +277,7 @@ def test_invalid_env(monkeypatch: pytest.MonkeyPatch):
277
277
"""Test that invalid attention backend names raise ValueError."""
278
278
with (
279
279
monkeypatch .context () as m ,
280
- patch ("vllm.attention.selector .current_platform" , CudaPlatform ()),
280
+ patch ("vllm.platforms .current_platform" , CudaPlatform ()),
281
281
):
282
282
m .setenv (STR_BACKEND_ENV_VAR , STR_INVALID_VAL )
283
283
0 commit comments