Skip to content

Commit 421433e

Browse files
authored
feat: Add support for bmm mxfp8 (flashinfer-ai#2256)
<!-- .github/pull_request_template.md --> ## πŸ“Œ Description Add support for GEMM with MXFP8 (`bmm_mxfp8`). At this time only cuDNN is supported. Added test `tests/gemm/test_bmm_mxfp8.py` Added routine `bmm_mxfp8` to `flashinfer_benchmark`. Benchmark results for `bmm_mxfp8` (on B200 GPU): ``` python benchmarks/flashinfer_benchmark.py \ --routine bmm_mxfp8 -vv \ --num_iters 30 \ --batch_size 128 \ --m 512 --n 512 --k 4096 \ --out_dtype bfloat16 \ --backends cudnn \ --refcheck [PERF] cudnn :: median time 0.117 ms; std 0.001 ms; achieved tflops 2347.650 TFLOPs/sec; achieved tb_per_sec 0.040 TB/sec ``` And `bmm_fp8` for comparison: ``` python benchmarks/flashinfer_benchmark.py \ --routine bmm_fp8 -vv \ --num_iters 30 \ --batch_size 128 \ --m 512 --n 512 --k 4096 \ --input_dtype fp8_e4m3 \ --mat2_dtype fp8_e4m3 \ --out_dtype bfloat16 \ --backends cudnn \ --refcheck [PERF] cudnn :: median time 0.116 ms; std 0.001 ms; achieved tflops 2369.049 TFLOPs/sec; achieved tb_per_sec 0.041 TB/sec ``` When running `ncu` the kernel `nvjet_sm100_qqtst_128x256_128x6_2x1_2cta_v_bz_Avec32UE8M0_Bvec32UE8M0_NNT` seems to trigger. <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## πŸ” Related Issues flashinfer-ai#2209 <!-- Link any related issues here --> ## πŸš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### βœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## πŸ§ͺ Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added MXFP8 (mixed 8-bit float) batched matrix multiplication with cuDNN acceleration and package-level export. * **Tests** * Added parameterized tests validating MXFP8 BMM against reference results across shapes, dtypes, layouts, backends, and autotune modes. * **Chores** * Updated benchmark catalog and backend-support mappings to include MXFP8 BMM. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Daniel Serebrenik <[email protected]>
1 parent 43ec6c7 commit 421433e

File tree

6 files changed

+744
-0
lines changed

6 files changed

+744
-0
lines changed

β€Žbenchmarks/routines/flashinfer_benchmark_utils.pyβ€Ž

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@
100100
"gemm_fp8_nt_groupwise",
101101
"group_gemm_fp8_nt_groupwise",
102102
"bmm_fp8",
103+
"bmm_mxfp8",
103104
"mm_fp4",
104105
],
105106
"moe": [
@@ -236,6 +237,16 @@ def dtype_str_to_torch_dtype(dtype_str):
236237
"10.3": ["cudnn", "cublas", "cutlass"],
237238
"12.0": ["cudnn", "cublas"],
238239
},
240+
"bmm_mxfp8": {
241+
"7.5": [],
242+
"8.0": [],
243+
"8.6": [],
244+
"8.9": [],
245+
"9.0": [],
246+
"10.0": ["cudnn"],
247+
"10.3": ["cudnn"],
248+
"12.0": [],
249+
},
239250
# Note: mm_fp4 uses support checkers to filter backends, so it is not listed here
240251
# MOE
241252
"trtllm_fp4_block_scale_moe": {

β€Žbenchmarks/routines/gemm.pyβ€Ž

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import flashinfer
99
from flashinfer.autotuner import autotune
10+
from flashinfer.fp8_quantization import mxfp8_quantize
1011
from flashinfer.testing.utils import (
1112
bench_gpu_time,
1213
dequantize_fp8,
@@ -38,6 +39,8 @@ def run_gemm_test(args):
3839
return testGroupGemmFp8NtGroupwise(args)
3940
elif args.routine == "bmm_fp8":
4041
return testBmmFp8(args)
42+
elif args.routine == "bmm_mxfp8":
43+
return testBmmMxfp8(args)
4144
elif args.routine == "mm_fp4":
4245
return testMmFp4(args)
4346
else:
@@ -144,6 +147,7 @@ def parse_gemm_args(line, parser):
144147
action="store_true",
145148
help="In mm_fp4, whether to use nvfp4 quantization or mxfp4 quantization, defaults to False.",
146149
)
150+
# TODO: add bmm_mxfp8 ?
147151
parser.add_argument(
148152
"--autotune",
149153
action="store_true",
@@ -757,6 +761,211 @@ def run_backend(backend, input_fp8, mat2_fp8, input_inv_s, mat2_inv_s):
757761
return res
758762

759763

764+
def testBmmMxfp8(args):
765+
"""
766+
Test bmm_mxfp8 API.
767+
768+
This test:
769+
1. Generates random input tensors
770+
2. Quantizes input tensors to MXFP8
771+
3. Runs bmm_mxfp8
772+
4. Runs reference check
773+
5. Measures performance metrics (TFLOPS, TB/sec)
774+
775+
Args:
776+
args: Parsed command line arguments containing test configuration
777+
778+
Returns:
779+
dict: List of dictionaries containing performance results
780+
"""
781+
if args.verbose >= 1:
782+
print("[INFO] Running testBmmMxfp8")
783+
print(f"[INFO] FlashInfer version: {flashinfer.__version__}")
784+
785+
device = get_device(args)
786+
if args.generate_repro_command:
787+
print(
788+
f"[INFO] To reproduce this test case, run the following command: {args.repro_command}"
789+
)
790+
791+
## Parse input arguments
792+
backends = args.backends
793+
batch_size = args.batch_size
794+
m = args.m
795+
n = args.n
796+
k = args.k
797+
res_dtype = args.out_dtype
798+
is_cuda_graph_compatible = not args.no_cuda_graph
799+
run_refcheck = args.refcheck
800+
autotune_supported_backends = [
801+
"cudnn",
802+
]
803+
res = []
804+
805+
backends = filter_backends_by_compute_capability(backends, args.routine, device)
806+
if len(backends) == 0:
807+
print("[ERROR] No backends to test. Exiting.")
808+
return res
809+
810+
res_dtype = dtype_str_to_torch_dtype(args.out_dtype)
811+
if res_dtype not in [torch.bfloat16, torch.float16]:
812+
raise ValueError(
813+
f"Unsupported res dtype: {res_dtype}. Supported dtypes are bfloat16 and float16."
814+
)
815+
## Done parsing input arguments
816+
817+
if getattr(args, "autotune", False):
818+
backends_to_remove = []
819+
for cur_backend in backends:
820+
if cur_backend not in autotune_supported_backends:
821+
print(f"[INFO] {cur_backend} backend does not support autotune")
822+
backends_to_remove.append(cur_backend)
823+
for cur_backend in backends_to_remove:
824+
backends.remove(cur_backend)
825+
826+
if len(backends) == 0:
827+
print("[ERROR] No backends to test. Exiting.")
828+
return res
829+
830+
## Prepare input tensors
831+
input = torch.randn([batch_size, m, k], device=device, dtype=torch.bfloat16)
832+
input_mxfp8, input_scale = mxfp8_quantize(input, is_sf_swizzled_layout=True)
833+
834+
mat2 = (
835+
torch.randn([batch_size, n, k], device=device, dtype=torch.bfloat16)
836+
.transpose(-2, -1)
837+
.contiguous()
838+
)
839+
mat2_mxfp8, mat2_scale = mxfp8_quantize(mat2, is_sf_swizzled_layout=True)
840+
841+
if args.verbose >= 2:
842+
print(f"[VVERBOSE] {input_mxfp8.shape = }")
843+
print(f"[VVERBOSE] {input_mxfp8.dtype = }")
844+
print(f"[VVERBOSE] {mat2_mxfp8.shape = }")
845+
print(f"[VVERBOSE] {mat2_mxfp8.dtype = }")
846+
print(f"[VVERBOSE] {input_scale.shape = }")
847+
print(f"[VVERBOSE] {input_scale.dtype = }")
848+
print(f"[VVERBOSE] {mat2_scale.shape = }")
849+
print(f"[VVERBOSE] {mat2_scale.dtype = }")
850+
851+
def run_backend(backend, input_mxfp8, mat2_mxfp8, input_scale, mat2_scale):
852+
if backend == "cudnn":
853+
return flashinfer.gemm.bmm_mxfp8(
854+
A=input_mxfp8,
855+
B=mat2_mxfp8,
856+
A_scale=input_scale,
857+
B_scale=mat2_scale,
858+
dtype=res_dtype,
859+
backend=backend,
860+
)
861+
else:
862+
raise ValueError(f"Unsupported backend: {backend}")
863+
864+
has_reference_output = False
865+
if run_refcheck:
866+
reference_output = torch.bmm(input, mat2)
867+
has_reference_output = True
868+
869+
if getattr(args, "autotune", False):
870+
warmup_iters = (
871+
args.dry_run_iters if args.dry_run_iters and args.dry_run_iters > 0 else 10
872+
)
873+
for cur_backend in backends:
874+
if cur_backend in autotune_supported_backends:
875+
if args.verbose >= 1:
876+
print(f"[INFO] Autotune warmup for bmm_mxfp8: {warmup_iters} iters")
877+
with autotune(True):
878+
for _ in range(warmup_iters):
879+
run_backend(
880+
cur_backend,
881+
input_mxfp8,
882+
mat2_mxfp8,
883+
input_scale,
884+
mat2_scale,
885+
)
886+
887+
# Storage for timing results and outputs
888+
backend_times = {backend: [] for backend in backends}
889+
outputs = {}
890+
for cur_backend in backends:
891+
if run_refcheck:
892+
outputs[cur_backend] = run_backend(
893+
cur_backend, input_mxfp8, mat2_mxfp8, input_scale, mat2_scale
894+
).detach()
895+
backend_times[cur_backend] = bench_gpu_time(
896+
fn=run_backend,
897+
dry_run_iters=args.dry_run_iters,
898+
repeat_iters=args.num_iters,
899+
sleep_after_run=True,
900+
enable_cupti=args.use_cupti,
901+
use_cuda_graph=is_cuda_graph_compatible,
902+
cold_l2_cache=True,
903+
input_args=(cur_backend, input_mxfp8, mat2_mxfp8, input_scale, mat2_scale),
904+
)
905+
906+
min_cos_sim = 0.9 # TODO: check if can be increased
907+
908+
tested_backends = list(outputs.keys())
909+
tested_outputs = list(outputs.values())
910+
if len(tested_backends) > 0:
911+
if run_refcheck and has_reference_output:
912+
for i in range(len(tested_backends)):
913+
cos_sim = F.cosine_similarity(
914+
reference_output.reshape(-1),
915+
tested_outputs[i].reshape(-1),
916+
dim=0,
917+
)
918+
if cos_sim < min_cos_sim:
919+
print(
920+
f"[ERROR] Output tensor mismatch between backends {tested_backends[0]} and {tested_backends[i]}"
921+
)
922+
if not args.allow_output_mismatch:
923+
raise AssertionError(
924+
f"[ERROR] Backend {tested_backends[i]} output mismatch with cos_sim={cos_sim}"
925+
)
926+
927+
for backend in backends:
928+
backend_name = backend + (
929+
"_autotune"
930+
if (
931+
getattr(args, "autotune", False)
932+
and backend in autotune_supported_backends
933+
)
934+
else ""
935+
)
936+
if len(backend_times[backend]) > 0:
937+
median_time = np.median(backend_times[backend])
938+
std_time = np.std(backend_times[backend])
939+
problem_flops = 2 * m * n * k * batch_size
940+
# MXFP8 uses fp8_e4m3fn for data (1 byte) and uint8 for scales
941+
# Scale tensors are much smaller, so approximate as 1 byte per element for simplicity
942+
problem_bytes = (
943+
m * k * torch.float8_e4m3fn.itemsize
944+
+ n * k * torch.float8_e4m3fn.itemsize
945+
+ m * n * res_dtype.itemsize
946+
) * batch_size
947+
tflops = problem_flops / (10**9 * median_time) # in TFLOPs/sec
948+
tb_per_sec = problem_bytes / (10**9 * median_time) # in TB/sec
949+
print_perf_metrics(backend_name, median_time, std_time, tflops, tb_per_sec)
950+
951+
if args.output_path is not None:
952+
cur_res = defaultdict(str)
953+
cur_res["batch_size"] = batch_size
954+
cur_res["routine"] = args.routine
955+
cur_res["median_time"] = median_time
956+
cur_res["std_time"] = std_time
957+
cur_res["tflops"] = tflops
958+
cur_res["tb_per_sec"] = tb_per_sec
959+
cur_res["m"] = m
960+
cur_res["n"] = n
961+
cur_res["k"] = k
962+
cur_res["out_dtype"] = res_dtype
963+
cur_res["backend"] = backend_name
964+
cur_res["case_tag"] = args.case_tag
965+
res.append(cur_res)
966+
return res
967+
968+
760969
def testMmFp4(args):
761970
"""
762971
Test mm_fp4 API.

β€Žflashinfer/__init__.pyβ€Ž

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
)
8787
from .gemm import SegmentGEMMWrapper as SegmentGEMMWrapper
8888
from .gemm import bmm_fp8 as bmm_fp8
89+
from .gemm import bmm_mxfp8 as bmm_mxfp8
8990
from .gemm import mm_fp4 as mm_fp4
9091
from .gemm import mm_fp8 as mm_fp8
9192
from .gemm import tgv_gemm_sm100 as tgv_gemm_sm100

β€Žflashinfer/gemm/__init__.pyβ€Ž

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .gemm_base import SegmentGEMMWrapper as SegmentGEMMWrapper
22
from .gemm_base import bmm_fp8 as bmm_fp8
3+
from .gemm_base import bmm_mxfp8 as bmm_mxfp8
34
from .gemm_base import mm_fp4 as mm_fp4
45
from .gemm_base import mm_fp8 as mm_fp8
56
from .gemm_base import tgv_gemm_sm100 as tgv_gemm_sm100
@@ -22,6 +23,7 @@
2223
__all__ = [
2324
"SegmentGEMMWrapper",
2425
"bmm_fp8",
26+
"bmm_mxfp8",
2527
"mm_fp4",
2628
"mm_fp8",
2729
"tgv_gemm_sm100",

0 commit comments

Comments
Β (0)