@@ -274,10 +274,14 @@ def __init__(
274
274
# Detect attention implementation.
275
275
self .attn_backend : _Backend = get_vit_attn_backend (support_fa = True )
276
276
if self .attn_backend not in {
277
- _Backend .FLASH_ATTN , _Backend .TORCH_SDPA , _Backend .XFORMERS
277
+ _Backend .FLASH_ATTN , _Backend .TORCH_SDPA , _Backend .XFORMERS ,
278
+ _Backend .ROCM_AITER_FA
278
279
}:
279
280
raise RuntimeError (
280
281
f"Qwen2-VL does not support { self .attn_backend } backend now." )
282
+ self .is_flash_attn_backend = self .attn_backend in {
283
+ _Backend .FLASH_ATTN , _Backend .ROCM_AITER_FA
284
+ }
281
285
282
286
def split_qkv (self , qkv : torch .Tensor ) -> tuple [torch .Tensor , ...]:
283
287
# [s, b, 3 * head * head_dim]
@@ -324,10 +328,13 @@ def forward(
324
328
q = apply_rotary_pos_emb_vision (q , rotary_pos_emb )
325
329
k = apply_rotary_pos_emb_vision (k , rotary_pos_emb )
326
330
327
- if self .attn_backend == _Backend . FLASH_ATTN :
331
+ if self .is_flash_attn_backend :
328
332
# from vllm_flash_attn.flash_attn_interface import (
329
333
# flash_attn_varlen_func)
330
- from flash_attn import flash_attn_varlen_func
334
+ if self .attn_backend == _Backend .ROCM_AITER_FA :
335
+ from aiter import flash_attn_varlen_func
336
+ else :
337
+ from flash_attn import flash_attn_varlen_func
331
338
332
339
q , k , v = (rearrange (x , "b s ... -> (b s) ..." ) for x in [q , k , v ])
333
340
@@ -338,7 +345,7 @@ def forward(
338
345
cu_seqlens_k = cu_seqlens ,
339
346
max_seqlen_q = max_seqlen ,
340
347
max_seqlen_k = max_seqlen ,
341
- dropout_p = 0 ,
348
+ dropout_p = 0.0 ,
342
349
causal = False )
343
350
344
351
context_layer = rearrange (output ,
@@ -620,7 +627,8 @@ def compute_attn_mask_seqlen(
620
627
self , cu_seqlens : torch .Tensor
621
628
) -> tuple [Optional [int ], Optional [list [int ]]]:
622
629
max_seqlen , seqlens = None , None
623
- if self .attn_backend == _Backend .FLASH_ATTN :
630
+ if (self .attn_backend == _Backend .FLASH_ATTN
631
+ or self .attn_backend == _Backend .ROCM_AITER_FA ):
624
632
max_seqlen = (cu_seqlens [1 :] - cu_seqlens [:- 1 ]).max ().item ()
625
633
elif self .attn_backend == _Backend .XFORMERS :
626
634
seqlens = (cu_seqlens [1 :] - cu_seqlens [:- 1 ]).tolist ()
0 commit comments