@@ -639,6 +639,9 @@ def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_m
639639 pytest .skip ("FP8 parameters are not supported in debug mode." )
640640 if recipe .float8_block_scaling () and not fp8_block_scaling_available :
641641 pytest .skip (reason_for_no_fp8_block_scaling )
642+ if (IS_HIP_EXTENSION and get_device_compute_capability () == (9 , 5 ) and
643+ dtype in (torch .float16 , torch .bfloat16 ) and rocm_attn_backend ()[2 ]):
644+ pytest .skip ("Test is not supported on GFX950 with current parameters and CK fused attention backend and non-zero dropout." )
642645
643646 config = model_configs [model ]
644647
@@ -761,6 +764,9 @@ def test_gpt_full_activation_recompute(
761764 pytest .skip ("FP8 parameters are not supported in debug mode." )
762765 if recipe .float8_block_scaling () and not fp8_block_scaling_available :
763766 pytest .skip (reason_for_no_fp8_block_scaling )
767+ if (IS_HIP_EXTENSION and get_device_compute_capability () == (9 , 5 ) and
768+ dtype in (torch .float16 , torch .bfloat16 ) and rocm_attn_backend ()[2 ]):
769+ pytest .skip ("Test is not supported on GFX950 with current parameters and CK fused attention backend and non-zero dropout." )
764770
765771 config = model_configs [model ]
766772 torch .compiler .reset () # avoid cache size limit overflow
@@ -909,6 +915,10 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
909915@pytest .mark .parametrize ("bs" , batch_sizes )
910916@pytest .mark .parametrize ("model" , ["126m" ])
911917def test_gpt_checkpointing (dtype , bs , model ):
918+ if (IS_HIP_EXTENSION and get_device_compute_capability () == (9 , 5 ) and
919+ dtype in (torch .float16 , torch .bfloat16 ) and rocm_attn_backend ()[2 ]):
920+ pytest .skip ("Test is not supported on GFX950 with current parameters and CK fused attention backend and non-zero dropout." )
921+
912922 config = model_configs [model ]
913923 outputs = _test_e2e_checkpointing (bs , dtype , config , checkpoint = False )
914924 outputs_checkpoint = _test_e2e_checkpointing (bs , dtype , config , checkpoint = True )
@@ -2410,6 +2420,9 @@ def test_gpt_fp8_parameters(dtype, bs, model, recipe):
24102420 pytest .skip ("FP8 parameters are not supported in debug mode." )
24112421 if recipe .float8_block_scaling () and not fp8_block_scaling_available :
24122422 pytest .skip (reason_for_no_fp8_block_scaling )
2423+ if (IS_HIP_EXTENSION and get_device_compute_capability () == (9 , 5 ) and
2424+ dtype in (torch .float16 , torch .bfloat16 ) and rocm_attn_backend ()[2 ]):
2425+ pytest .skip ("Test is not supported on GFX950 with current parameters and CK fused attention backend and non-zero dropout." )
24132426
24142427 config = model_configs [model ]
24152428
@@ -2561,10 +2574,21 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
25612574 max_seqlen_kv = config .seq_len ,
25622575 )
25632576
2564- torch .testing .assert_close (
2565- y_bshd ,
2566- y_thd .reshape (bs , config .seq_len , config .hidden_size ).contiguous (),
2567- )
2577+ if IS_HIP_EXTENSION :
2578+ # On some GPUs CK fused attention with THD can produce larger error
2579+ tols = dtype_tols (dtype )
2580+ tols ["atol" ] = 1e-3
2581+ torch .testing .assert_close (
2582+ y_bshd ,
2583+ y_thd .reshape (bs , config .seq_len , config .hidden_size ).contiguous (),
2584+ ** tols ,
2585+ )
2586+ else :
2587+
2588+ torch .testing .assert_close (
2589+ y_bshd ,
2590+ y_thd .reshape (bs , config .seq_len , config .hidden_size ).contiguous (),
2591+ )
25682592
25692593
25702594@pytest .mark .parametrize ("dtype" , param_types )
0 commit comments