diff --git a/tests/python/direct/test_cutlass_nvfp4_gemm.py b/tests/python/direct/test_cutlass_nvfp4_gemm.py index 8b73f384338..22cdb7aaca4 100644 --- a/tests/python/direct/test_cutlass_nvfp4_gemm.py +++ b/tests/python/direct/test_cutlass_nvfp4_gemm.py @@ -6,15 +6,8 @@ import pytest import torch from nvfuser_direct import nvf_cutlass - -compute_cap = torch.cuda.get_device_capability() -if compute_cap < (10, 0) or compute_cap >= (12, 0): - pytest.skip( - reason="Nvfp4 Requires compute capability 10.", - allow_module_level=True, - ) - from python.direct_utils import ( + microarchitecture_is, FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX, dequantize_to_dtype, @@ -52,6 +45,10 @@ def get_ref_results( return torch.matmul(a_in_dtype, b_in_dtype.t()) +@pytest.mark.skipif( + not microarchitecture_is(10, 0), + reason="Does not support blackwell compute 12.0, other arches are not tested.", +) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize( "shape", [(128, 128, 64), (128, 128, 128), (256, 128, 64), (128, 256, 128)] @@ -100,6 +97,10 @@ def test_nvfp4_gemm( torch.testing.assert_close(out, expected_out.to(dtype=dtype), atol=1e-1, rtol=1e-1) +@pytest.mark.skipif( + not microarchitecture_is(10, 0), + reason="Does not support blackwell compute 12.0, other arches are not tested.", +) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize( "shape", [(128, 128, 64), (128, 128, 128), (256, 128, 64), (128, 256, 128)] @@ -175,6 +176,10 @@ def test_nvfp4_gemm_epilogue( ) +@pytest.mark.skipif( + not microarchitecture_is(10, 0), + reason="Does not support blackwell compute 12.0, other arches are not tested.", +) @pytest.mark.parametrize("config", [[1024, 128, 256]]) @pytest.mark.parametrize("tokens_per_expert_neg_one", [[115, 144, 8]]) @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])