Skip to content

Commit aade718

Browse files
[release/2.5] fix and skip test_causal_variants on Navi4x (#2304)
This PR will fix test_transformers.py and skip one test: - Add missed lines in test_transformers.py to fix `NameError: name '_cur_sdpa_kernel_backends' is not defined` - skip `test_transformers.py::TestAttnBias::test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape3_cuda` on Navi4x (tested on gfx1200 and gfx1201). Test failed on shape (1, 1, 23, 56, 15) only on Naxi4x only in release/2.5. Test uses `hipBlas`, but switching to `hipBlasLT` doesn't help release/2.6 has significant changes in `SDPBackend.MATH` backend and the test passes Fixes SWDEV-522844
1 parent 71fc73f commit aade718

File tree

2 files changed

+30
-9
lines changed

2 files changed

+30
-9
lines changed

test/test_transformers.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3779,6 +3779,10 @@ def test_causal_variants(self, device, causal_variant: CausalVariant, shape: Lis
37793779
if TEST_WITH_ROCM and causal_variant == CausalVariant.LOWER_RIGHT:
37803780
self.skipTest("No support for LOWER_RIGHT variant for now")
37813781
return
3782+
if (TEST_WITH_ROCM
3783+
and "gfx12" in torch.cuda.get_device_properties(0).gcnArchName.split(":")[0]
3784+
and self._testMethodName == "test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape3_cuda"):
3785+
self.skipTest(f"Failed on Navi4x in release/2.5 for shape {shape}")
37823786

37833787
bsz, num_heads, seq_len_q, seq_len_kv, head_dim = shape
37843788
make_q_tensor = partial(make_tensor, SdpaShape(bsz, num_heads, seq_len_q, head_dim))

torch/nn/attention/__init__.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,14 @@
11
# mypy: allow-untyped-defs
22
""" This module contains functions and classes that alter the behavior of torch.nn.functional.scaled_dot_product_attention """
33
import contextlib
4-
from typing import List, Union
4+
from typing import Iterable, List, Union
55
from warnings import warn
66

7+
import torch.backends.cuda
78
from torch._C import _SDPBackend as SDPBackend
89
from torch.backends.cuda import (
910
can_use_efficient_attention,
1011
can_use_flash_attention,
11-
cudnn_sdp_enabled,
12-
enable_cudnn_sdp,
13-
enable_flash_sdp,
14-
enable_math_sdp,
15-
enable_mem_efficient_sdp,
16-
flash_sdp_enabled,
17-
math_sdp_enabled,
18-
mem_efficient_sdp_enabled,
1912
SDPAParams,
2013
)
2114

@@ -66,6 +59,30 @@ def _raise_kernel_warnings(params: SDPAParams) -> None:
6659
warn("Flash attention can't be used because:")
6760
can_use_flash_attention(params, True)
6861

62+
_backend_names = {
63+
"cudnn": "CUDNN_ATTENTION",
64+
"flash": "FLASH_ATTENTION",
65+
"mem_efficient": "EFFICIENT_ATTENTION",
66+
"math": "MATH",
67+
}
68+
69+
70+
def _backend_from_string(name: str):
71+
return getattr(SDPBackend, name)
72+
73+
74+
def _cur_sdpa_kernel_backends():
75+
backends: List[SDPBackend] = []
76+
for name, val in _backend_names.items():
77+
if getattr(torch.backends.cuda, f"{name}_sdp_enabled")():
78+
backends.append(getattr(SDPBackend, val))
79+
return backends
80+
81+
82+
def _sdpa_kernel(backends: Iterable[SDPBackend]):
83+
for name, val in _backend_names.items():
84+
enabled = getattr(SDPBackend, val) in backends
85+
getattr(torch.backends.cuda, f"enable_{name}_sdp")(enabled)
6986

7087
@contextlib.contextmanager
7188
def sdpa_kernel(

0 commit comments

Comments
 (0)