@@ -22,7 +22,7 @@ def clear_cache():
22
22
23
23
# Define MLA and non-MLA backends separately
24
24
DEVICE_MLA_BACKENDS = {
25
- "cuda" : ["TRITON_MLA" , "FLASHMLA" ],
25
+ "cuda" : ["TRITON_MLA" , "FLASHMLA" , "FLASH_ATTN_MLA" , "CUTLASS_MLA" ],
26
26
"hip" : ["TRITON_MLA" , "ROCM_AITER_MLA" ],
27
27
"cpu" : [],
28
28
}
@@ -98,21 +98,14 @@ def test_env(
98
98
with patch ("vllm.attention.selector.current_platform" ,
99
99
RocmPlatform ()):
100
100
if use_mla :
101
- # Validate HIP MLA backend-block_size combinations
102
- valid_combination = (
103
- (name == "TRITON_MLA" and block_size != 1 )
104
- or (name == "ROCM_AITER_MLA" and block_size == 1 ))
105
-
106
- if valid_combination :
107
- backend = get_attn_backend (16 ,
108
- torch .float16 ,
109
- torch .float16 ,
110
- block_size ,
111
- False ,
112
- use_mla = use_mla )
113
- expected = f"{ name } _VLLM_V1" if use_v1 else name
114
- assert backend .get_name () == expected
115
- else :
101
+ # ROCm MLA backend logic:
102
+ # - TRITON_MLA: supported when block_size != 1
103
+ # - ROCM_AITER_MLA: supported when block_size == 1
104
+ # If backend is forced but doesn't match block_size,
105
+ # should raise ValueError
106
+
107
+ if name == "TRITON_MLA" and block_size == 1 :
108
+ # TRITON_MLA doesn't support block_size == 1
116
109
with pytest .raises (ValueError ) as exc_info :
117
110
get_attn_backend (16 ,
118
111
torch .float16 ,
@@ -122,6 +115,27 @@ def test_env(
122
115
use_mla = use_mla )
123
116
assert f"The selected backend, { name } " in str (
124
117
exc_info .value )
118
+ elif name == "ROCM_AITER_MLA" and block_size != 1 :
119
+ # ROCM_AITER_MLA only supports block_size == 1
120
+ with pytest .raises (ValueError ) as exc_info :
121
+ get_attn_backend (16 ,
122
+ torch .float16 ,
123
+ torch .float16 ,
124
+ block_size ,
125
+ False ,
126
+ use_mla = use_mla )
127
+ assert f"The selected backend, { name } " in str (
128
+ exc_info .value )
129
+ else :
130
+ # Valid backend-block_size combination
131
+ backend = get_attn_backend (16 ,
132
+ torch .float16 ,
133
+ torch .float16 ,
134
+ block_size ,
135
+ False ,
136
+ use_mla = use_mla )
137
+ expected = f"{ name } _VLLM_V1" if use_v1 else name
138
+ assert backend .get_name () == expected
125
139
else :
126
140
backend = get_attn_backend (16 ,
127
141
torch .float16 ,
@@ -136,26 +150,68 @@ def test_env(
136
150
with patch ("vllm.attention.selector.current_platform" ,
137
151
CudaPlatform ()):
138
152
if use_mla :
139
- if name == "FLASHMLA" and block_size == 64 :
140
- from vllm .attention .backends .flashmla import (
141
- is_flashmla_supported )
142
-
143
- # only on cuda platforms with specific capability.
144
- is_supported , _ = is_flashmla_supported ()
145
-
146
- if not is_supported :
147
- # if platform is not supported then skip this case.
148
- pytest .skip ()
153
+ # CUDA MLA backend logic:
154
+ # - CUTLASS_MLA: only supported with block_size == 128
155
+ # and Blackwell GPUs (SM 10.0), V1 only
156
+ # - FLASHMLA: only supported with block_size == 64
157
+ # - FLASH_ATTN_MLA: V1 only
158
+ # - TRITON_MLA: fallback for other cases
159
+
160
+ if name == "CUTLASS_MLA" :
161
+ if not use_v1 :
162
+ # CUTLASS_MLA only supported on V1 engine
163
+ pytest .skip (
164
+ "CUTLASS_MLA only supported on V1 engine" )
165
+ elif block_size != 128 :
166
+ # CUTLASS_MLA only supports block_size == 128
167
+ pytest .skip (
168
+ "CUTLASS_MLA only supports block_size 128" )
169
+ else :
170
+ backend = get_attn_backend (16 ,
171
+ torch .float16 ,
172
+ torch .float16 ,
173
+ block_size ,
174
+ False ,
175
+ use_mla = use_mla )
176
+ expected = "CUTLASS_MLA_VLLM_V1"
177
+ assert backend .get_name () == expected
178
+ elif name == "FLASHMLA" :
179
+ if block_size != 64 :
180
+ # FlashMLA only supports block_size == 64
181
+ pytest .skip ("FlashMLA only supports block_size 64" )
182
+ else :
183
+ from vllm .attention .backends .flashmla import (
184
+ is_flashmla_supported )
185
+ is_supported , _ = is_flashmla_supported ()
186
+ if not is_supported :
187
+ pytest .skip (
188
+ "FlashMLA not supported on this platform" )
189
+ else :
190
+ backend = get_attn_backend (16 ,
191
+ torch .float16 ,
192
+ torch .float16 ,
193
+ block_size ,
194
+ False ,
195
+ use_mla = use_mla )
196
+ expected = f"{ name } _VLLM_V1" if use_v1 else name
197
+ assert backend .get_name () == expected
198
+ elif name == "FLASH_ATTN_MLA" :
199
+ if not use_v1 :
200
+ # FlashAttention MLA only supported on V1 engine
201
+ pytest .skip (
202
+ "FlashAttention MLA only supported on V1 engine"
203
+ )
149
204
else :
150
205
backend = get_attn_backend (16 ,
151
206
torch .float16 ,
152
207
torch .float16 ,
153
208
block_size ,
154
209
False ,
155
210
use_mla = use_mla )
156
- expected = f" { name } _VLLM_V1" if use_v1 else name
211
+ expected = "FLASH_ATTN_MLA"
157
212
assert backend .get_name () == expected
158
213
else :
214
+ # TRITON_MLA or other fallback
159
215
backend = get_attn_backend (16 ,
160
216
torch .float16 ,
161
217
torch .float16 ,
0 commit comments