Skip to content

Commit caf1b51

Browse files
xinyazhangjithunnair-amd
authored andcommitted
[ROCm] Transformer/SDPA unit test parity (pytorch#163745)
## Major Changes * Efficient Attention on ROCM requires last dimensions of input tensors align with 16 bytes. - Unlike FA, ME does not pad input tensors in `scaled_dot_product_attention` and hence this is required. * Fix `atomic_counter` handling in varlen FA API * Unskips a few unit tests. Fixes pytorch#157120 Fixes pytorch#157121 Fixes pytorch#157122 Fixes pytorch#157167 Fixes pytorch#155217 Fixes pytorch#157043 Fixes pytorch#157060 Pull Request resolved: pytorch#163745 Approved by: https://github.com/jeffdaily
1 parent 5f9c692 commit caf1b51

File tree

6 files changed

+29
-30
lines changed

6 files changed

+29
-30
lines changed

aten/src/ATen/native/transformers/cuda/sdp_utils.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,28 @@ bool check_head_dim_size_flash(sdp_params const& params, bool debug) {
176176
}
177177
return false;
178178
}
179+
if constexpr(caller_is_meff) {
180+
bool is_half = (params.query.dtype() == at::kHalf) ||
181+
(params.query.dtype() == at::kBFloat16);
182+
const int64_t alignment = is_half ? 8 : 4;
183+
if (!(query_size_last % alignment == 0 && query_size_last > 0 &&
184+
value_size_last % alignment == 0 && value_size_last > 0)) {
185+
if (debug) {
186+
TORCH_WARN(
187+
"Mem efficient attention requires last dimension of inputs to be divisible by ",
188+
alignment,
189+
". ",
190+
"Got Query.size(-1): ",
191+
query_size_last,
192+
", Key.size(-1): ",
193+
params.key.sym_size(-1),
194+
", Value.size(-1): ",
195+
params.value.sym_size(-1),
196+
" instead.");
197+
}
198+
return false;
199+
}
200+
}
179201
return true;
180202
}
181203

aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -462,10 +462,11 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot
462462
using sdp::aotriton_adapter::mk_aotensor;
463463
using sdp::aotriton_adapter::mk_aoscalartensor;
464464
using sdp::aotriton_adapter::mk_philoxtensor;
465+
using sdp::aotriton_adapter::mk_atomictensor;
465466
using sdp::aotriton_adapter::cast_dtype;
466467
at::Tensor atomic_counter;
467468
if (is_causal) {
468-
atomic_counter = at::zeros({1}, q.options());
469+
atomic_counter = at::zeros({1}, q.options().dtype(at::kInt));
469470
}
470471
aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype()));
471472
auto seed = use_philox_state ? mk_philoxtensor(philox_state.seed_.ptr) : mk_aoscalartensor(seed_t);
@@ -474,7 +475,7 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot
474475
auto nullscalar = mk_philoxtensor(nullptr);
475476
auto seed_output = use_philox_state ? mk_philoxtensor(seed_t.data_ptr<int64_t>()) : nullscalar;
476477
auto offset_output = use_philox_state ? mk_philoxtensor(offset_t.data_ptr<int64_t>()) : nullscalar;
477-
auto persistent_counter = is_causal ? mk_philoxtensor(atomic_counter.data_ptr<int64_t>()) : nullscalar;
478+
auto persistent_counter = mk_atomictensor(is_causal ? atomic_counter.data_ptr<int32_t>() : nullptr);
478479
if (uses_swa || AOTRITON_ALWAYS_V3_API) {
479480
#if AOTRITON_V3_API
480481
using aotriton::v3::flash::CausalType;

test/nn/test_multihead_attention.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
instantiate_parametrized_tests,
1818
parametrize as parametrize_test,
1919
run_tests,
20-
skipIfRocm,
2120
TEST_NUMPY,
2221
TEST_WITH_CROSSREF,
2322
)
@@ -746,7 +745,6 @@ def test_multihead_attn_nested_tensor_outside_fast_path(self):
746745

747746

