Skip to content

Commit 41aa578

Browse files
authored
[NVIDIA] Add Cutlass MLA backend (vllm-project#17625)
1 parent 8d646c2 commit 41aa578

File tree

7 files changed

+111
-3
lines changed

7 files changed

+111
-3
lines changed

csrc/attention/mla/cutlass_mla_kernels.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ typename T::Fmha::Arguments args_from_options(
119119
{static_cast<ElementOut*>(out.data_ptr()), stride_O,
120120
static_cast<ElementAcc*>(nullptr), stride_LSE},
121121
hw_info,
122-
-1, // split_kv
122+
1, // split_kv
123123
nullptr, // is_var_split_kv
124124
};
125125
// TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute

tests/kernels/test_cutlass_mla_decode.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,9 @@ def test_cutlass_mla_decode(dtype: torch.dtype, mean_seq_len: int, bs: int,
7676
pack_factor = 128 // block_size
7777
block_num = ((block_num + pack_factor - 1) // pack_factor) * pack_factor
7878

79-
q = torch.randn(bs, h_q, d)
79+
# Amplify input values to ensure test coverage of edge cases where CUTLASS
80+
# kernel errors occur with split_k settings.
81+
q = torch.randn(bs, h_q, d) * 100
8082
block_table = torch.randint(0,
8183
bs * block_num, (bs, block_num),
8284
dtype=torch.int32)

vllm/engine/arg_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1395,6 +1395,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
13951395
"PALLAS_VLLM_V1",
13961396
"TRITON_ATTN_VLLM_V1",
13971397
"TRITON_MLA",
1398+
"CUTLASS_MLA_VLLM_V1",
13981399
"FLASHMLA",
13991400
"FLASHINFER",
14001401
"FLASHINFER_VLLM_V1",

vllm/platforms/cuda.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,14 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
183183
if use_mla:
184184
# TODO(lucas): refactor to be more concise
185185
# we should probably consider factoring out V1 here
186+
if selected_backend == _Backend.CUTLASS_MLA_VLLM_V1:
187+
if use_v1:
188+
logger.info_once("Using Cutlass MLA backend on V1 engine.")
189+
return ("vllm.v1.attention.backends.mla."
190+
"cutlass_mla.CutlassMLABackend")
191+
else:
192+
logger.warning(
193+
"Cutlass MLA backend is only supported on V1 engine")
186194
if selected_backend == _Backend.TRITON_MLA or block_size != 64:
187195
if use_v1:
188196
logger.info_once("Using Triton MLA backend on V1 engine.")

vllm/platforms/interface.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class _Backend(enum.Enum):
5151
TRITON_MLA_VLLM_V1 = enum.auto()
5252
FLASHMLA_VLLM_V1 = enum.auto()
5353
FLASHMLA = enum.auto() # Supported by V1
54+
CUTLASS_MLA_VLLM_V1 = enum.auto()
5455
HPU_ATTN = enum.auto()
5556
PALLAS = enum.auto()
5657
PALLAS_VLLM_V1 = enum.auto()

vllm/v1/attention/backends/mla/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,7 @@ def __init__(self,
350350
self.num_heads = model_config.get_num_attention_heads(
351351
runner.parallel_config)
352352
self.mla_dims = get_mla_dims(model_config)
353-
self.aot_schedule = is_vllm_fa and (get_flash_attn_version() == 3)
353+
self.aot_schedule = current_platform.is_cuda()
354354
self.kv_cache_spec = kv_cache_spec
355355

356356
# Dont try to access the runner on AMD
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
from typing import Any, Optional
4+
5+
import torch
6+
7+
import vllm._custom_ops as ops
8+
from vllm.attention.backends.abstract import (AttentionType,
9+
is_quantized_kv_cache)
10+
from vllm.logger import init_logger
11+
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
12+
MLACommonImpl,
13+
MLACommonMetadata)
14+
15+
logger = init_logger(__name__)
16+
17+
18+
class CutlassMLABackend(MLACommonBackend):
19+
20+
@staticmethod
21+
def get_name() -> str:
22+
return "CUTLASS_MLA_VLLM_V1"
23+
24+
@staticmethod
25+
def get_impl_cls() -> type["CutlassMLAImpl"]:
26+
return CutlassMLAImpl
27+
28+
29+
class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
30+
31+
def __init__(
32+
self,
33+
num_heads: int,
34+
head_size: int,
35+
scale: float,
36+
num_kv_heads: int,
37+
alibi_slopes: Optional[list[float]],
38+
sliding_window: Optional[int],
39+
kv_cache_dtype: str,
40+
blocksparse_params: Optional[dict[str, Any]],
41+
logits_soft_cap: Optional[float],
42+
attn_type: str,
43+
# MLA Specific Arguments
44+
**mla_args) -> None:
45+
super().__init__(num_heads, head_size, scale, num_kv_heads,
46+
alibi_slopes, sliding_window, kv_cache_dtype,
47+
blocksparse_params, logits_soft_cap, attn_type,
48+
**mla_args)
49+
50+
unsupported_features = [
51+
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
52+
]
53+
if any(unsupported_features):
54+
raise NotImplementedError(
55+
"CutlassMLAImpl does not support one of the following: "
56+
"alibi_slopes, sliding_window, blocksparse_params, "
57+
"logits_soft_cap")
58+
59+
if attn_type != AttentionType.DECODER:
60+
raise NotImplementedError("Encoder self-attention and "
61+
"encoder/decoder cross-attention "
62+
"are not implemented for "
63+
"CutlassMLAImpl")
64+
65+
if is_quantized_kv_cache(self.kv_cache_dtype):
66+
raise NotImplementedError(
67+
"CutlassMLA V1 with FP8 KV cache not yet supported")
68+
69+
def _forward_decode(
70+
self,
71+
q_nope: torch.Tensor,
72+
q_pe: torch.Tensor,
73+
kv_c_and_k_pe_cache: torch.Tensor,
74+
attn_metadata: MLACommonMetadata,
75+
) -> torch.Tensor:
76+
assert kv_c_and_k_pe_cache.numel() > 0
77+
assert attn_metadata.decode is not None
78+
79+
if self.kv_cache_dtype.startswith("fp8"):
80+
raise NotImplementedError("FP8 Cutlass MLA not yet supported")
81+
82+
B = q_nope.shape[0]
83+
84+
o = torch.empty((B, self.num_heads, self.kv_lora_rank),
85+
dtype=q_nope.dtype,
86+
device=q_nope.device)
87+
88+
# Run MLA
89+
# Clone q_nope and q_pe to make sure strides computation is correct.
90+
q_nope = q_nope.clone()
91+
q_pe = q_pe.clone()
92+
ops.cutlass_mla_decode(o, q_nope, q_pe, kv_c_and_k_pe_cache,
93+
attn_metadata.decode.seq_lens,
94+
attn_metadata.decode.block_table, self.scale)
95+
96+
return self._v_up_proj(o)

0 commit comments

Comments
 (0)