Skip to content

Commit 38a7521

Browse files
bottlermeta-codesync[bot]
authored andcommitted
Fixes to FMHA tests on amd (#238)
Summary: Pull Request resolved: #238 Further fixes to tests Reviewed By: cthi Differential Revision: D97479860 fbshipit-source-id: 3ff44140478550a79346c9f5be2b7d8a0f003920
1 parent 74820e9 commit 38a7521

File tree

3 files changed

+10
-1
lines changed

3 files changed

+10
-1
lines changed

test/attention/fmha/test_fmha_merge_attentions.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from .utils import (
2424
assert_allclose,
25+
cuda_only,
2526
disable_on_rocm,
2627
sm80_or_better_only,
2728
UNSUPPORTED_OP_PASSES,
@@ -479,6 +480,7 @@ def test_merge_attentions_sharedinput(
479480
)
480481

481482

483+
@cuda_only
482484
@sm80_or_better_only
483485
@pytest.mark.parametrize("bmghk", (False, True))
484486
def test_merge_attentions_against_ref(bmghk: bool):
@@ -685,6 +687,7 @@ def test_merge_training_zilch():
685687

686688

687689
@sm80_or_better_only
690+
@cuda_only
688691
def test_merge_training_undilate():
689692
torch.manual_seed(1)
690693

test/attention/fmha/test_fmha_split_blocks_fairinternal.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def test_split_blocks_for_decoding():
4646
assert (chunked_bias.k_seqinfo.seqstart >= attn_bias.k_seqinfo.seqstart).all()
4747

4848

49+
@cuda_only
4950
def test_split_blocks_for_decoding_with_paged():
5051
torch.manual_seed(0)
5152
max_len_kv = 2048

test/attention/fmha/test_mem_eff_attention.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ def test_dropout_ck(q_len, kv_len, batch_size, k_len, p, seed, attn_bias):
313313
def test_dropout_backward_ck(q_len, kv_len, batch_size, k, p):
314314
op = fmha.ck.FwOp
315315
dtype = torch.float16
316-
if not op.is_available():
316+
if not fmha.ck.BwOp.is_available():
317317
if UNSUPPORTED_OP_PASSES:
318318
return
319319
pytest.skip()
@@ -614,6 +614,7 @@ def test_unsupported_stride_alignment(op: Type[fmha.AttentionFwOpBase]):
614614

615615

616616
@sm75_or_better_only
617+
@cuda_only
617618
def test_unsupported_dropout_combine_flash_cutlass() -> None:
618619
q = torch.empty(
619620
[1, 4, 1, 16], device="cuda", dtype=torch.float16, requires_grad=True
@@ -1893,6 +1894,10 @@ def test_memeff_compile(bias_t, create_bias_inside_compiled: bool, op) -> None:
18931894
if UNSUPPORTED_OP_PASSES:
18941895
return
18951896
pytest.skip("Op is not available")
1897+
if (not not torch.version.hip) and not fmha.ck.BwOp.is_available():
1898+
if UNSUPPORTED_OP_PASSES:
1899+
return
1900+
pytest.skip("Op is not available")
18961901
torch._dynamo.reset_code_caches() # avoids hitting recompilation limit
18971902
B, M, H, K = 1, 256, 2, 64
18981903
q, k, v, bias = create_tensors(

0 commit comments

Comments
 (0)