Skip to content

Can't run fp4_gemm unittest with 5090 #3059

@kcabkooltonod

Description

@kcabkooltonod

System Info

GPU: RTX 5090
CUDA version: 12.8
tensorrt-llm version:0.19.0.dev2025032500 built from source

Commands to run some fp4_quant unittest

First, I was trying to run fp4_quant unittest: TensorRT-LLM/tests/unittest/_torch/test_fp4_gemm_quantize.py with the command:pytest test_fp4_gemm_quantize.py -v and it skips the fp4_gemm unittest because of the statment: pytest.skip("https://nvbugs/5100633")
And actually all the quantization tests passed:

pytest test_fp4_gemm_quantize.py -v
/usr/local/lib/python3.12/dist-packages/pytest_asyncio/plugin.py:207: PytestDeprecationWarning: The configuration option "asyncio_default_fixture_loop_scope" is unset.
The event loop scope for asynchronous fixtures will default to the fixture caching scope. Future versions of pytest-asyncio will default the loop scope for asynchronous fixtures to function scope. Set the default fixture loop scope explicitly in order to avoid unexpected behavior in the future. Valid fixture loop scopes are: "function", "class", "module", "package", "session"

warnings.warn(PytestDeprecationWarning(_DEFAULT_FIXTURE_LOOP_SCOPE_UNSET))
==================================================================================================== test session starts =====================================================================================================
platform linux -- Python 3.12.3, pytest-8.3.5, pluggy-1.5.0 -- /usr/bin/python3
cachedir: .pytest_cache
hypothesis profile 'default' -> database=DirectoryBasedExampleDatabase('/code/TensorRT-LLM/tests/unittest/_torch/.hypothesis/examples')
rootdir: /code/TensorRT-LLM/tests/unittest
configfile: pytest.ini
plugins: env-1.1.5, forked-1.6.0, timeout-2.3.1, split-0.10.0, cov-6.0.0, csv-3.0.0, mock-3.14.0, asyncio-0.25.3, anyio-4.8.0, hypothesis-5.35.1, xdoctest-1.0.2, shard-0.1.2, rerunfailures-15.0, flakefinder-1.1.0, xdist-3.6.1
asyncio: mode=Mode.STRICT, asyncio_default_fixture_loop_scope=None
collected 10 items
Running 10 items in this shard: _torch/test_fp4_gemm_quantize.py::TestFunctional::test_fp4_quantize_gemm_torch_1024_1024_1024, _torch/test_fp4_gemm_quantize.py::TestFunctional::test_fp4_quantize_gemm_torch_7_32_32, _torch/test_fp4_gemm_quantize.py::TestFunctional::test_fp4_quantize_torch_1024_1024_torch_float16_False, _torch/test_fp4_gemm_quantize.py::TestFunctional::test_fp4_quantize_torch_13_16_torch_float16_True, _torch/test_fp4_gemm_quantize.py::TestFunctional::test_fp4_quantize_torch_2_512_torch_bfloat16_False, _torch/test_fp4_gemm_quantize.py::TestFunctional::test_fp4_quantize_torch_fp8_13_16_torch_float8_e4m3fn_True, _torch/test_fp4_gemm_quantize.py::TestFunctional::test_fp4_quantize_torch_fp8_64_64_torch_float8_e4m3fn_False, _torch/test_fp4_gemm_quantize.py::TestProfiling::test_fp4_quantize_gemm_torch_profiling_1024_1024_1024, _torch/test_fp4_gemm_quantize.py::TestProfiling::test_fp4_quantize_gemm_torch_profiling_512_32_64, _torch/test_fp4_gemm_quantize.py::TestProfiling::test_fp4_quantize_gemm_torch_profiling_7_32_32

