@@ -468,7 +468,7 @@ def block_scale_mxfp_matmul( #
468468 (128 , 128 , 256 ), (128 , 256 , 256 )])
469469@pytest .mark .parametrize ("NUM_STAGES" , [1 , 2 , 4 ])
470470@pytest .mark .parametrize ("USE_2D_SCALE_LOAD" , [False , True ])
471- @pytest .mark .skipif (torch .cuda .get_device_capability ()[0 ] < 10 , reason = "Requires compute capability >= 10" )
471+ @pytest .mark .skipif (is_hip () or torch .cuda .get_device_capability ()[0 ] < 10 , reason = "Requires compute capability >= 10" )
472472def test_blocked_scale_mxfp (M , N , K , BLOCK_M , BLOCK_N , BLOCK_K , NUM_STAGES , USE_2D_SCALE_LOAD , device ):
473473 if BLOCK_N == 256 and BLOCK_K == 256 :
474474 NUM_STAGES = min (NUM_STAGES , 2 )
@@ -540,7 +540,7 @@ def flatten_scale(scale):
540540@pytest .mark .parametrize ("BLOCK_M, BLOCK_N, BLOCK_K" , [(128 , 128 , 64 ), (128 , 64 , 128 ), (64 , 128 , 32 ), (128 , 256 , 32 )])
541541@pytest .mark .parametrize ("a_trans" , [False , True ])
542542@pytest .mark .parametrize ("dtype_src_str" , ["float32" , "float16" , "float8e5" ])
543- @pytest .mark .skipif (torch .cuda .get_device_capability ()[0 ] < 10 , reason = "Requires compute capability >= 10" )
543+ @pytest .mark .skipif (is_hip () or torch .cuda .get_device_capability ()[0 ] < 10 , reason = "Requires compute capability >= 10" )
544544def test_lhs_in_tmem (BLOCK_M , BLOCK_N , BLOCK_K , a_trans , dtype_src_str , device , monkeypatch ):
545545 M = 1024
546546 N = 512
@@ -604,7 +604,7 @@ def lhs_in_tmem_kernel_mxfp( #
604604 tl .store (output_ptrs , accumulator )
605605
606606
607- @pytest .mark .skipif (torch .cuda .get_device_capability ()[0 ] < 10 , reason = "Requires compute capability >= 10" )
607+ @pytest .mark .skipif (is_hip () or torch .cuda .get_device_capability ()[0 ] < 10 , reason = "Requires compute capability >= 10" )
608608def test_lhs_in_tmem_mxfp (device , monkeypatch ):
609609 _knob_promote_lhs_to_tmem (monkeypatch )
610610 M , N , K = 128 , 64 , 32
0 commit comments