Skip to content

Commit 2550f14

Browse files
maleksan85Aleksandr Malyshev
andauthored
llama3.2 + cross attn test (#220)
* llama3.2 + cross attn test * lint issues fix * mypy errors * making yapf happy * cut off WA for tunned gemms * try and catch for non continuous tensor --------- Co-authored-by: Aleksandr Malyshev <[email protected]>
1 parent 4075b35 commit 2550f14

File tree

5 files changed

+280
-94
lines changed

5 files changed

+280
-94
lines changed

tests/kernels/test_encoder_decoder_attn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
from vllm.utils import is_hip
2222

2323
# List of support backends for encoder/decoder models
24-
LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS]
24+
LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS] if not is_hip() \
25+
else [_Backend.ROCM_FLASH]
2526

2627
HEAD_SIZES = [64, 256]
2728

@@ -807,7 +808,6 @@ def test_encoder_only(
807808
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out)
808809

809810

810-
@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
811811
@pytest.mark.parametrize("num_heads", NUM_HEADS)
812812
@pytest.mark.parametrize("head_size", HEAD_SIZES)
813813
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)

tests/kernels/utils.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
from torch._prims_common import TensorLikeType
1313

1414
from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
15-
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL,
16-
make_tensor_with_pad)
15+
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_ROCM_FLASH_ATTN_VAL,
16+
STR_XFORMERS_ATTN_VAL, make_tensor_with_pad)
1717

1818
# For now, disable "test_aot_dispatch_dynamic" since there are some
1919
# bugs related to this test in PyTorch 2.4.
@@ -524,8 +524,13 @@ def make_backend(backend_name: str) -> AttentionBackend:
524524
if backend_name == STR_XFORMERS_ATTN_VAL:
525525
# NOTE: xFormers backend cannot be imported for CPU and AMD GPUs.
526526
from vllm.attention.backends.xformers import XFormersBackend
527-
528527
return XFormersBackend()
528+
529+
if backend_name == STR_ROCM_FLASH_ATTN_VAL:
530+
from vllm.attention.backends.rocm_flash_attn import ( # noqa: F401
531+
ROCmFlashAttentionBackend)
532+
return ROCmFlashAttentionBackend
533+
529534
raise AssertionError(
530535
f"Unrecognized backend_name {backend_name} for unit test")
531536

0 commit comments

Comments
 (0)