Skip to content

Commit 9c5ee91

Browse files
authored
[ROCm] [VL] [Bugfix] Fix vit flash attn dispatcher logic for ROCm (vllm-project#26104)
Signed-off-by: tjtanaa <[email protected]>
1 parent 27edd2a commit 9c5ee91

File tree

9 files changed

+154
-141
lines changed

9 files changed

+154
-141
lines changed

vllm/attention/layer.py

Lines changed: 49 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
"""Attention layer."""
4-
from typing import List, Optional
4+
from typing import Callable, List, Optional
55

66
import torch
77
import torch.nn as nn
@@ -68,9 +68,39 @@ def check_upstream_fa_availability(dtype: torch.dtype):
6868
) and current_platform.has_device_capability(80):
6969
from transformers.utils import is_flash_attn_2_available
7070
return is_flash_attn_2_available()
71+
if current_platform.is_rocm():
72+
from importlib.util import find_spec
73+
return find_spec("flash_attn") is not None
7174
return False
7275

7376

77+
def maybe_get_vit_flash_attn_backend(
78+
attn_backend: _Backend,
79+
use_upstream_fa: bool) -> tuple[_Backend, Callable]:
80+
if attn_backend != _Backend.FLASH_ATTN and \
81+
attn_backend != _Backend.ROCM_AITER_FA and \
82+
check_upstream_fa_availability(torch.get_default_dtype()):
83+
attn_backend = _Backend.FLASH_ATTN
84+
use_upstream_fa = True
85+
86+
if current_platform.is_rocm() and \
87+
attn_backend == _Backend.FLASH_ATTN:
88+
use_upstream_fa = True
89+
90+
if (attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}):
91+
if attn_backend == _Backend.ROCM_AITER_FA:
92+
from aiter import flash_attn_varlen_func
93+
else:
94+
if use_upstream_fa:
95+
from flash_attn import flash_attn_varlen_func
96+
else:
97+
from vllm.vllm_flash_attn import flash_attn_varlen_func
98+
else:
99+
flash_attn_varlen_func = None
100+
101+
return attn_backend, flash_attn_varlen_func
102+
103+
74104
class Attention(nn.Module, AttentionLayerBase):
75105
"""Attention layer.
76106
@@ -410,13 +440,9 @@ def __init__(
410440
# to upstream flash attention if available.
411441
# If vllm native fa is selected, we use it directly.
412442
use_upstream_fa = False
413-
if backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
414-
dtype):
415-
backend = _Backend.FLASH_ATTN
416-
use_upstream_fa = True
417443

418-
if current_platform.is_rocm() or current_platform.is_xpu():
419-
# currently, only torch_sdpa is supported on rocm/xpu
444+
if current_platform.is_xpu():
445+
# currently, only torch_sdpa is supported on xpu
420446
self.attn_backend = _Backend.TORCH_SDPA
421447
else:
422448

@@ -428,17 +454,25 @@ def __init__(
428454
_Backend.FLASH_ATTN,
429455
} else _Backend.TORCH_SDPA
430456

457+
self.attn_backend, self._flash_attn_varlen_func \
458+
= maybe_get_vit_flash_attn_backend(
459+
self.attn_backend,
460+
use_upstream_fa,
461+
)
462+
431463
if (self.attn_backend == _Backend.XFORMERS
432464
and not check_xformers_availability()):
433465
self.attn_backend = _Backend.TORCH_SDPA
434466

435-
if self.attn_backend == _Backend.FLASH_ATTN:
436-
if use_upstream_fa:
437-
from flash_attn import flash_attn_varlen_func
438-
self._flash_attn_varlen_func = flash_attn_varlen_func
439-
else:
440-
from vllm.vllm_flash_attn import flash_attn_varlen_func
441-
self._flash_attn_varlen_func = flash_attn_varlen_func
467+
self.is_flash_attn_backend = self.attn_backend in {
468+
_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
469+
}
470+
471+
# this condition is just to make sure that the
472+
# use_upstream_fa in the log is correct
473+
if current_platform.is_rocm() \
474+
and self.attn_backend == _Backend.FLASH_ATTN:
475+
use_upstream_fa = True
442476

443477
logger.info_once(
444478
f"MultiHeadAttention attn_backend: {self.attn_backend}, "
@@ -466,7 +500,7 @@ def forward(
466500
key = torch.repeat_interleave(key, num_repeat, dim=2)
467501
value = torch.repeat_interleave(value, num_repeat, dim=2)
468502

469-
if self.attn_backend == _Backend.FLASH_ATTN:
503+
if self.is_flash_attn_backend:
470504
cu_seqlens_q = torch.arange(0, (bsz + 1) * q_len,
471505
step=q_len,
472506
dtype=torch.int32,
@@ -507,14 +541,6 @@ def forward(
507541
from torch_xla.experimental.custom_kernel import flash_attention
508542
out = flash_attention(query, key, value, sm_scale=self.scale)
509543
out = out.transpose(1, 2)
510-
elif self.attn_backend == _Backend.ROCM_AITER_FA:
511-
from aiter import flash_attn_varlen_func
512-
513-
# ROCm Flash Attention expects (batch, seq, heads, head_dim)
514-
out = flash_attn_varlen_func(query,
515-
key,
516-
value,
517-
softmax_scale=self.scale)
518544
else:
519545
# ViT attention hasn't supported this backend yet
520546
raise NotImplementedError(

vllm/model_executor/models/dots_ocr.py

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
from transformers.models.qwen2_vl import Qwen2VLProcessor
1111

1212
from vllm.attention.backends.registry import _Backend
13-
from vllm.attention.layer import check_upstream_fa_availability
13+
from vllm.attention.layer import (check_upstream_fa_availability,
14+
maybe_get_vit_flash_attn_backend)
1415
from vllm.config import VllmConfig
1516
from vllm.distributed import utils as dist_utils
1617
from vllm.distributed.parallel_state import (
@@ -267,10 +268,12 @@ def __init__(self,
267268
self.attn_backend = get_vit_attn_backend(
268269
self.hidden_size_per_attention_head, torch.get_default_dtype())
269270
self.use_upstream_fa = False
270-
if self.attn_backend != _Backend.FLASH_ATTN and \
271-
check_upstream_fa_availability(torch.get_default_dtype()):
272-
self.attn_backend = _Backend.FLASH_ATTN
273-
self.use_upstream_fa = True
271+
272+
self.attn_backend, self.flash_attn_varlen_func \
273+
= maybe_get_vit_flash_attn_backend(
274+
self.attn_backend,
275+
self.use_upstream_fa,
276+
)
274277
if self.attn_backend not in {
275278
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
276279
_Backend.ROCM_AITER_FA
@@ -306,25 +309,18 @@ def forward(
306309
q, k = torch.chunk(qk_rotated, 2, dim=0)
307310

308311
if self.is_flash_attn_backend:
309-
if self.attn_backend == _Backend.ROCM_AITER_FA:
310-
from aiter import flash_attn_varlen_func
311-
else:
312-
if self.use_upstream_fa:
313-
from flash_attn import flash_attn_varlen_func
314-
else:
315-
from vllm.vllm_flash_attn import flash_attn_varlen_func
316312
q_ = q.reshape(bs * q.shape[1], q.shape[2], q.shape[3])
317313
k_ = k.reshape(bs * k.shape[1], k.shape[2], k.shape[3])
318314
v_ = v.reshape(bs * v.shape[1], v.shape[2], v.shape[3])
319-
output = flash_attn_varlen_func(q_,
320-
k_,
321-
v_,
322-
cu_seqlens_q=cu_seqlens,
323-
cu_seqlens_k=cu_seqlens,
324-
max_seqlen_q=max_seqlen,
325-
max_seqlen_k=max_seqlen,
326-
dropout_p=0.0,
327-
causal=False)
315+
output = self.flash_attn_varlen_func(q_,
316+
k_,
317+
v_,
318+
cu_seqlens_q=cu_seqlens,
319+
cu_seqlens_k=cu_seqlens,
320+
max_seqlen_q=max_seqlen,
321+
max_seqlen_k=max_seqlen,
322+
dropout_p=0.0,
323+
causal=False)
328324
context_layer = output.view(bs, -1,
329325
self.num_attention_heads_per_partition,
330326
self.hidden_size_per_attention_head)
@@ -611,7 +607,8 @@ def compute_attn_mask_seqlen(
611607
self, cu_seqlens: torch.Tensor
612608
) -> tuple[Optional[int], Optional[list[int]]]:
613609
max_seqlen, seqlens = None, None
614-
if self.attn_backend == _Backend.FLASH_ATTN:
610+
if (self.attn_backend == _Backend.FLASH_ATTN
611+
or self.attn_backend == _Backend.ROCM_AITER_FA):
615612
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
616613
elif self.attn_backend == _Backend.XFORMERS:
617614
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()

vllm/model_executor/models/ernie45_vl.py

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@
3535
from transformers import BatchFeature
3636

3737
from vllm.attention.backends.registry import _Backend
38-
from vllm.attention.layer import check_upstream_fa_availability
38+
from vllm.attention.layer import (check_upstream_fa_availability,
39+
maybe_get_vit_flash_attn_backend)
3940
from vllm.config import VllmConfig
4041
from vllm.distributed import parallel_state
4142
from vllm.distributed import utils as dist_utils
@@ -176,14 +177,18 @@ def __init__(
176177
dtype=torch.get_default_dtype())
177178

178179
self.use_upstream_fa = False
179-
if self.attn_backend != _Backend.FLASH_ATTN and \
180-
check_upstream_fa_availability(torch.get_default_dtype()):
181-
self.attn_backend = _Backend.FLASH_ATTN
182-
self.use_upstream_fa = True
180+
181+
self.attn_backend, self.flash_attn_varlen_func \
182+
= maybe_get_vit_flash_attn_backend(
183+
self.attn_backend,
184+
self.use_upstream_fa,
185+
)
183186

184187
if self.attn_backend not in {
185-
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
186-
_Backend.ROCM_AITER_FA
188+
_Backend.FLASH_ATTN,
189+
_Backend.TORCH_SDPA,
190+
_Backend.XFORMERS,
191+
_Backend.ROCM_AITER_FA,
187192
}:
188193
raise RuntimeError(
189194
f"Ernie45-VL does not support {self.attn_backend} backend now."
@@ -239,27 +244,18 @@ def forward(
239244
q, k = torch.chunk(qk_rotated, 2, dim=0)
240245

241246
if self.is_flash_attn_backend:
242-
# from vllm_flash_attn.flash_attn_interface import (
243-
# flash_attn_varlen_func)
244-
if self.attn_backend == _Backend.ROCM_AITER_FA:
245-
from aiter import flash_attn_varlen_func
246-
else:
247-
if self.use_upstream_fa:
248-
from flash_attn import flash_attn_varlen_func
249-
else:
250-
from vllm.vllm_flash_attn import flash_attn_varlen_func
251247

252248
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
253249

254-
output = flash_attn_varlen_func(q,
255-
k,
256-
v,
257-
cu_seqlens_q=cu_seqlens,
258-
cu_seqlens_k=cu_seqlens,
259-
max_seqlen_q=max_seqlen,
260-
max_seqlen_k=max_seqlen,
261-
dropout_p=0.0,
262-
causal=False)
250+
output = self.flash_attn_varlen_func(q,
251+
k,
252+
v,
253+
cu_seqlens_q=cu_seqlens,
254+
cu_seqlens_k=cu_seqlens,
255+
max_seqlen_q=max_seqlen,
256+
max_seqlen_k=max_seqlen,
257+
dropout_p=0.0,
258+
causal=False)
263259

264260
context_layer = rearrange(output,
265261
"(b s) h d -> s b (h d)",
@@ -516,7 +512,8 @@ def compute_attn_mask_seqlen(
516512
self, cu_seqlens: torch.Tensor
517513
) -> tuple[Optional[int], Optional[list[int]]]:
518514
max_seqlen, seqlens = None, None
519-
if self.attn_backend == _Backend.FLASH_ATTN:
515+
if (self.attn_backend == _Backend.FLASH_ATTN
516+
or self.attn_backend == _Backend.ROCM_AITER_FA):
520517
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
521518
elif self.attn_backend == _Backend.XFORMERS:
522519
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()

vllm/model_executor/models/glm4_1v.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@
4747
from transformers.video_utils import VideoMetadata
4848

4949
from vllm.attention.backends.registry import _Backend
50-
from vllm.attention.layer import check_upstream_fa_availability
50+
from vllm.attention.layer import (check_upstream_fa_availability,
51+
maybe_get_vit_flash_attn_backend)
5152
from vllm.config import VllmConfig
5253
from vllm.distributed import (get_tensor_model_parallel_world_size,
5354
parallel_state)
@@ -263,19 +264,26 @@ def __init__(
263264
head_size=self.hidden_size_per_attention_head,
264265
dtype=torch.get_default_dtype())
265266
self.use_upstream_fa = False
266-
if self.attn_backend != _Backend.FLASH_ATTN and \
267-
check_upstream_fa_availability(torch.get_default_dtype()):
268-
self.attn_backend = _Backend.FLASH_ATTN
269-
self.use_upstream_fa = True
267+
268+
self.attn_backend, self.flash_attn_varlen_func \
269+
= maybe_get_vit_flash_attn_backend(
270+
self.attn_backend,
271+
self.use_upstream_fa,
272+
)
270273

271274
if self.attn_backend not in {
272275
_Backend.FLASH_ATTN,
273276
_Backend.TORCH_SDPA,
274277
_Backend.XFORMERS,
278+
_Backend.ROCM_AITER_FA,
275279
}:
276280
raise RuntimeError(
277281
f"GLM-4V does not support {self.attn_backend} backend now.")
278282

283+
self.is_flash_attn_backend = self.attn_backend in {
284+
_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
285+
}
286+
279287
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
280288
# [s, b, 3 * head * head_dim]
281289
seq_len, bs, _ = qkv.shape
@@ -316,17 +324,11 @@ def forward(
316324
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
317325
q, k = torch.chunk(qk_rotated, 2, dim=0)
318326

319-
if self.attn_backend == _Backend.FLASH_ATTN:
320-
# from vllm_flash_attn.flash_attn_interface import (
321-
# flash_attn_varlen_func)
322-
if self.use_upstream_fa:
323-
from flash_attn import flash_attn_varlen_func
324-
else:
325-
from vllm.vllm_flash_attn import flash_attn_varlen_func
327+
if self.is_flash_attn_backend:
326328

327329
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
328330

329-
output = flash_attn_varlen_func(
331+
output = self.flash_attn_varlen_func(
330332
q,
331333
k,
332334
v,
@@ -774,7 +776,8 @@ def compute_attn_mask_seqlen(
774776
) -> tuple[Optional[int], Optional[list[int]]]:
775777
max_seqlen, seqlens = None, None
776778
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
777-
if self.attn_backend == _Backend.FLASH_ATTN:
779+
if (self.attn_backend == _Backend.FLASH_ATTN
780+
or self.attn_backend == _Backend.ROCM_AITER_FA):
778781
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
779782
return max_seqlen, seqlens
780783

0 commit comments

Comments
 (0)