From 4b7d91b6eacbf5a8e8d335f62474b3b2f5eab508 Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Mon, 5 Jan 2026 07:15:14 -0800 Subject: [PATCH 1/2] fix test_grouped_mm --- tests/python/direct/test_cutlass_gemm.py | 14 ++++++++------ tests/python/direct_utils/utils.py | 5 +++++ 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/tests/python/direct/test_cutlass_gemm.py b/tests/python/direct/test_cutlass_gemm.py index 95b4bd5be33..18bb69c1588 100644 --- a/tests/python/direct/test_cutlass_gemm.py +++ b/tests/python/direct/test_cutlass_gemm.py @@ -5,16 +5,18 @@ import pytest import torch -from python.direct_utils import is_pre_blackwell -from python.direct_utils import microarchitecture_is_pre +from python.direct_utils import microarchitecture_is from nvfuser_direct import nvf_cutlass +# GPU Compute Capability: https://developer.nvidia.com/cuda/gpus +# tested on blackwell compute 10.0 (B200 and GB200) +# doesn't support 12.0 (RTX PRO 6000 and RTX 50XX) +# Not tested on 10.3 (B300 and GB300) +# Not tested on 12.1 (DGX Spark) @pytest.mark.skipif( - is_pre_blackwell(), reason="Only supported on blackwell and newer devices." -) -@pytest.mark.skipif( - not microarchitecture_is_pre(12), reason="Does not support blackwell compute 12.0." + not microarchitecture_is(10, 0), + reason="Does not support blackwell compute 12.0, other arches are not tested.", ) @pytest.mark.parametrize("config", [[1024, 128, 256], [267, 128, 256]]) @pytest.mark.parametrize("tokens_per_expert_neg_one", [[115, 144, 8], [5, 7, 9]]) diff --git a/tests/python/direct_utils/utils.py b/tests/python/direct_utils/utils.py index 651591fcacd..f5eb652a39e 100644 --- a/tests/python/direct_utils/utils.py +++ b/tests/python/direct_utils/utils.py @@ -9,6 +9,11 @@ from looseversion import LooseVersion +def microarchitecture_is(major, minor): + prop = torch.cuda.get_device_properties(torch.cuda.current_device()) + return prop.major == major and prop.minor == minor + + def microarchitecture_is_pre(major): prop = torch.cuda.get_device_properties(torch.cuda.current_device()) return prop.major < major From 235f16185143acb1ed6835941d95862893de5589 Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Mon, 5 Jan 2026 07:24:04 -0800 Subject: [PATCH 2/2] skip nvfp4 --- .../python/direct/test_cutlass_nvfp4_gemm.py | 21 ++++++++++++------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/tests/python/direct/test_cutlass_nvfp4_gemm.py b/tests/python/direct/test_cutlass_nvfp4_gemm.py index 8b73f384338..22cdb7aaca4 100644 --- a/tests/python/direct/test_cutlass_nvfp4_gemm.py +++ b/tests/python/direct/test_cutlass_nvfp4_gemm.py @@ -6,15 +6,8 @@ import pytest import torch from nvfuser_direct import nvf_cutlass - -compute_cap = torch.cuda.get_device_capability() -if compute_cap < (10, 0) or compute_cap >= (12, 0): - pytest.skip( - reason="Nvfp4 Requires compute capability 10.", - allow_module_level=True, - ) - from python.direct_utils import ( + microarchitecture_is, FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX, dequantize_to_dtype, @@ -52,6 +45,10 @@ def get_ref_results( return torch.matmul(a_in_dtype, b_in_dtype.t()) +@pytest.mark.skipif( + not microarchitecture_is(10, 0), + reason="Does not support blackwell compute 12.0, other arches are not tested.", +) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize( "shape", [(128, 128, 64), (128, 128, 128), (256, 128, 64), (128, 256, 128)] @@ -100,6 +97,10 @@ def test_nvfp4_gemm( torch.testing.assert_close(out, expected_out.to(dtype=dtype), atol=1e-1, rtol=1e-1) +@pytest.mark.skipif( + not microarchitecture_is(10, 0), + reason="Does not support blackwell compute 12.0, other arches are not tested.", +) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize( "shape", [(128, 128, 64), (128, 128, 128), (256, 128, 64), (128, 256, 128)] @@ -175,6 +176,10 @@ def test_nvfp4_gemm_epilogue( ) +@pytest.mark.skipif( + not microarchitecture_is(10, 0), + reason="Does not support blackwell compute 12.0, other arches are not tested.", +) @pytest.mark.parametrize("config", [[1024, 128, 256]]) @pytest.mark.parametrize("tokens_per_expert_neg_one", [[115, 144, 8]]) @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])