@@ -344,9 +344,6 @@ def test_train_with_pad_and_catch_error(self, device):
344344 @parametrize ("key_padding_mask_dim" , [2 , None ])
345345 @parametrize ("mask_dtype" , [torch .bool , torch .float32 ])
346346 def test_multiheadattention_fastpath_attn_mask (self , device , attn_mask_dim , key_padding_mask_dim , mask_dtype ):
347- if TEST_WITH_ROCM :
348- if attn_mask_dim is not None and mask_dtype == torch .bool :
349- self .skipTest ("boolean mask is not fully supported on ROCm yet." )
350347 # MHA converts all
351348 with torch .no_grad ():
352349 B = 2
@@ -429,8 +426,7 @@ def hook(module, inputs, output):
429426 # remove hook
430427 handle .remove ()
431428
432- @skipIfRocm
433- @tf32_on_and_off (0.001 )
429+ @tf32_on_and_off (0.0021 if TEST_WITH_ROCM else 0.001 )
434430 @parametrize ("use_torchscript" , [False ])
435431 @parametrize ("enable_nested_tensor" , [True , False ])
436432 @parametrize ("use_autocast" , [True , False ])
@@ -1420,7 +1416,6 @@ def ones_tensor(*shape):
14201416 _ = mha_f (qkv_f , qkv_f , qkv_f , attn_mask = mask , need_weights = False , is_causal = True )
14211417 torch .cuda .synchronize ()
14221418
1423- @skipIfRocm # Missing EFFICIENT_ATTENTION
14241419 @unittest .skipIf (
14251420 not PLATFORM_SUPPORTS_FLASH_ATTENTION , "Platform does not supposrt fused SDPA or pre-SM80 hardware"
14261421 )
@@ -1713,7 +1708,7 @@ def test_unaligned_tensors(self, device):
17131708 make_tensor = partial (torch .rand , size , device = device , dtype = dtype )
17141709 q , k , v = make_tensor (), make_tensor (), make_tensor ()
17151710 with sdpa_kernel (backends = [SDPBackend .EFFICIENT_ATTENTION ]):
1716- ctxmgr = self .assertRaises (RuntimeError ) if not TEST_WITH_ROCM else contextlib . nullcontext ()
1711+ ctxmgr = self .assertRaises (RuntimeError )
17171712 with ctxmgr :
17181713 torch .nn .functional .scaled_dot_product_attention (q , k , v , None , 0.0 , False )
17191714
@@ -2611,7 +2606,6 @@ def convert_flash_attn_S_to_softmax(
26112606 S_converted = F .pad (S_converted , (0 , seqlen_k_og - seqlen_k_rounded ))
26122607 return S_converted [:, :, :seqlen_q , :seqlen_k ]
26132608
2614- @skipIfRocm # No cuDNN Attention
26152609 @unittest .skipIf (not PLATFORM_SUPPORTS_CUDNN_ATTENTION , "cuDNN Attention is not supported on this system" )
26162610 def test_cudnn_attention_different_dk_dv (self , device ):
26172611 dtype = torch .bfloat16
@@ -2635,7 +2629,6 @@ def test_cudnn_attention_different_dk_dv(self, device):
26352629
26362630 self .assertEqual (actual .contiguous (), math_ref .contiguous ().to (dtype ), atol = 1e-3 , rtol = 1e-2 )
26372631
2638- @skipIfRocm # No cuDNN Attention
26392632 @unittest .skipIf (not PLATFORM_SUPPORTS_CUDNN_ATTENTION , "cuDNN Attention is not supported on this system" )
26402633 def test_cudnn_attention_gqa (self , device ):
26412634 batch = 4
@@ -2659,7 +2652,6 @@ def test_cudnn_attention_gqa(self, device):
26592652
26602653 self .assertEqual (output_math , output_cudnn )
26612654
2662- @skipIfRocm # No cuDNN Attention
26632655 @unittest .skipIf (not PLATFORM_SUPPORTS_CUDNN_ATTENTION , "cuDNN Attention is not supported on this system" )
26642656 def test_cudnn_attention_d256_heuristic (self , device ):
26652657 dtype = torch .bfloat16
@@ -2690,7 +2682,6 @@ def test():
26902682 with self .assertRaisesRegex (RuntimeError , "No available kernel." ):
26912683 test ()
26922684
2693- @skipIfRocm (msg = "No cuDNN on ROCm" )
26942685 @unittest .skipIf (not PLATFORM_SUPPORTS_CUDNN_ATTENTION , "cuDNN Attention is not supported on this system" )
26952686 def test_fused_attention_different_dk_dv (self , device ):
26962687 dtype = torch .bfloat16
@@ -2714,7 +2705,7 @@ def test_fused_attention_different_dk_dv(self, device):
27142705 self .assertEqual (actual .contiguous (), math_ref .contiguous ().to (dtype ), atol = 1e-3 , rtol = 1e-2 )
27152706
27162707
2717- @skipIfRocm # No cuDNN Attention
2708+ @unittest . skipIf ( not PLATFORM_SUPPORTS_CUDNN_ATTENTION , " cuDNN Attention is not supported on this system" )
27182709 @unittest .skipIf (True , "broken as of cuDNN 9.10" )
27192710 def test_cudnn_attention_fail_d128 (self , device ):
27202711 # Test that cuDNN attention dispatching correctly bails out on d > 128
@@ -2736,7 +2727,6 @@ def test_cudnn_attention_fail_d128(self, device):
27362727 with self .assertRaisesRegex (RuntimeError , "No available kernel." ):
27372728 torch .nn .functional .scaled_dot_product_attention (q , k , v )
27382729
2739- @skipIfRocm (msg = "No cuDNN on ROCm" )
27402730 @unittest .skipIf (not PLATFORM_SUPPORTS_CUDNN_ATTENTION , "cudnn Attention is not supported on this system" )
27412731 def test_cudnn_attention_trivial_output_transpose (self , device ):
27422732 # see also: https://github.com/pytorch/pytorch/issues/134001
@@ -2752,7 +2742,6 @@ def test_cudnn_attention_trivial_output_transpose(self, device):
27522742 o .backward (o )
27532743 torch .testing .assert_close (x .grad , x_cpu .grad .cuda (), atol = 7e-3 , rtol = 7e-3 )
27542744
2755- @skipIfRocm # No cuDNN Attention
27562745 @unittest .skipIf (not PLATFORM_SUPPORTS_CUDNN_ATTENTION , "cudnn Attention is not supported on this system" )
27572746 def test_cudnn_attention_nonmodulo64seqlen (self , device ):
27582747 # see also: https://github.com/pytorch/pytorch/issues/137347
@@ -2792,7 +2781,6 @@ def test_cudnn_attention_nonmodulo64seqlen(self, device):
27922781 torch .testing .assert_close (k .grad , k_cpu .grad .cuda (), atol = 3e-3 , rtol = 2e-3 )
27932782 torch .testing .assert_close (v .grad , v_cpu .grad .cuda (), atol = 3e-3 , rtol = 2e-3 )
27942783
2795- @skipIfRocm
27962784 @unittest .skipIf (not PLATFORM_SUPPORTS_CUDNN_ATTENTION , "cudnn Attention is not supported on this system" )
27972785 def test_cudnn_attention_preserves_query_layout (self , device ):
27982786
@@ -2822,7 +2810,6 @@ def test_attention(backend: SDPBackend, permute_order: list[list[int]]):
28222810 for permute_order in permute_orders :
28232811 test_attention (SDPBackend .CUDNN_ATTENTION , list (permute_order ) + [3 ])
28242812
2825- @skipIfRocm
28262813 @unittest .skipIf (not PLATFORM_SUPPORTS_CUDNN_ATTENTION , "cudnn Attention is not supported on this system" )
28272814 def test_cudnn_attention_compiles (self ):
28282815 q = torch .randn (2 , 8 , 1024 , 128 , dtype = torch .half , device = 'cuda' , requires_grad = True )
@@ -3265,7 +3252,6 @@ def test_sdp_choice_with_determinism(self, device, warn_only):
32653252 with sdpa_kernel (backends = [SDPBackend .EFFICIENT_ATTENTION , SDPBackend .MATH ]):
32663253 assert torch ._fused_sdp_choice (query , key , value ) == SDPBackend .EFFICIENT_ATTENTION .value
32673254
3268- @skipIfRocm
32693255 @onlyCUDA
32703256 @unittest .skipIf (not PLATFORM_SUPPORTS_CUDNN_ATTENTION , "cuDNN Attention is not supported on this system" )
32713257 @unittest .skipIf (not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION , "Platform does not support fused SDPA" )
0 commit comments