Skip to content

Commit 402759d

Browse files
[Attention] FlashAttn MLA (vllm-project#14258)
Signed-off-by: Lucas Wilkinson <[email protected]> Signed-off-by: Lucas Wilkinson <[email protected]> Signed-off-by: Matthew Bonanni <[email protected]> Co-authored-by: Matthew Bonanni <[email protected]> Co-authored-by: Matthew Bonanni <[email protected]>
1 parent 2c301ee commit 402759d

File tree

22 files changed

+491
-211
lines changed

22 files changed

+491
-211
lines changed

.buildkite/check-wheel-size.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
import sys
66
import zipfile
77

8-
# Read the VLLM_MAX_SIZE_MB environment variable, defaulting to 400 MiB
9-
# Note that we have 400 MiB quota, please use it wisely.
10-
# See https://github.com/pypi/support/issues/3792 .
8+
# Read the VLLM_MAX_SIZE_MB environment variable, defaulting to 450 MiB
9+
# Note that we have 800 MiB quota, please use it wisely.
10+
# See https://github.com/pypi/support/issues/6326 .
1111
# Please also sync the value with the one in Dockerfile.
12-
VLLM_MAX_SIZE_MB = int(os.environ.get("VLLM_MAX_SIZE_MB", 400))
12+
VLLM_MAX_SIZE_MB = int(os.environ.get("VLLM_MAX_SIZE_MB", 450))
1313

1414

1515
def print_top_10_largest_files(zip_file):

cmake/external_projects/vllm_flash_attn.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ else()
3838
FetchContent_Declare(
3939
vllm-flash-attn
4040
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
41-
GIT_TAG 57b4e68b9f9d94750b46de8f8dbd2bfcc86edd4f
41+
GIT_TAG ee4d25bd84e0cbc7e0b9b9685085fd5db2dcb62a
4242
GIT_PROGRESS TRUE
4343
# Don't share the vllm-flash-attn build between build types
4444
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn

docker/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ RUN --mount=type=cache,target=/root/.cache/ccache \
237237
# Check the size of the wheel if RUN_WHEEL_CHECK is true
238238
COPY .buildkite/check-wheel-size.py check-wheel-size.py
239239
# sync the default value with .buildkite/check-wheel-size.py
240-
ARG VLLM_MAX_SIZE_MB=400
240+
ARG VLLM_MAX_SIZE_MB=450
241241
ENV VLLM_MAX_SIZE_MB=$VLLM_MAX_SIZE_MB
242242
ARG RUN_WHEEL_CHECK=true
243243
RUN if [ "$RUN_WHEEL_CHECK" = "true" ]; then \

tests/kernels/attention/test_attention_selector.py

Lines changed: 83 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def clear_cache():
2222

2323
# Define MLA and non-MLA backends separately
2424
DEVICE_MLA_BACKENDS = {
25-
"cuda": ["TRITON_MLA", "FLASHMLA"],
25+
"cuda": ["TRITON_MLA", "FLASHMLA", "FLASH_ATTN_MLA", "CUTLASS_MLA"],
2626
"hip": ["TRITON_MLA", "ROCM_AITER_MLA"],
2727
"cpu": [],
2828
}
@@ -98,21 +98,14 @@ def test_env(
9898
with patch("vllm.attention.selector.current_platform",
9999
RocmPlatform()):
100100
if use_mla:
101-
# Validate HIP MLA backend-block_size combinations
102-
valid_combination = (
103-
(name == "TRITON_MLA" and block_size != 1)
104-
or (name == "ROCM_AITER_MLA" and block_size == 1))
105-
106-
if valid_combination:
107-
backend = get_attn_backend(16,
108-
torch.float16,
109-
torch.float16,
110-
block_size,
111-
False,
112-
use_mla=use_mla)
113-
expected = f"{name}_VLLM_V1" if use_v1 else name
114-
assert backend.get_name() == expected
115-
else:
101+
# ROCm MLA backend logic:
102+
# - TRITON_MLA: supported when block_size != 1
103+
# - ROCM_AITER_MLA: supported when block_size == 1
104+
# If backend is forced but doesn't match block_size,
105+
# should raise ValueError
106+
107+
if name == "TRITON_MLA" and block_size == 1:
108+
# TRITON_MLA doesn't support block_size == 1
116109
with pytest.raises(ValueError) as exc_info:
117110
get_attn_backend(16,
118111
torch.float16,
@@ -122,6 +115,27 @@ def test_env(
122115
use_mla=use_mla)
123116
assert f"The selected backend, {name}" in str(
124117
exc_info.value)
118+
elif name == "ROCM_AITER_MLA" and block_size != 1:
119+
# ROCM_AITER_MLA only supports block_size == 1
120+
with pytest.raises(ValueError) as exc_info:
121+
get_attn_backend(16,
122+
torch.float16,
123+
torch.float16,
124+
block_size,
125+
False,
126+
use_mla=use_mla)
127+
assert f"The selected backend, {name}" in str(
128+
exc_info.value)
129+
else:
130+
# Valid backend-block_size combination
131+
backend = get_attn_backend(16,
132+
torch.float16,
133+
torch.float16,
134+
block_size,
135+
False,
136+
use_mla=use_mla)
137+
expected = f"{name}_VLLM_V1" if use_v1 else name
138+
assert backend.get_name() == expected
125139
else:
126140
backend = get_attn_backend(16,
127141
torch.float16,
@@ -136,26 +150,68 @@ def test_env(
136150
with patch("vllm.attention.selector.current_platform",
137151
CudaPlatform()):
138152
if use_mla:
139-
if name == "FLASHMLA" and block_size == 64:
140-
from vllm.attention.backends.flashmla import (
141-
is_flashmla_supported)
142-
143-
# only on cuda platforms with specific capability.
144-
is_supported, _ = is_flashmla_supported()
145-
146-
if not is_supported:
147-
# if platform is not supported then skip this case.
148-
pytest.skip()
153+
# CUDA MLA backend logic:
154+
# - CUTLASS_MLA: only supported with block_size == 128
155+
# and Blackwell GPUs (SM 10.0), V1 only
156+
# - FLASHMLA: only supported with block_size == 64
157+
# - FLASH_ATTN_MLA: V1 only
158+
# - TRITON_MLA: fallback for other cases
159+
160+
if name == "CUTLASS_MLA":
161+
if not use_v1:
162+
# CUTLASS_MLA only supported on V1 engine
163+
pytest.skip(
164+
"CUTLASS_MLA only supported on V1 engine")
165+
elif block_size != 128:
166+
# CUTLASS_MLA only supports block_size == 128
167+
pytest.skip(
168+
"CUTLASS_MLA only supports block_size 128")
169+
else:
170+
backend = get_attn_backend(16,
171+
torch.float16,
172+
torch.float16,
173+
block_size,
174+
False,
175+
use_mla=use_mla)
176+
expected = "CUTLASS_MLA_VLLM_V1"
177+
assert backend.get_name() == expected
178+
elif name == "FLASHMLA":
179+
if block_size != 64:
180+
# FlashMLA only supports block_size == 64
181+
pytest.skip("FlashMLA only supports block_size 64")
182+
else:
183+
from vllm.attention.backends.flashmla import (
184+
is_flashmla_supported)
185+
is_supported, _ = is_flashmla_supported()
186+
if not is_supported:
187+
pytest.skip(
188+
"FlashMLA not supported on this platform")
189+
else:
190+
backend = get_attn_backend(16,
191+
torch.float16,
192+
torch.float16,
193+
block_size,
194+
False,
195+
use_mla=use_mla)
196+
expected = f"{name}_VLLM_V1" if use_v1 else name
197+
assert backend.get_name() == expected
198+
elif name == "FLASH_ATTN_MLA":
199+
if not use_v1:
200+
# FlashAttention MLA only supported on V1 engine
201+
pytest.skip(
202+
"FlashAttention MLA only supported on V1 engine"
203+
)
149204
else:
150205
backend = get_attn_backend(16,
151206
torch.float16,
152207
torch.float16,
153208
block_size,
154209
False,
155210
use_mla=use_mla)
156-
expected = f"{name}_VLLM_V1" if use_v1 else name
211+
expected = "FLASH_ATTN_MLA"
157212
assert backend.get_name() == expected
158213
else:
214+
# TRITON_MLA or other fallback
159215
backend = get_attn_backend(16,
160216
torch.float16,
161217
torch.float16,

tests/v1/attention/test_attention_backends.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -70,22 +70,6 @@ def _convert_dtype_to_torch(dtype):
7070
}
7171

7272

73-
def create_dummy_kv_cache(kv_cache_spec: FullAttentionSpec,
74-
device: torch.device,
75-
num_blocks: int = 100) -> torch.Tensor:
76-
"""Create a dummy KV cache tensor for testing."""
77-
kv_cache = torch.randn(
78-
2, # K and V
79-
num_blocks,
80-
kv_cache_spec.block_size,
81-
kv_cache_spec.num_kv_heads,
82-
kv_cache_spec.head_size,
83-
dtype=_convert_dtype_to_torch(kv_cache_spec.dtype),
84-
device=device,
85-
)
86-
return kv_cache
87-
88-
8973
def create_and_prepopulate_kv_cache(
9074
k_contexts: list[torch.Tensor],
9175
v_contexts: list[torch.Tensor],

0 commit comments

Comments
 (0)