test_fp4_gemm_quantize.py::TestFunctional::test_fp4_quantize_gemm_torch_1024_1024_1024 SKIPPED (https://nvbugs/5100633) [ 10%]
test_fp4_gemm_quantize.py::TestFunctional::test_fp4_quantize_gemm_torch_7_32_32 SKIPPED (https://nvbugs/5100633) [ 20%]
test_fp4_gemm_quantize.py::TestFunctional::test_fp4_quantize_torch_1024_1024_torch_float16_False PASSED [ 30%]
test_fp4_gemm_quantize.py::TestFunctional::test_fp4_quantize_torch_13_16_torch_float16_True PASSED [ 40%]
test_fp4_gemm_quantize.py::TestFunctional::test_fp4_quantize_torch_2_512_torch_bfloat16_False PASSED [ 50%]
test_fp4_gemm_quantize.py::TestFunctional::test_fp4_quantize_torch_fp8_13_16_torch_float8_e4m3fn_True PASSED [ 60%]
test_fp4_gemm_quantize.py::TestFunctional::test_fp4_quantize_torch_fp8_64_64_torch_float8_e4m3fn_False PASSED [ 70%]
test_fp4_gemm_quantize.py::TestProfiling::test_fp4_quantize_gemm_torch_profiling_1024_1024_1024 SKIPPED (https://nvbugs/5100633) [ 80%]
test_fp4_gemm_quantize.py::TestProfiling::test_fp4_quantize_gemm_torch_profiling_512_32_64 SKIPPED (https://nvbugs/5100633) [ 90%]
test_fp4_gemm_quantize.py::TestProfiling::test_fp4_quantize_gemm_torch_profiling_7_32_32 SKIPPED (https://nvbugs/5100633) [100%]

===================================================================================================== slowest durations ======================================================================================================
0.18s call _torch/test_fp4_gemm_quantize.py::TestFunctional::test_fp4_quantize_torch_1024_1024_torch_float16_False

(29 durations < 0.005s hidden. Use -vv to show these durations.)
=============================================================================================== 5 passed, 5 skipped in 10.24s ===============================================================================================

which means i can properly deal with the fp4 quantization.

Errors occured when running fp4_gemm unittest

After running the fp4_quant unittest i tried to test fp4_gemm: TensorRT-LLM/tests/unittest/_torch/test_fp4_linear.py but encountered error:

../../../tensorrt_llm/_torch/modules/linear.py:358: RuntimeError
__________________________________________________________________________________________________ test_fp4_linear[dtype1] ___________________________________________________________________________________________________

dtype = torch.bfloat16

@skip_pre_blackwell
@pytest.mark.parametrize(
    "dtype", [torch.float16, torch.bfloat16]
)  # TODO: Do we need float32 test case? fp4_quantize only supports fp16, bf16, fp8_e4m3
def test_fp4_linear(dtype):
    SEQ_LEN = 10
    HIDDEN_SIZE = 128
    torch.manual_seed(0)

    x = torch.randn((SEQ_LEN, HIDDEN_SIZE), dtype=dtype).cuda()
    x_sf_global = (448 * 6) / x.abs().max().float()

    w = torch.randn((HIDDEN_SIZE, HIDDEN_SIZE), dtype=dtype).cuda()
    w_sf_global = (448 * 6) / w.abs().max().float()
    w_fp4, w_sf_block = torch.ops.trtllm.fp4_quantize(w, w_sf_global,
                                                      scaling_vector_size,
                                                      False)

    qc = QuantConfig(quant_algo=QuantAlgo.NVFP4)
    l_fp4 = Linear(in_features=HIDDEN_SIZE,
                   out_features=HIDDEN_SIZE,
                   bias=False,
                   dtype=dtype,
                   quant_config=qc)

    assert l_fp4.weight.dtype == fp4_utils.float4_e2m1x2
    assert l_fp4.weight_scale.dtype == fp4_utils.float4_sf_dtype

    w_sf_block_unswizzled = (
        torch.ops.tensorrt_llm.nvfp4_block_scale_interleave_reverse(
            w_sf_block.cpu().view(HIDDEN_SIZE, -1)))

    l_fp4.load_weights([{
        'input_scale':
        1.0 / x_sf_global.cpu(),  # Simulates amax/(448*6) in modelopt ckpt
        'weight':
        w_fp4.cpu(),
        'weight_scale':
        w_sf_block_unswizzled.view(
            torch.float8_e4m3fn),  # Simulates float8_e4m3fn in modelopt ckpt
        'weight_scale_2':
        1.0 / w_sf_global.cpu()  # Simulates amax/(448*6) in modelopt ckpt
    }])
    l_fp4 = l_fp4.cuda()

    torch.testing.assert_close(l_fp4.weight, w_fp4)
    torch.testing.assert_close(l_fp4.input_scale[0], x_sf_global)
    torch.testing.assert_close(l_fp4.weight_scale, w_sf_block)
    alpha_ref = 1.0 / (w_sf_global * x_sf_global)
    torch.testing.assert_close(l_fp4.alpha[0], alpha_ref)

    with torch.inference_mode():
      output = l_fp4.forward(x)

test_fp4_linear.py:69:


../../../tensorrt_llm/_torch/modules/linear.py:409: in forward
output = self.apply_linear(input, self.weight, self.bias)


self = Linear(
(all_reduce): AllReduce()
)
input = tensor([[ 1.4609, 0.2910, 0.9609, ..., 1.2188, 1.9375, 1.5938],
[-0.4844, 0.9062, -1.1250, ..., 0.875...],
[ 0.0408, 1.8047, -1.1172, ..., -0.1885, 0.2832, 1.3828]],
device='cuda:0', dtype=torch.bfloat16)
weight = Parameter containing:
tensor([[107, 103, 68, ..., 182, 75, 204],
[ 17, 190, 136, ..., 181, 151, 43],
...0, 113, ..., 132, 215, 178],
[213, 176, 175, ..., 170, 89, 51]], device='cuda:0',
dtype=torch.uint8)
bias = None

def apply_linear(self, input, weight, bias):
    if self.has_any_quant:
        qc = self.quant_config
        if self.has_fp8_qdq:
            if input.dtype != torch.float8_e4m3fn:
                qinput, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor(
                    input, self.input_scale)
            else:
                qinput = input
            # This op does not support bias now.
            output = torch.ops.trtllm.cublas_scaled_mm(
                qinput,
                weight.t(),
                scale_a=self.input_scale,
                scale_b=self.weight_scale,
                bias=None,
                out_dtype=self.dtype or input.dtype,
                userbuffers_id=-1,
            )
            if bias is not None:
                output = output + bias
        elif self.has_fp8_block_scales:
            if input.dtype == torch.float8_e4m3fn:
                input = input.to(torch.bfloat16) * self.input_scale
            assert input.dtype == torch.bfloat16

            act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128(
                input)

            output = torch.ops.trtllm.fp8_block_scaling_gemm(
                act_input_fp8, self.weight, act_input_sf, self.weight_scale)

        elif self.has_nv_fp4:
            if isinstance(input, Fp4QuantizedTensor):
                act_fp4, act_sf = input.fp4_tensor, input.scaling_factor
            else:
                act_fp4, act_sf = torch.ops.trtllm.fp4_quantize(
                    input, self.input_scale, self.scaling_vector_size,
                    False)

            # This is a workaround to avoid the issue that torch compile cannot handle the profiler.
            if is_torch_compiling():
                output = torch.ops.trtllm.fp4_gemm(act_fp4, self.weight,
                                                   act_sf,
                                                   self.weight_scale,
                                                   self.alpha, False,
                                                   self.dtype)
            else:
                m = math.prod(act_fp4.shape[:-1])
                n = self.weight.shape[0]
                k = self.weight.shape[1] * 2

                if self.needs_profiling:
                    self.needs_profiling = False
                    self.profiler.run_profile(n, k, fp4_utils.fp4_buckets)

                best_config_id = self.profiler.get_best_config_id(m, n, k)
              output = self.profiler.run_gemm(act_fp4, self.weight,
                                                act_sf, self.weight_scale,
                                                self.alpha, False,
                                                best_config_id)

E RuntimeError: [TensorRT-LLM Error][CutlassFp4GemmRunner][GEMM Dispatch] Arch unsupported for CUTLASS FP4 GEMM

../../../tensorrt_llm/_torch/modules/linear.py:358: RuntimeError
===================================================================================================== slowest durations ======================================================================================================
0.28s call _torch/test_fp4_linear.py::test_fp4_linear[dtype0]

(5 durations < 0.005s hidden. Use -vv to show these durations.)
================================================================================================== short test summary info ===================================================================================================
FAILED test_fp4_linear.py::test_fp4_linear[dtype0] - RuntimeError: [TensorRT-LLM Error][CutlassFp4GemmRunner][GEMM Dispatch] Arch unsupported for CUTLASS FP4 GEMM
FAILED test_fp4_linear.py::test_fp4_linear[dtype1] - RuntimeError: [TensorRT-LLM Error][CutlassFp4GemmRunner][GEMM Dispatch] Arch unsupported for CUTLASS FP4 GEMM

Is that means RTX 5090 doesn't support FP4 GEMM?

Try to run fp4_gemm unittest again but still failed

At last, i trying to run TensorRT-LLM/tests/unittest/functional/test_fp4_gemm.py but this time it occurred another error:

pytest test_fp4_gemm.py -v
/usr/local/lib/python3.12/dist-packages/pytest_asyncio/plugin.py:207: PytestDeprecationWarning: The configuration option "asyncio_default_fixture_loop_scope" is unset.
The event loop scope for asynchronous fixtures will default to the fixture caching scope. Future versions of pytest-asyncio will default the loop scope for asynchronous fixtures to function scope. Set the default fixture loop scope explicitly in order to avoid unexpected behavior in the future. Valid fixture loop scopes are: "function", "class", "module", "package", "session"

warnings.warn(PytestDeprecationWarning(_DEFAULT_FIXTURE_LOOP_SCOPE_UNSET))
==================================================================================================== test session starts =====================================================================================================
platform linux -- Python 3.12.3, pytest-8.3.5, pluggy-1.5.0 -- /usr/bin/python3
cachedir: .pytest_cache
hypothesis profile 'default' -> database=DirectoryBasedExampleDatabase('/code/TensorRT-LLM/tests/unittest/functional/.hypothesis/examples')
rootdir: /code/TensorRT-LLM/tests/unittest
configfile: pytest.ini
plugins: env-1.1.5, forked-1.6.0, timeout-2.3.1, split-0.10.0, cov-6.0.0, csv-3.0.0, mock-3.14.0, asyncio-0.25.3, anyio-4.8.0, hypothesis-5.35.1, xdoctest-1.0.2, shard-0.1.2, rerunfailures-15.0, flakefinder-1.1.0, xdist-3.6.1
asyncio: mode=Mode.STRICT, asyncio_default_fixture_loop_scope=None
collected 66 items
Running 66 items in this shard: functional/test_fp4_gemm.py::TestFunctional::test_fp4_gemm_1024_1024_1023_float16_16_1_0, functional/test_fp4_gemm.py::TestFunctional::test_fp4_gemm_1024_1024_1023_float16_16_2_0, functional/test_fp4_gemm.py::TestFunctional::test_fp4_gemm_1024_1024_128_float16_16_1_0, functional/test_fp4_gemm.py::TestFunctional::test_fp4_gemm_1024_1024_128_float16_16_2_0, functional/test_fp4_gemm.py::TestFunctional::test_fp4_gemm_1024_1024_1_float16_16_1_0, functional/test_fp4_gemm.py::TestFunctional::test_fp4_gemm_1024_1024_1_float16_16_2_0, functional/test_fp4_gemm.py::TestFunctional::test_fp4_gemm_1024_1024_8_float16_16_1_0, functional/test_fp4_gemm.py::TestFunctional::test_fp4_gemm_1024_1024_8_float16_16_2_0, functional/test_fp4_gemm.py::TestFunctional::test_fp4_gemm_1024_2048_1023_float16_16_1_0, functional/test_fp4_gemm.py::TestFunctional::test_fp4_gemm_1024_2048_1023_float16_16_2_0, functional/test_fp4_gemm.py::TestFunctional::test_fp4_gemm_1024_2048_128_float16_16_1_0, functional/test_fp4_gemm.py::TestFunctional::test_fp4_gemm_1024_2048_128_float16_16_2_0, functional/test_fp4_gemm.py::TestFunctional::test_fp4_gemm_1024_2048_1_float16_16_1_0, functional/test_fp4_gemm.py::TestFunctional::test_fp4_gemm_1024_2048_1_float16_16_2_0, functional/test_fp4_gemm.py::TestFunctional::test_fp4_gemm_1024_2048_8_float16_16_1_0, functional/test_fp4_gemm.py::TestFunctional::test_fp4_gemm_1024_2048_8_float16_16_2_0, functional/test_fp4_gemm.py::TestFunctional::test_fp4_gemm_2048_1024_1023_float16_16_1_0, functional/test_fp4_gemm.py::TestFunctional::test_fp4_gemm_2048_1024_1023_float16_16_2_0, functional/test_fp4_gemm.py::TestFunctional::test_fp4_gemm_2048_1024_128_float16_16_1_0, functional/test_fp4_gemm.py::TestFunctional::test_fp4_gemm_2048_1024_128_float16_16_2_0, functional/test_fp4_gemm.py::TestFunctional::test_fp4_gemm_2048_1024_1_float16_16_1_0, functional/test_fp4_gemm.py::TestFunctional::test_fp4_gemm_2048_1024_1_float16_16_2_0, functional/test_fp4_gemm.py::TestFunctional::test_fp4_gemm_2048_1024_8_float16_16_1_0, functional/test_fp4_gemm.py::TestFunctional::test_fp4_gemm_2048_1024_8_float16_16_2_0, functional/test_fp4_gemm.py::TestFunctional::test_fp4_gemm_2048_2048_1023_float16_16_1_0, functional/test_fp4_gemm.py::TestFunctional::test_fp4_gemm_2048_2048_1023_float16_16_2_0, functional/test_fp4_gemm.py::TestFunctional::test_fp4_gemm_2048_2048_128_float16_16_1_0, functional/test_fp4_gemm.py::TestFunctional::test_fp4_gemm_2048_2048_128_float16_16_2_0, functional/test_fp4_gemm.py::TestFunctional::test_fp4_gemm_2048_2048_1_float16_16_1_0, functional/test_fp4_gemm.py::TestFunctional::test_fp4_gemm_2048_2048_1_float16_16_2_0, functional/test_fp4_gemm.py::TestFunctional::test_fp4_gemm_2048_2048_8_float16_16_1_0, functional/test_fp4_gemm.py::TestFunctional::test_fp4_gemm_2048_2048_8_float16_16_2_0, functional/test_fp4_gemm.py::TestFunctional::test_input_quant_and_fp4_gemm_10240_10240_1_16, functional/test_fp4_gemm.py::TestFunctional::test_input_quant_and_fp4_gemm_10240_10240_20_16, functional/test_fp4_gemm.py::TestFunctional::test_input_quant_and_fp4_gemm_10240_28672_1_16, functional/test_fp4_gemm.py::TestFunctional::test_input_quant_and_fp4_gemm_10240_28672_20_16, functional/test_fp4_gemm.py::TestFunctional::test_input_quant_and_fp4_gemm_10240_8192_1_16, functional/test_fp4_gemm.py::TestFunctional::test_input_quant_and_fp4_gemm_10240_8192_20_16, functional/test_fp4_gemm.py::TestFunctional::test_input_quant_and_fp4_gemm_1024_1024_1023_16, functional/test_fp4_gemm.py::TestFunctional::test_input_quant_and_fp4_gemm_1024_1024_128_16, functional/test_fp4_gemm.py::TestFunctional::test_input_quant_and_fp4_gemm_1024_1024_1_16, functional/test_fp4_gemm.py::TestFunctional::test_input_quant_and_fp4_gemm_1024_1024_8_16, functional/test_fp4_gemm.py::TestFunctional::test_input_quant_and_fp4_gemm_1024_2048_1023_16, functional/test_fp4_gemm.py::TestFunctional::test_input_quant_and_fp4_gemm_1024_2048_128_16, functional/test_fp4_gemm.py::TestFunctional::test_input_quant_and_fp4_gemm_1024_2048_1_16, functional/test_fp4_gemm.py::TestFunctional::test_input_quant_and_fp4_gemm_1024_2048_8_16, functional/test_fp4_gemm.py::TestFunctional::test_input_quant_and_fp4_gemm_2048_1024_1023_16, functional/test_fp4_gemm.py::TestFunctional::test_input_quant_and_fp4_gemm_2048_1024_128_16, functional/test_fp4_gemm.py::TestFunctional::test_input_quant_and_fp4_gemm_2048_1024_1_16, functional/test_fp4_gemm.py::TestFunctional::test_input_quant_and_fp4_gemm_2048_1024_8_16, functional/test_fp4_gemm.py::TestFunctional::test_input_quant_and_fp4_gemm_2048_2048_1023_16, functional/test_fp4_gemm.py::TestFunctional::test_input_quant_and_fp4_gemm_2048_2048_128_16, functional/test_fp4_gemm.py::TestFunctional::test_input_quant_and_fp4_gemm_2048_2048_1_16, functional/test_fp4_gemm.py::TestFunctional::test_input_quant_and_fp4_gemm_2048_2048_8_16, functional/test_fp4_gemm.py::TestFunctional::test_input_quant_and_fp4_gemm_28672_10240_1_16, functional/test_fp4_gemm.py::TestFunctional::test_input_quant_and_fp4_gemm_28672_10240_20_16, functional/test_fp4_gemm.py::TestFunctional::test_input_quant_and_fp4_gemm_28672_28672_1_16, functional/test_fp4_gemm.py::TestFunctional::test_input_quant_and_fp4_gemm_28672_28672_20_16, functional/test_fp4_gemm.py::TestFunctional::test_input_quant_and_fp4_gemm_28672_8192_1_16, functional/test_fp4_gemm.py::TestFunctional::test_input_quant_and_fp4_gemm_28672_8192_20_16, functional/test_fp4_gemm.py::TestFunctional::test_input_quant_and_fp4_gemm_8192_10240_1_16, functional/test_fp4_gemm.py::TestFunctional::test_input_quant_and_fp4_gemm_8192_10240_20_16, functional/test_fp4_gemm.py::TestFunctional::test_input_quant_and_fp4_gemm_8192_28672_1_16, functional/test_fp4_gemm.py::TestFunctional::test_input_quant_and_fp4_gemm_8192_28672_20_16, functional/test_fp4_gemm.py::TestFunctional::test_input_quant_and_fp4_gemm_8192_8192_1_16, functional/test_fp4_gemm.py::TestFunctional::test_input_quant_and_fp4_gemm_8192_8192_20_16

test_fp4_gemm.py::TestFunctional::test_fp4_gemm_1024_1024_1023_float16_16_1_0 Fatal Python error: Aborted

Thread 0x00007f251ef836c0 (most recent call first):
File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/compile_worker/subproc_pool.py", line 54 in _recv_msg
File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/compile_worker/subproc_pool.py", line 162 in _read_thread
File "/usr/lib/python3.12/threading.py", line 1010 in run
File "/usr/lib/python3.12/threading.py", line 1073 in _bootstrap_inner
File "/usr/lib/python3.12/threading.py", line 1030 in _bootstrap

Thread 0x00007f2c502a36c0 (most recent call first):
File "/usr/lib/python3.12/socket.py", line 295 in accept
File "/usr/local/lib/python3.12/dist-packages/pytest_rerunfailures.py", line 440 in run_server
File "/usr/lib/python3.12/threading.py", line 1010 in run
File "/usr/lib/python3.12/threading.py", line 1073 in _bootstrap_inner
File "/usr/lib/python3.12/threading.py", line 1030 in _bootstrap

Current thread 0x00007f2c524a2540 (most recent call first):
File "/code/TensorRT-LLM/tensorrt_llm/runtime/session.py", line 297 in run
File "/code/TensorRT-LLM/tests/unittest/functional/test_fp4_gemm.py", line 251 in test_fp4_gemm
File "/usr/local/lib/python3.12/dist-packages/parameterized/parameterized.py", line 620 in standalone_func
File "/usr/lib/python3.12/unittest/case.py", line 589 in _callTestMethod
File "/usr/lib/python3.12/unittest/case.py", line 634 in run
File "/usr/lib/python3.12/unittest/case.py", line 690 in call
File "/usr/local/lib/python3.12/dist-packages/_pytest/unittest.py", line 351 in runtest
File "/usr/local/lib/python3.12/dist-packages/_pytest/runner.py", line 174 in pytest_runtest_call
File "/usr/local/lib/python3.12/dist-packages/pluggy/_callers.py", line 103 in _multicall
File "/usr/local/lib/python3.12/dist-packages/pluggy/_manager.py", line 120 in _hookexec
File "/usr/local/lib/python3.12/dist-packages/pluggy/_hooks.py", line 513 in call
File "/usr/local/lib/python3.12/dist-packages/_pytest/runner.py", line 242 in
File "/usr/local/lib/python3.12/dist-packages/_pytest/runner.py", line 341 in from_call
File "/usr/local/lib/python3.12/dist-packages/_pytest/runner.py", line 241 in call_and_report
File "/usr/local/lib/python3.12/dist-packages/_pytest/runner.py", line 132 in runtestprotocol
File "/usr/local/lib/python3.12/dist-packages/_pytest/runner.py", line 113 in pytest_runtest_protocol
File "/usr/local/lib/python3.12/dist-packages/pluggy/_callers.py", line 103 in _multicall
File "/usr/local/lib/python3.12/dist-packages/pluggy/_manager.py", line 120 in _hookexec
File "/usr/local/lib/python3.12/dist-packages/pluggy/_hooks.py", line 513 in call
File "/usr/local/lib/python3.12/dist-packages/_pytest/main.py", line 362 in pytest_runtestloop
File "/usr/local/lib/python3.12/dist-packages/pluggy/_callers.py", line 103 in _multicall
File "/usr/local/lib/python3.12/dist-packages/pluggy/_manager.py", line 120 in _hookexec
File "/usr/local/lib/python3.12/dist-packages/pluggy/_hooks.py", line 513 in call
File "/usr/local/lib/python3.12/dist-packages/_pytest/main.py", line 337 in _main
File "/usr/local/lib/python3.12/dist-packages/_pytest/main.py", line 283 in wrap_session
File "/usr/local/lib/python3.12/dist-packages/_pytest/main.py", line 330 in pytest_cmdline_main
File "/usr/local/lib/python3.12/dist-packages/pluggy/_callers.py", line 103 in _multicall
File "/usr/local/lib/python3.12/dist-packages/pluggy/_manager.py", line 120 in _hookexec
File "/usr/local/lib/python3.12/dist-packages/pluggy/_hooks.py", line 513 in call
File "/usr/local/lib/python3.12/dist-packages/_pytest/config/init.py", line 175 in main
File "/usr/local/lib/python3.12/dist-packages/_pytest/config/init.py", line 201 in console_main
File "/usr/local/bin/pytest", line 8 in

Extension modules: numpy.core._multiarray_umath, numpy.core._multiarray_tests, numpy.linalg._umath_linalg, numpy.fft._pocketfft_internal, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._mt19937, numpy.random.mtrand, numpy.random._philox, numpy.random._pcg64, numpy.random._sfc64, numpy.random._generator, torch._C, torch._C._dynamo.autograd_compiler, torch._C._dynamo.eval_frame, torch._C._dynamo.guards, torch._C._dynamo.utils, torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._nn, torch._C._sparse, torch._C._special, cuda.bindings._lib.utils, cuda.bindings._bindings.cydriver, cuda.bindings.cydriver, cuda.bindings.driver, cuda.cuda, mpi4py.MPI, google._upb._message, cuda.bindings._lib.cyruntime.utils, cuda.bindings._lib.cyruntime.cyruntime, cuda.bindings.cyruntime, cuda.bindings.runtime, cuda.cudart, zstandard.backend_c, charset_normalizer.md, yaml._yaml, pyarrow.lib, pandas._libs.tslibs.ccalendar, pandas._libs.tslibs.np_datetime, pandas._libs.tslibs.dtypes, pandas._libs.tslibs.base, pandas._libs.tslibs.nattype, pandas._libs.tslibs.timezones, pandas._libs.tslibs.fields, pandas._libs.tslibs.timedeltas, pandas._libs.tslibs.tzconversion, pandas._libs.tslibs.timestamps, pandas._libs.properties, pandas._libs.tslibs.offsets, pandas._libs.tslibs.strptime, pandas._libs.tslibs.parsing, pandas._libs.tslibs.conversion, pandas._libs.tslibs.period, pandas._libs.tslibs.vectorized, pandas._libs.ops_dispatch, pandas._libs.missing, pandas._libs.hashtable, pandas._libs.algos, pandas._libs.interval, pandas._libs.lib, pyarrow._compute, pandas._libs.ops, numexpr.interpreter, pandas._libs.hashing, pandas._libs.arrays, pandas._libs.tslib, pandas._libs.sparse, pandas._libs.internals, pandas._libs.indexing, pandas._libs.index, pandas._libs.writers, pandas._libs.join, pandas._libs.window.aggregations, pandas._libs.window.indexers, pandas._libs.reshape, pandas._libs.groupby, pandas._libs.json, pandas._libs.parsers, pandas._libs.testing, _cffi_backend, pyarrow._parquet, pyarrow._fs, pyarrow._azurefs, pyarrow._hdfs, pyarrow._gcsfs, pyarrow._s3fs, multidict._multidict, yarl._quoting_c, propcache._helpers_c, aiohttp._http_writer, aiohttp._http_parser, aiohttp._websocket.mask, aiohttp._websocket.reader_c, frozenlist._frozenlist, xxhash._xxhash, pyarrow._json, cython.cimports.libc.math, Cython.Utils, Cython.Plex.Actions, Cython.Plex.Transitions, Cython.Plex.Machines, Cython.Plex.DFA, Cython.Plex.Scanners, Cython.Compiler.Scanning, Cython.StringIOTree, Cython.Compiler.Code, markupsafe._speedups, PIL._imaging, sklearn.__check_build._check_build, psutil._psutil_linux, psutil._psutil_posix, scipy._lib._ccallback_c, scipy.sparse._sparsetools, _csparsetools, scipy.sparse._csparsetools, scipy.linalg._fblas, scipy.linalg._flapack, scipy.linalg.cython_lapack, scipy.linalg._cythonized_array_utils, scipy.linalg._solve_toeplitz, scipy.linalg._decomp_lu_cython, scipy.linalg._matfuncs_sqrtm_triu, scipy.linalg.cython_blas, scipy.linalg._matfuncs_expm, scipy.linalg._decomp_update, scipy.sparse.linalg._dsolve._superlu, scipy.sparse.linalg._eigen.arpack._arpack, scipy.sparse.linalg._propack._spropack, scipy.sparse.linalg._propack._dpropack, scipy.sparse.linalg._propack._cpropack, scipy.sparse.linalg._propack._zpropack, scipy.sparse.csgraph._tools, scipy.sparse.csgraph._shortest_path, scipy.sparse.csgraph._traversal, scipy.sparse.csgraph._min_spanning_tree, scipy.sparse.csgraph._flow, scipy.sparse.csgraph._matching, scipy.sparse.csgraph._reordering, scipy.special._ufuncs_cxx, scipy.special._ufuncs, scipy.special._specfun, scipy.special._comb, scipy.special._ellip_harm_2, scipy.spatial._ckdtree, scipy._lib.messagestream, scipy.spatial._qhull, scipy.spatial._voronoi, scipy.spatial._distance_wrap, scipy.spatial._hausdorff, scipy.spatial.transform._rotation, scipy.optimize._group_columns, scipy.optimize._trlib._trlib, scipy.optimize._lbfgsb, _moduleTNC, scipy.optimize._moduleTNC, scipy.optimize._cobyla, scipy.optimize._slsqp, scipy.optimize._minpack, scipy.optimize._lsq.givens_elimination, scipy.optimize._zeros, scipy.optimize._highs.cython.src._highs_wrapper, scipy.optimize._highs._highs_wrapper, scipy.optimize._highs.cython.src._highs_constants, scipy.optimize._highs._highs_constants, scipy.linalg._interpolative, scipy.optimize._bglu_dense, scipy.optimize._lsap, scipy.optimize._direct, scipy.integrate._odepack, scipy.integrate._quadpack, scipy.integrate._vode, scipy.integrate._dop, scipy.integrate._lsoda, scipy.interpolate._fitpack, scipy.interpolate._dfitpack, scipy.interpolate._bspl, scipy.interpolate._ppoly, scipy.interpolate.interpnd, scipy.interpolate._rbfinterp_pythran, scipy.interpolate._rgi_cython, scipy.special.cython_special, scipy.stats._stats, scipy.stats._biasedurn, scipy.stats._levy_stable.levyst, scipy.stats._stats_pythran, scipy._lib._uarray._uarray, scipy.stats._ansari_swilk_statistics, scipy.stats._sobol, scipy.stats._qmc_cy, scipy.stats._mvn, scipy.stats._rcont.rcont, scipy.stats._unuran.unuran_wrapper, scipy.ndimage._nd_image, _ni_label, scipy.ndimage._ni_label, sklearn.utils._isfinite, sklearn.utils.sparsefuncs_fast, sklearn.utils.murmurhash, sklearn.utils._openmp_helpers, sklearn.metrics.cluster._expected_mutual_info_fast, sklearn.preprocessing._csr_polynomial_expansion, sklearn.preprocessing._target_encoder_fast, sklearn.metrics._dist_metrics, sklearn.metrics._pairwise_distances_reduction._datasets_pair, sklearn.utils._cython_blas, sklearn.metrics._pairwise_distances_reduction._base, sklearn.metrics._pairwise_distances_reduction._middle_term_computer, sklearn.utils._heap, sklearn.utils._sorting, sklearn.metrics._pairwise_distances_reduction._argkmin, sklearn.metrics._pairwise_distances_reduction._argkmin_classmode, sklearn.utils._vector_sentinel, sklearn.metrics._pairwise_distances_reduction._radius_neighbors, sklearn.metrics._pairwise_distances_reduction._radius_neighbors_classmode, sklearn.metrics._pairwise_fast, PIL._imagingft, modelopt_core.torch.nas.plugins.megatron, modelopt_core.torch.quantization.config, modelopt.core.torch.quantization.config, modelopt_core.torch.quantization.qtensor.nvfp4_tensor, modelopt.core.torch.quantization.qtensor.nvfp4_tensor, modelopt_core.torch.quantization.algorithms, regex._regex, modelopt.core.torch.quantization.algorithms, h5py._errors, h5py.defs, h5py._objects, h5py.h5, h5py.utils, h5py.h5t, h5py.h5s, h5py.h5ac, h5py.h5p, h5py.h5r, h5py._proxy, h5py._conv, h5py.h5z, h5py.h5a, h5py.h5d, h5py.h5ds, h5py.h5g, h5py.h5i, h5py.h5o, h5py.h5f, h5py.h5fd, h5py.h5pl, h5py.h5l, h5py._selector, sentencepiece._sentencepiece, nvtx._lib.lib, nvtx._lib.profiler, zmq.backend.cython._zmq, cuda.bindings._bindings.cynvrtc, cuda.bindings.cynvrtc, cuda.bindings.nvrtc, cuda.nvrtc (total: 259)
Aborted (core dumped)

How can i solve these problem?

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

run the fp4_gemm unittest with RTX 5090

Expected behavior

all the tests should pass and i can run fp4_gemm cutlass kernel

actual behavior

can't run fp4_gemm on RTX 5090

additional notes

no

Metadata

Metadata

Assignees

No one assigned

    Labels

    not a bugSome known limitation, but not a bug.

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions