Skip to content

Commit cd02b1a

Browse files
committed
fix
1 parent ebea4bd commit cd02b1a

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

benchmarks/bench_groupwise_gemm_fp8_blackwell.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,10 +167,10 @@ def bench_groupwise_gemm_fp8_blackwell(m, n, k, in_dtype, out_dtype):
167167
b_scale = torch.rand((k // 128, n // 128), dtype=torch.float32, device="cuda")
168168

169169
out = torch.empty((m, n), dtype=out_dtype, device="cuda")
170-
gemm_fp8_nt_groupwise(a, b, a_scale, b_scale, out=out)
170+
gemm_fp8_nt_groupwise(a, b, a_scale, b_scale, out=out, scale_major_mode="MN")
171171

172172
measurements = bench_gpu_time(
173-
lambda: gemm_fp8_nt_groupwise(a, b, a_scale, b_scale, out=out)
173+
lambda: gemm_fp8_nt_groupwise(a, b, a_scale, b_scale, out=out, scale_major_mode="MN")
174174
)
175175
ms = np.median(measurements)
176176
tflops_per_second = 2 * m * n * k * 1e-9 / ms

benchmarks/bench_tgv_gemm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def test_tgv_gemm_bf16_sm100_perf():
6565

6666
for m, n, k, has_bias, description in test_cases:
6767
print(f"\n--- {description}: M={m}, N={n}, K={k}, has_bias={has_bias} ---")
68-
68+
flops = m * n * k * 2 / 1e12
6969
# Create tensors
7070
A = torch.randn(m, k, device="cuda", dtype=torch.bfloat16)
7171
B = torch.randn(n, k, device="cuda", dtype=torch.bfloat16).t()
@@ -99,7 +99,7 @@ def test_tgv_gemm_bf16_sm100_perf():
9999
torch.cuda.synchronize()
100100
end_time = time.time()
101101
cublas_avg_time = (end_time - start_time) / 100
102-
print(f"CUBLAS average time: {cublas_avg_time * 1000:.6f} ms")
102+
print(f"CUBLAS average time: {cublas_avg_time * 1000:.6f} ms, {flops/cublas_avg_time:.3f} TFLOPS")
103103

104104
# Warmup
105105
with autotune(tune_mode=True):
@@ -128,7 +128,7 @@ def test_tgv_gemm_bf16_sm100_perf():
128128

129129
tgv_avg_time = (end_time - start_time) / 100
130130
print(
131-
f"TGV average time: {tgv_avg_time * 1000:.6f} ms, speedup: {cublas_avg_time / tgv_avg_time:.2f}x"
131+
f"TGV average time: {tgv_avg_time * 1000:.6f} ms, {flops/tgv_avg_time:.3f} TFLOPS, speedup: {cublas_avg_time / tgv_avg_time:.2f}x"
132132
)
133133

134134
# Test with PDL
@@ -151,7 +151,7 @@ def test_tgv_gemm_bf16_sm100_perf():
151151

152152
pdl_avg_time = (end_time - start_time) / 100
153153
print(
154-
f"PDL average time: {pdl_avg_time * 1000:.6f} ms, speedup: {cublas_avg_time / pdl_avg_time:.2f}x"
154+
f"PDL average time: {pdl_avg_time * 1000:.6f} ms, {flops/pdl_avg_time:.3f} TFLOPS, speedup: {cublas_avg_time / pdl_avg_time:.2f}x"
155155
)
156156

157157
# Store results for CSV

0 commit comments

Comments
 (0)