Skip to content

Commit 7ab33a4

Browse files
author
guanbao
committed
add mha dispatch logic
Signed-off-by: guanbao <[email protected]>
1 parent c783a5e commit 7ab33a4

File tree

2 files changed

+55
-12
lines changed

2 files changed

+55
-12
lines changed

vllm/envs.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@
108108
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
109109
VLLM_ROCM_USE_AITER_MOE: bool = True
110110
VLLM_ROCM_USE_AITER_MLA: bool = True
111+
VLLM_ROCM_USE_AITER_TRITON_MLA: bool = False
111112
VLLM_ROCM_USE_AITER_MHA: bool = True
112113
VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: bool = False
113114
VLLM_ROCM_USE_TRITON_ROPE: bool = True
@@ -879,6 +880,11 @@ def get_vllm_port() -> int | None:
879880
"VLLM_ROCM_USE_AITER_MLA": lambda: (
880881
os.getenv("VLLM_ROCM_USE_AITER_MLA", "True").lower() in ("true", "1")
881882
),
883+
# Whether to use aiter triton mla ops.
884+
# By default is disabled.
885+
"VLLM_ROCM_USE_AITER_TRITON_MLA": lambda: (
886+
os.getenv("VLLM_ROCM_USE_AITER_TRITON_MLA", "False").lower() in ("true", "1")
887+
),
882888
# Whether to use aiter mha ops.
883889
# By default is enabled.
884890
"VLLM_ROCM_USE_AITER_MHA": lambda: (

vllm/v1/attention/backends/mla/rocm_aiter_mla.py

Lines changed: 49 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from vllm.attention.backends.abstract import AttentionLayer
1111
from vllm.attention.ops.rocm_aiter_mla import aiter_mla_decode_fwd
1212
from vllm.config import VllmConfig
13+
from vllm.platforms.rocm import on_gfx950
1314
from vllm.utils import cdiv
1415
from vllm.v1.attention.backends.mla.common import (
1516
MLACommonBackend,
@@ -243,23 +244,59 @@ def __init__(
243244
"alibi_slopes, sliding_window, logits_soft_cap"
244245
)
245246

246-
from aiter import flash_attn_varlen_func
247+
from aiter import flash_attn_varlen_func as aiter_flash_attn_varlen_func
248+
from aiter.ops.triton.mha import (
249+
flash_attn_varlen_func as triton_flash_attn_varlen_func,
250+
)
251+
252+
self.triton_flash_attn_varlen_func = triton_flash_attn_varlen_func
253+
self.aiter_flash_attn_varlen_func = aiter_flash_attn_varlen_func
254+
255+
def _use_triton_mha(self, q, k, **kwargs) -> bool:
256+
# TODO: refine dispatch logic on other non-GFX950 GPUs
257+
if not on_gfx950():
258+
return False
259+
260+
cu_seqlens_q = kwargs.get("cu_seqlens_q")
261+
max_seqlen_q = kwargs.get("max_seqlen_q", q.size(0))
262+
max_seqlen_k = kwargs.get("max_seqlen_k", k.size(0))
263+
264+
bs = cu_seqlens_q.shape[0] - 1 if cu_seqlens_q is not None else 1
265+
266+
# TODO: consider more comprehensive conditions here
267+
use_triton_mha = bs <= 32
268+
use_triton_mha = use_triton_mha and (max_seqlen_q <= 1024)
269+
use_triton_mha = use_triton_mha and (max_seqlen_k <= 1024)
247270

248-
self.flash_attn_varlen_func = flash_attn_varlen_func
271+
return use_triton_mha
249272

250273
def _flash_attn_varlen_diff_headdims(
251274
self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
252275
):
253-
output = self.flash_attn_varlen_func(
254-
q=q,
255-
k=k,
256-
v=v,
257-
softmax_scale=softmax_scale,
258-
return_lse=return_softmax_lse,
259-
**kwargs,
260-
)
261-
262-
return output
276+
# force to use triton mha if env var is set, otherwise do dispatch
277+
if envs.VLLM_ROCM_USE_AITER_TRITON_MLA or self._use_triton_mha(q, k, **kwargs):
278+
result = self.triton_flash_attn_varlen_func(
279+
q=q,
280+
k=k,
281+
v=v,
282+
softmax_scale=softmax_scale,
283+
return_lse=return_softmax_lse,
284+
**kwargs,
285+
)
286+
if return_softmax_lse and type(result) is tuple:
287+
output, lse = result
288+
return (output, lse.T.contiguous())
289+
return result
290+
else:
291+
output = self.aiter_flash_attn_varlen_func(
292+
q=q,
293+
k=k,
294+
v=v,
295+
softmax_scale=softmax_scale,
296+
return_lse=return_softmax_lse,
297+
**kwargs,
298+
)
299+
return output
263300

264301
def _forward_decode(
265302
self,

0 commit comments

Comments
 (0)