748747
class TestMultiheadAttentionNNDeviceType(NNTestCase):
749-
@skipIfRocm(msg="To investigate: yields NaN")
750748
def test_multihead_self_attn_two_masks_fast_path(self, device):
751749
"""
752750
Multihead self-attention should give the same result on the fast path (BetterTransformer) as on the slow path

test/test_flop_counter.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
)
1616
from torch.testing._internal.common_utils import (
1717
run_tests,
18-
skipIfRocm,
1918
TEST_WITH_TORCHDYNAMO,
2019
TestCase,
2120
)
@@ -463,7 +462,6 @@ def get_flops(
463462
self.assertExpectedInline(str(flops_fw_bw_math), """805306368""")
464463
self.assertExpectedInline(str(flops_fw_bw_efficient), """939524096""")
465464

466-
@skipIfRocm # Nested tensor
467465
@unittest.skipIf(not HAS_CUDA, "CUDA not available")
468466
@unittest.skipIf(
469467
not PLATFORM_SUPPORTS_FLASH_ATTENTION
@@ -683,7 +681,6 @@ def split_tensor(x):
683681
),
684682
)
685683

686-
@skipIfRocm # Nested tensor
687684
@unittest.skipIf(not HAS_CUDA, "CUDA not available")
688685
@unittest.skipIf(
689686
not PLATFORM_SUPPORTS_FLASH_ATTENTION,

test/test_nn.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
parametrize as parametrize_test, subtest, instantiate_parametrized_tests, \
4040
skipIfTorchDynamo, gcIfJetson, set_default_dtype
4141
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, \
42-
PLATFORM_SUPPORTS_FLASH_ATTENTION, _get_torch_rocm_version
42+
_get_torch_rocm_version
4343
from torch.testing._internal.common_nn import NNTestCase, NewModuleTest, CriterionTest, \
4444
module_tests, criterion_tests, loss_reference_fns, _create_basic_net, \
4545
ctcloss_reference, get_new_module_tests, single_batch_reference_fn, _test_bfloat16_ops, _test_module_empty_input
@@ -3166,7 +3166,6 @@ def perm_fn(x):
31663166
[2.42240309, 0.0354595, -0.60659063, -0.05378816]]]))
31673167
torch.testing.assert_close(result, ref_output, rtol=1e-5, atol=0)
31683168

3169-
@skipIfRocm(msg='Large numerical errors')
31703169
def test_transformerdecoder(self):
31713170
def get_a_test_layer(use_cuda, activation, batch_first=False):
31723171
d_model = 4
@@ -12998,8 +12997,6 @@ def test_skip_init(self, device):
1299812997
@dtypes(torch.float)
1299912998
@dtypesIfCUDA(torch.double, torch.float, torch.half)
1300012999
def test_transformerencoderlayer(self, device, dtype):
13001-
if TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION and dtype == torch.half:
13002-
self.skipTest("Skip on ROCM due to Flash Attention tolerances")
1300313000
# this is a deterministic test for TransformerEncoderLayer
1300413001
d_model = 4
1300513002
nhead = 2
@@ -13221,8 +13218,6 @@ def test_transformerencoderlayer_fast_path(self, device, dtype):
1322113218
@dtypes(torch.float)
1322213219
@dtypesIfCUDA(torch.half, torch.float)
1322313220
def test_transformerencoderlayer_gelu(self, device, dtype):
13224-
if TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION and dtype == torch.half:
13225-
self.skipTest("Skip on ROCM due to Flash Attention tolerances")
1322613221
# this is a deterministic test for TransformerEncoderLayer with gelu activation
1322713222
d_model = 4
1322813223
nhead = 2

test/test_transformers.py

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -344,9 +344,6 @@ def test_train_with_pad_and_catch_error(self, device):
344344
@parametrize("key_padding_mask_dim", [2, None])
345345
@parametrize("mask_dtype", [torch.bool, torch.float32])
346346
def test_multiheadattention_fastpath_attn_mask(self, device, attn_mask_dim, key_padding_mask_dim, mask_dtype):
347-
if TEST_WITH_ROCM:
348-
if attn_mask_dim is not None and mask_dtype == torch.bool:
349-
self.skipTest("boolean mask is not fully supported on ROCm yet.")
350347
# MHA converts all
351348
with torch.no_grad():
352349
B = 2
@@ -429,8 +426,7 @@ def hook(module, inputs, output):
429426
# remove hook
430427
handle.remove()
431428

432-
@skipIfRocm
433-
@tf32_on_and_off(0.001)
429+
@tf32_on_and_off(0.0021 if TEST_WITH_ROCM else 0.001)
434430
@parametrize("use_torchscript", [False])
435431
@parametrize("enable_nested_tensor", [True, False])
436432
@parametrize("use_autocast", [True, False])
@@ -1420,7 +1416,6 @@ def ones_tensor(*shape):
14201416
_ = mha_f(qkv_f, qkv_f, qkv_f, attn_mask=mask, need_weights=False, is_causal=True)
14211417
torch.cuda.synchronize()
14221418

1423-
@skipIfRocm # Missing EFFICIENT_ATTENTION
14241419
@unittest.skipIf(
14251420
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Platform does not supposrt fused SDPA or pre-SM80 hardware"
14261421
)
@@ -1713,7 +1708,7 @@ def test_unaligned_tensors(self, device):
17131708
make_tensor = partial(torch.rand, size, device=device, dtype=dtype)
17141709
q, k, v = make_tensor(), make_tensor(), make_tensor()
17151710
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
1716-
ctxmgr = self.assertRaises(RuntimeError) if not TEST_WITH_ROCM else contextlib.nullcontext()
1711+
ctxmgr = self.assertRaises(RuntimeError)
17171712
with ctxmgr:
17181713
torch.nn.functional.scaled_dot_product_attention(q, k, v, None, 0.0, False)
17191714

@@ -2611,7 +2606,6 @@ def convert_flash_attn_S_to_softmax(
26112606
S_converted = F.pad(S_converted, (0, seqlen_k_og - seqlen_k_rounded))
26122607
return S_converted[:, :, :seqlen_q, :seqlen_k]
26132608

2614-
@skipIfRocm # No cuDNN Attention
26152609
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system")
26162610
def test_cudnn_attention_different_dk_dv(self, device):
26172611
dtype = torch.bfloat16
@@ -2635,7 +2629,6 @@ def test_cudnn_attention_different_dk_dv(self, device):
26352629

26362630
self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2)
26372631

2638-
@skipIfRocm # No cuDNN Attention
26392632
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system")
26402633
def test_cudnn_attention_gqa(self, device):
26412634
batch = 4
@@ -2659,7 +2652,6 @@ def test_cudnn_attention_gqa(self, device):
26592652

26602653
self.assertEqual(output_math, output_cudnn)
26612654

2662-
@skipIfRocm # No cuDNN Attention
26632655
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system")
26642656
def test_cudnn_attention_d256_heuristic(self, device):
26652657
dtype = torch.bfloat16
@@ -2690,7 +2682,6 @@ def test():
26902682
with self.assertRaisesRegex(RuntimeError, "No available kernel."):
26912683
test()
26922684

2693-
@skipIfRocm(msg="No cuDNN on ROCm")
26942685
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system")
26952686
def test_fused_attention_different_dk_dv(self, device):
26962687
dtype = torch.bfloat16
@@ -2714,7 +2705,7 @@ def test_fused_attention_different_dk_dv(self, device):
27142705
self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2)
27152706

27162707

2717-
@skipIfRocm # No cuDNN Attention
2708+
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system")
27182709
@unittest.skipIf(True, "broken as of cuDNN 9.10")
27192710
def test_cudnn_attention_fail_d128(self, device):
27202711
# Test that cuDNN attention dispatching correctly bails out on d > 128
@@ -2736,7 +2727,6 @@ def test_cudnn_attention_fail_d128(self, device):
27362727
with self.assertRaisesRegex(RuntimeError, "No available kernel."):
27372728
torch.nn.functional.scaled_dot_product_attention(q, k, v)
27382729

2739-
@skipIfRocm(msg="No cuDNN on ROCm")
27402730
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system")
27412731
def test_cudnn_attention_trivial_output_transpose(self, device):
27422732
# see also: https://github.com/pytorch/pytorch/issues/134001
@@ -2752,7 +2742,6 @@ def test_cudnn_attention_trivial_output_transpose(self, device):
27522742
o.backward(o)
27532743
torch.testing.assert_close(x.grad, x_cpu.grad.cuda(), atol=7e-3, rtol=7e-3)
27542744

2755-
@skipIfRocm # No cuDNN Attention
27562745
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system")
27572746
def test_cudnn_attention_nonmodulo64seqlen(self, device):
27582747
# see also: https://github.com/pytorch/pytorch/issues/137347
@@ -2792,7 +2781,6 @@ def test_cudnn_attention_nonmodulo64seqlen(self, device):
27922781
torch.testing.assert_close(k.grad, k_cpu.grad.cuda(), atol=3e-3, rtol=2e-3)
27932782
torch.testing.assert_close(v.grad, v_cpu.grad.cuda(), atol=3e-3, rtol=2e-3)
27942783

2795-
@skipIfRocm
27962784
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system")
27972785
def test_cudnn_attention_preserves_query_layout(self, device):
27982786

@@ -2822,7 +2810,6 @@ def test_attention(backend: SDPBackend, permute_order: list[list[int]]):
28222810
for permute_order in permute_orders:
28232811
test_attention(SDPBackend.CUDNN_ATTENTION, list(permute_order) + [3])
28242812

2825-
@skipIfRocm
28262813
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system")
28272814
def test_cudnn_attention_compiles(self):
28282815
q = torch.randn(2, 8, 1024, 128, dtype=torch.half, device='cuda', requires_grad=True)
@@ -3241,7 +3228,6 @@ def test_sdp_choice_with_determinism(self, device, warn_only):
32413228
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]):
32423229
assert torch._fused_sdp_choice(query, key, value) == SDPBackend.EFFICIENT_ATTENTION.value
32433230

3244-
@skipIfRocm
32453231
@onlyCUDA
32463232
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system")
32473233
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Platform does not support fused SDPA")

0 commit comments

Comments
 (0)