Skip to content

Commit 35e099e

Browse files
authored
tests: Update support for tgv_gemm to SM100 only and add to ut (#1810)
<!-- .github/pull_request_template.md --> ## πŸ“Œ Description add tgv_gemm to tests and update support surface ## πŸ” Related Issues <!-- Link any related issues here --> ## πŸš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### βœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## πŸ§ͺ Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent f765a2a commit 35e099e

File tree

6 files changed

+45
-10
lines changed

6 files changed

+45
-10
lines changed

β€Žflashinfer/aot.pyβ€Ž

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
gen_gemm_sm100_module,
5454
gen_gemm_sm100_module_cutlass_fp4,
5555
gen_gemm_sm100_module_cutlass_fp8,
56-
gen_gemm_sm100_module_tgv,
56+
gen_tgv_gemm_sm10x_module,
5757
gen_gemm_sm120_module,
5858
gen_gemm_sm120_module_cutlass_fp4,
5959
gen_trtllm_gen_gemm_module,
@@ -412,6 +412,7 @@ def gen_all_modules(
412412
jit_specs: List[JitSpec] = []
413413
has_sm90 = sm_capabilities.get("sm90", False)
414414
has_sm100 = sm_capabilities.get("sm100", False)
415+
has_sm100f = sm_capabilities.get("sm100f", False)
415416
has_sm103 = sm_capabilities.get("sm103", False)
416417
has_sm110 = sm_capabilities.get("sm110", False)
417418
has_sm120 = sm_capabilities.get("sm120", False)
@@ -449,11 +450,21 @@ def gen_all_modules(
449450
jit_specs.append(gen_gemm_sm100_module_cutlass_fp4())
450451
jit_specs.append(gen_gemm_sm100_module_cutlass_fp8())
451452
# Add TGV GEMM modules for both bf16 and fp16
452-
jit_specs.append(gen_gemm_sm100_module_tgv(torch.bfloat16))
453-
jit_specs.append(gen_gemm_sm100_module_tgv(torch.float16))
453+
jit_specs.append(
454+
gen_tgv_gemm_sm10x_module(torch.bfloat16, use_sm_100f=False)
455+
)
456+
jit_specs.append(
457+
gen_tgv_gemm_sm10x_module(torch.float16, use_sm_100f=False)
458+
)
454459
jit_specs.append(gen_mxfp8_quantization_sm100_module())
455460
jit_specs.append(gen_trtllm_gen_gemm_module())
456461
jit_specs.append(gen_trtllm_gen_fused_moe_sm100_module())
462+
if has_sm100f:
463+
# Add TGV GEMM modules compiled with SM100f flags for both bf16 and fp16
464+
jit_specs.append(
465+
gen_tgv_gemm_sm10x_module(torch.bfloat16, use_sm_100f=True)
466+
)
467+
jit_specs.append(gen_tgv_gemm_sm10x_module(torch.float16, use_sm_100f=True))
457468
if has_sm103:
458469
jit_specs.append(gen_fp4_quantization_sm103_module())
459470
if has_sm110:
@@ -588,6 +599,7 @@ def has_sm(compute: str, version: str) -> bool:
588599
return {
589600
"sm90": has_sm("compute_90", "12.3"),
590601
"sm100": has_sm("compute_100", "12.8"),
602+
"sm100f": has_sm("compute_100", "12.9"),
591603
"sm103": has_sm("compute_103", "12.8"),
592604
"sm110": has_sm("compute_110", "12.9"),
593605
"sm120": has_sm("compute_120", "13.0"),

β€Žflashinfer/gemm.pyβ€Ž

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from .jit.cubin_loader import get_cubin
4141
from .utils import (
4242
is_sm100a_supported,
43+
is_sm100f_supported,
4344
is_sm120a_supported,
4445
is_sm121a_supported,
4546
LibraryError,
@@ -65,6 +66,7 @@
6566
gen_jit_spec,
6667
sm90a_nvcc_flags,
6768
sm100a_nvcc_flags,
69+
sm100f_nvcc_flags,
6870
current_compilation_context,
6971
)
7072
from .jit.cubin_loader import setup_cubin_loader
@@ -869,12 +871,16 @@ def get_gemm_sm120_module_cutlass_fp4():
869871
)
870872

871873

872-
def gen_gemm_sm100_module_tgv(dtype: torch.dtype = torch.bfloat16) -> JitSpec:
874+
def gen_tgv_gemm_sm10x_module(
875+
dtype: torch.dtype = torch.bfloat16, use_sm_100f: bool = False
876+
) -> JitSpec:
873877
"""
874878
Generate TGV GEMM module for SM100 architecture.
875879
876880
Args:
877881
dtype: Data type for the GEMM operation (torch.bfloat16 or torch.float16)
882+
use_sm_100f: Whether to compile with SM100f flags (default: False), which makes the compiled kernel
883+
compatible with both B200 and B300 GPUs. However, it's only available with CUDA 12.9+.
878884
879885
Returns:
880886
JitSpec for the TGV GEMM module
@@ -926,7 +932,7 @@ def gen_gemm_sm100_module_tgv(dtype: torch.dtype = torch.bfloat16) -> JitSpec:
926932
return gen_jit_spec(
927933
module_name,
928934
source_paths,
929-
extra_cuda_cflags=sm100a_nvcc_flags,
935+
extra_cuda_cflags=sm100f_nvcc_flags if use_sm_100f else sm100a_nvcc_flags,
930936
extra_include_paths=[
931937
jit_env.FLASHINFER_INCLUDE_DIR,
932938
jit_env.FLASHINFER_CSRC_DIR,
@@ -935,17 +941,21 @@ def gen_gemm_sm100_module_tgv(dtype: torch.dtype = torch.bfloat16) -> JitSpec:
935941

936942

937943
@functools.cache
938-
def get_gemm_sm100_module_tgv(dtype: torch.dtype = torch.bfloat16):
944+
def get_tgv_gemm_sm10x_module(
945+
dtype: torch.dtype = torch.bfloat16, use_sm_100f: bool = False
946+
):
939947
"""
940948
Get and build the TGV GEMM module for the specified dtype.
941949
942950
Args:
943951
dtype: Data type for the GEMM operation (torch.bfloat16 or torch.float16)
952+
use_sm_100f: Whether to compile with SM100f flags (default: False), which makes the compiled kernel
953+
compatible with both B200 and B300 GPUs. However, it's only available with CUDA 12.9+.
944954
945955
Returns:
946956
SimpleNamespace with the runner function
947957
"""
948-
module = gen_gemm_sm100_module_tgv(dtype).build_and_load()
958+
module = gen_tgv_gemm_sm10x_module(dtype, use_sm_100f).build_and_load()
949959

950960
def tgv_gemm_runner():
951961
class TGVGemmRunner(TunableRunner):
@@ -1013,8 +1023,8 @@ def tgv_gemm_sm100(
10131023
- Tensor b is expected to be in column-major layout (transposed from typical PyTorch row-major)
10141024
"""
10151025
# Verify SM100 architecture support
1016-
if not _match_sm_version(a.device, ["100", "103", "110"]):
1017-
raise ValueError("TGV GEMM requires SM100, SM103, or SM110 architecture")
1026+
if not _match_sm_version(a.device, ["100", "103"]):
1027+
raise ValueError("TGV GEMM requires SM100, SM103 architecture")
10181028

10191029
# Verify dtype support
10201030
if a.dtype not in [torch.bfloat16, torch.float16]:
@@ -1028,7 +1038,8 @@ def tgv_gemm_sm100(
10281038
)
10291039

10301040
runners = []
1031-
runners.append(get_gemm_sm100_module_tgv(a.dtype).tgv_gemm_runner())
1041+
use_sm_100f = is_sm100f_supported(a.device)
1042+
runners.append(get_tgv_gemm_sm10x_module(a.dtype, use_sm_100f).tgv_gemm_runner())
10321043

10331044
tuner = AutoTuner.get()
10341045
a_tensor_index = 0

β€Žflashinfer/jit/__init__.pyβ€Ž

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
from .core import gen_jit_spec as gen_jit_spec
6464
from .core import sm90a_nvcc_flags as sm90a_nvcc_flags
6565
from .core import sm100a_nvcc_flags as sm100a_nvcc_flags
66+
from .core import sm100f_nvcc_flags as sm100f_nvcc_flags
6667
from .core import sm103a_nvcc_flags as sm103a_nvcc_flags
6768
from .core import sm110a_nvcc_flags as sm110a_nvcc_flags
6869
from .core import sm120a_nvcc_flags as sm120a_nvcc_flags

β€Žflashinfer/jit/core.pyβ€Ž

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def clear_cache_dir():
7575
sm90a_nvcc_flags = ["-gencode=arch=compute_90a,code=sm_90a"] + common_nvcc_flags
7676
sm100a_nvcc_flags = ["-gencode=arch=compute_100a,code=sm_100a"] + common_nvcc_flags
7777
sm103a_nvcc_flags = ["-gencode=arch=compute_103a,code=sm_103a"] + common_nvcc_flags
78+
sm100f_nvcc_flags = ["-gencode=arch=compute_100f,code=sm_100f"] + common_nvcc_flags
7879
sm110a_nvcc_flags = ["-gencode=arch=compute_110a,code=sm_110a"] + common_nvcc_flags
7980
sm120a_nvcc_flags = ["-gencode=arch=compute_120a,code=sm_120a"] + common_nvcc_flags
8081
sm121a_nvcc_flags = ["-gencode=arch=compute_121a,code=sm_121a"] + common_nvcc_flags

β€Žflashinfer/utils.pyβ€Ž

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,11 @@ def is_sm100a_supported(device: torch.device) -> bool:
466466
return major == 10 and version_at_least(torch.version.cuda, "12.8")
467467

468468

469+
def is_sm100f_supported(device: torch.device) -> bool:
470+
major, _ = get_compute_capability(device)
471+
return major == 10 and version_at_least(torch.version.cuda, "12.9")
472+
473+
469474
def is_sm110a_supported(device: torch.device) -> bool:
470475
major, _ = get_compute_capability(device)
471476
return major == 11 and version_at_least(torch.version.cuda, "13.0")

β€Žtests/unlisted/test_tgv_gemm.pyβ€Ž renamed to β€Žtests/GEMM/test_tgv_gemm.pyβ€Ž

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
tgv_gemm_sm100,
77
)
88

9+
from flashinfer.gemm import _match_sm_version
10+
911

1012
@pytest.mark.parametrize("m", [1, 8, 16, 32, 64])
1113
@pytest.mark.parametrize("n", [1024, 2048, 4096])
@@ -17,6 +19,9 @@ def test_tgv_gemm_sm100(m, n, k, dtype):
1719
B = torch.randn(n, k, device="cuda", dtype=dtype).t() # column major
1820
bias = torch.randn(n, device="cuda", dtype=dtype)
1921

22+
if not _match_sm_version(A.device, ["100", "103"]):
23+
pytest.skip("TGV GEMM requires SM100, SM103 architecture")
24+
2025
print(
2126
f"Input tensors: A {A.shape}, B {B.shape}, bias {bias.shape}, dtype: {A.dtype}",
2227
flush=True,

0 commit comments

Comments
Β (0)