|
10 | 10 | from vllm.attention.backends.abstract import AttentionLayer |
11 | 11 | from vllm.attention.ops.rocm_aiter_mla import aiter_mla_decode_fwd |
12 | 12 | from vllm.config import VllmConfig |
| 13 | +from vllm.platforms.rocm import on_gfx950 |
13 | 14 | from vllm.utils import cdiv |
14 | 15 | from vllm.v1.attention.backends.mla.common import ( |
15 | 16 | MLACommonBackend, |
@@ -243,23 +244,59 @@ def __init__( |
243 | 244 | "alibi_slopes, sliding_window, logits_soft_cap" |
244 | 245 | ) |
245 | 246 |
|
246 | | - from aiter import flash_attn_varlen_func |
| 247 | + from aiter import flash_attn_varlen_func as aiter_flash_attn_varlen_func |
| 248 | + from aiter.ops.triton.mha import ( |
| 249 | + flash_attn_varlen_func as triton_flash_attn_varlen_func, |
| 250 | + ) |
| 251 | + |
| 252 | + self.triton_flash_attn_varlen_func = triton_flash_attn_varlen_func |
| 253 | + self.aiter_flash_attn_varlen_func = aiter_flash_attn_varlen_func |
| 254 | + |
| 255 | + def _use_triton_mha(self, q, k, **kwargs) -> bool: |
| 256 | + # TODO: refine dispatch logic on other non-GFX950 GPUs |
| 257 | + if not on_gfx950(): |
| 258 | + return False |
| 259 | + |
| 260 | + cu_seqlens_q = kwargs.get("cu_seqlens_q") |
| 261 | + max_seqlen_q = kwargs.get("max_seqlen_q", q.size(0)) |
| 262 | + max_seqlen_k = kwargs.get("max_seqlen_k", k.size(0)) |
| 263 | + |
| 264 | + bs = cu_seqlens_q.shape[0] - 1 if cu_seqlens_q is not None else 1 |
| 265 | + |
| 266 | + # TODO: consider more comprehensive conditions here |
| 267 | + use_triton_mha = bs <= 32 |
| 268 | + use_triton_mha = use_triton_mha and (max_seqlen_q <= 1024) |
| 269 | + use_triton_mha = use_triton_mha and (max_seqlen_k <= 1024) |
247 | 270 |
|
248 | | - self.flash_attn_varlen_func = flash_attn_varlen_func |
| 271 | + return use_triton_mha |
249 | 272 |
|
250 | 273 | def _flash_attn_varlen_diff_headdims( |
251 | 274 | self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs |
252 | 275 | ): |
253 | | - output = self.flash_attn_varlen_func( |
254 | | - q=q, |
255 | | - k=k, |
256 | | - v=v, |
257 | | - softmax_scale=softmax_scale, |
258 | | - return_lse=return_softmax_lse, |
259 | | - **kwargs, |
260 | | - ) |
261 | | - |
262 | | - return output |
| 276 | + # force to use triton mha if env var is set, otherwise do dispatch |
| 277 | + if envs.VLLM_ROCM_USE_AITER_TRITON_MLA or self._use_triton_mha(q, k, **kwargs): |
| 278 | + result = self.triton_flash_attn_varlen_func( |
| 279 | + q=q, |
| 280 | + k=k, |
| 281 | + v=v, |
| 282 | + softmax_scale=softmax_scale, |
| 283 | + return_lse=return_softmax_lse, |
| 284 | + **kwargs, |
| 285 | + ) |
| 286 | + if return_softmax_lse and type(result) is tuple: |
| 287 | + output, lse = result |
| 288 | + return (output, lse.T.contiguous()) |
| 289 | + return result |
| 290 | + else: |
| 291 | + output = self.aiter_flash_attn_varlen_func( |
| 292 | + q=q, |
| 293 | + k=k, |
| 294 | + v=v, |
| 295 | + softmax_scale=softmax_scale, |
| 296 | + return_lse=return_softmax_lse, |
| 297 | + **kwargs, |
| 298 | + ) |
| 299 | + return output |
263 | 300 |
|
264 | 301 | def _forward_decode( |
265 | 302 | self, |
|
0 commit comments