Skip to content

Commit cb50446

Browse files
committed
added compute capability check in test and document support for trtllm_low_latency_gemm
1 parent a4e8f34 commit cb50446

File tree

2 files changed

+6
-0
lines changed

2 files changed

+6
-0
lines changed

flashinfer/trtllm_low_latency_gemm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ def trtllm_low_latency_gemm(
124124
out: torch.Tensor,
125125
) -> None:
126126
r"""GEMM optimized for low M dimension. B needs to be shuffled and its layout needs to be adjusted.
127+
Only supported on Blackwell GPUs.
127128
128129
Parameters
129130
----------

tests/gemm/test_mm_fp8.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import Dict
2+
from flashinfer.utils import get_compute_capability
23
import pytest
34
import torch
45
import torch.nn.functional as F
@@ -24,6 +25,10 @@ def test_mm_fp8(
2425
mat2_dtype: torch.dtype,
2526
res_dtype: torch.dtype,
2627
):
28+
compute_capability = get_compute_capability(torch.device(device="cuda"))
29+
if compute_capability[0] not in [10]:
30+
pytest.skip("mm_fp8 is only supported on Blackwell GPUs.")
31+
2732
torch.manual_seed(123)
2833
input = torch.randn([m, k], device="cuda", dtype=torch.bfloat16)
2934
input_fp8, input_inv_s = to_float8(input, dtype=input_dtype)

0 commit comments

Comments
 (0)