Skip to content

Commit 9c6c265

Browse files
authored
misc: fix some B200 GEMM bench (#1883)
<!-- .github/pull_request_template.md --> ## 📌 Description Before it couldn't run due to missing layout <img width="1047" height="187" alt="image" src="https://github.com/user-attachments/assets/c1995f60-d4f6-4ed6-8a72-97f24f7e8ec5" /> ### After <img width="894" height="356" alt="image" src="https://github.com/user-attachments/assets/72f12a21-99f1-4374-8453-2965fe716856" /> <img width="891" height="227" alt="image" src="https://github.com/user-attachments/assets/b7911d02-5914-4eb0-b69a-acd4d6c1595e" /> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent ebea4bd commit 9c6c265

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

benchmarks/bench_groupwise_gemm_fp8_blackwell.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,10 +167,12 @@ def bench_groupwise_gemm_fp8_blackwell(m, n, k, in_dtype, out_dtype):
167167
b_scale = torch.rand((k // 128, n // 128), dtype=torch.float32, device="cuda")
168168

169169
out = torch.empty((m, n), dtype=out_dtype, device="cuda")
170-
gemm_fp8_nt_groupwise(a, b, a_scale, b_scale, out=out)
170+
gemm_fp8_nt_groupwise(a, b, a_scale, b_scale, out=out, scale_major_mode="MN")
171171

172172
measurements = bench_gpu_time(
173-
lambda: gemm_fp8_nt_groupwise(a, b, a_scale, b_scale, out=out)
173+
lambda: gemm_fp8_nt_groupwise(
174+
a, b, a_scale, b_scale, out=out, scale_major_mode="MN"
175+
)
174176
)
175177
ms = np.median(measurements)
176178
tflops_per_second = 2 * m * n * k * 1e-9 / ms

benchmarks/bench_tgv_gemm.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def test_tgv_gemm_bf16_sm100_perf():
6565

6666
for m, n, k, has_bias, description in test_cases:
6767
print(f"\n--- {description}: M={m}, N={n}, K={k}, has_bias={has_bias} ---")
68-
68+
flops = m * n * k * 2 / 1e12
6969
# Create tensors
7070
A = torch.randn(m, k, device="cuda", dtype=torch.bfloat16)
7171
B = torch.randn(n, k, device="cuda", dtype=torch.bfloat16).t()
@@ -99,7 +99,9 @@ def test_tgv_gemm_bf16_sm100_perf():
9999
torch.cuda.synchronize()
100100
end_time = time.time()
101101
cublas_avg_time = (end_time - start_time) / 100
102-
print(f"CUBLAS average time: {cublas_avg_time * 1000:.6f} ms")
102+
print(
103+
f"CUBLAS average time: {cublas_avg_time * 1000:.6f} ms, {flops / cublas_avg_time:.3f} TFLOPS"
104+
)
103105

104106
# Warmup
105107
with autotune(tune_mode=True):
@@ -128,7 +130,7 @@ def test_tgv_gemm_bf16_sm100_perf():
128130

129131
tgv_avg_time = (end_time - start_time) / 100
130132
print(
131-
f"TGV average time: {tgv_avg_time * 1000:.6f} ms, speedup: {cublas_avg_time / tgv_avg_time:.2f}x"
133+
f"TGV average time: {tgv_avg_time * 1000:.6f} ms, {flops / tgv_avg_time:.3f} TFLOPS, speedup: {cublas_avg_time / tgv_avg_time:.2f}x"
132134
)
133135

134136
# Test with PDL
@@ -151,7 +153,7 @@ def test_tgv_gemm_bf16_sm100_perf():
151153

152154
pdl_avg_time = (end_time - start_time) / 100
153155
print(
154-
f"PDL average time: {pdl_avg_time * 1000:.6f} ms, speedup: {cublas_avg_time / pdl_avg_time:.2f}x"
156+
f"PDL average time: {pdl_avg_time * 1000:.6f} ms, {flops / pdl_avg_time:.3f} TFLOPS, speedup: {cublas_avg_time / pdl_avg_time:.2f}x"
155157
)
156158

157159
# Store results for CSV

0 commit comments

Comments
 (0)