Skip to content

Commit 3b056b3

Browse files
committed
pre-commit
1 parent 877ac53 commit 3b056b3

File tree

4 files changed

+30
-17
lines changed

4 files changed

+30
-17
lines changed

benchmarks/bench_blackwell_attention.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,10 @@ def flops(ms):
7575
else:
7676
return batch_size * qkv_len * qkv_len * num_heads * head_dim * 4 / ms / 1e9
7777

78-
print(
79-
f"bench_fmha_blackwell (batch_size={batch_size}, qkv_len={qkv_len}, num_heads={num_heads}, head_dim={head_dim}, causal={causal}), flops: {flops(ms):.3f} TFLOPs/s"
80-
)
78+
msg = f"bench_fmha_blackwell (batch_size={batch_size}, qkv_len={qkv_len}, num_heads={num_heads}, head_dim={head_dim}, causal={causal}), flops: {flops(ms):.3f} TFLOPs/s"
79+
print(msg)
80+
with open("bench_fmha_blackwell.txt", "a") as f:
81+
f.write(msg + "\n")
8182

8283

8384
if __name__ == "__main__":

benchmarks/bench_deepgemm_blackwell.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,15 @@ def bench_deepgemm_grouped_fp8_blackwell(batch_size, m, n, k, in_dtype, out_dtyp
6767
* 1e-9
6868
/ ms
6969
)
70-
print(
71-
f"group_deepgemm_fp8_nt_groupwise batch_size={batch_size} m={m} n={n} k={k} "
72-
f"in_dtype={in_dtype} out_dtype={out_dtype}: {tflops_per_second:.2f} TFLOPs/s"
73-
f"memory_bandwidth: {memory_bandwidth_per_second:.2f} TB/s"
74-
)
70+
# print(
71+
# f"group_deepgemm_fp8_nt_groupwise batch_size={batch_size} m={m} n={n} k={k} "
72+
# f"in_dtype={in_dtype} out_dtype={out_dtype}: {tflops_per_second:.2f} TFLOPs/s"
73+
# f"memory_bandwidth: {memory_bandwidth_per_second:.2f} TB/s"
74+
# )
75+
msg = f"deepgemm_grouped_fp8_blackwell batch_size={batch_size} m={m} n={n} k={k} in_dtype={in_dtype} out_dtype={out_dtype}: {tflops_per_second:.2f} TFLOPs/s memory_bandwidth: {memory_bandwidth_per_second:.2f} TB/s"
76+
print(msg)
77+
with open("bench_deepgemm_fp8_blackwell.txt", "a") as f:
78+
f.write(msg + "\n")
7579

7680
return tflops_per_second
7781

@@ -118,11 +122,15 @@ def bench_deepgemm_batch_fp8_blackwell(batch_size, m, n, k, in_dtype, out_dtype)
118122
* 1e-9
119123
/ ms
120124
)
121-
print(
122-
f"group_deepgemm_fp8_nt_groupwise batch_size={batch_size} m={m} n={n} k={k} "
123-
f"in_dtype={in_dtype} out_dtype={out_dtype}: {tflops_per_second:.2f} TFLOPs/s"
124-
f"memory_bandwidth: {memory_bandwidth_per_second:.2f} TB/s"
125-
)
125+
msg = f"bench_deepgemm_fp8_blackwell batch_size={batch_size} m={m} n={n} k={k} in_dtype={in_dtype} out_dtype={out_dtype}: {tflops_per_second:.2f} TFLOPs/s memory_bandwidth: {memory_bandwidth_per_second:.2f} TB/s"
126+
print(msg)
127+
with open("bench_deepgemm_fp8_blackwell.txt", "a") as f:
128+
f.write(msg + "\n")
129+
# print(
130+
# f"group_deepgemm_fp8_nt_groupwise batch_size={batch_size} m={m} n={n} k={k} "
131+
# f"in_dtype={in_dtype} out_dtype={out_dtype}: {tflops_per_second:.2f} TFLOPs/s"
132+
# f"memory_bandwidth: {memory_bandwidth_per_second:.2f} TB/s"
133+
# )
126134

127135
return tflops_per_second
128136

benchmarks/bench_groupwise_gemm_fp8_blackwell.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,9 @@ def bench_groupwise_gemm_fp8_blackwell(m, n, k, in_dtype, out_dtype):
170170
gemm_fp8_nt_groupwise(a, b, a_scale, b_scale, out=out, scale_major_mode="MN")
171171

172172
measurements = bench_gpu_time(
173-
lambda: gemm_fp8_nt_groupwise(a, b, a_scale, b_scale, out=out, scale_major_mode="MN")
173+
lambda: gemm_fp8_nt_groupwise(
174+
a, b, a_scale, b_scale, out=out, scale_major_mode="MN"
175+
)
174176
)
175177
ms = np.median(measurements)
176178
tflops_per_second = 2 * m * n * k * 1e-9 / ms

benchmarks/bench_tgv_gemm.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,9 @@ def test_tgv_gemm_bf16_sm100_perf():
9999
torch.cuda.synchronize()
100100
end_time = time.time()
101101
cublas_avg_time = (end_time - start_time) / 100
102-
print(f"CUBLAS average time: {cublas_avg_time * 1000:.6f} ms, {flops/cublas_avg_time:.3f} TFLOPS")
102+
print(
103+
f"CUBLAS average time: {cublas_avg_time * 1000:.6f} ms, {flops / cublas_avg_time:.3f} TFLOPS"
104+
)
103105

104106
# Warmup
105107
with autotune(tune_mode=True):
@@ -128,7 +130,7 @@ def test_tgv_gemm_bf16_sm100_perf():
128130

129131
tgv_avg_time = (end_time - start_time) / 100
130132
print(
131-
f"TGV average time: {tgv_avg_time * 1000:.6f} ms, {flops/tgv_avg_time:.3f} TFLOPS, speedup: {cublas_avg_time / tgv_avg_time:.2f}x"
133+
f"TGV average time: {tgv_avg_time * 1000:.6f} ms, {flops / tgv_avg_time:.3f} TFLOPS, speedup: {cublas_avg_time / tgv_avg_time:.2f}x"
132134
)
133135

134136
# Test with PDL
@@ -151,7 +153,7 @@ def test_tgv_gemm_bf16_sm100_perf():
151153

152154
pdl_avg_time = (end_time - start_time) / 100
153155
print(
154-
f"PDL average time: {pdl_avg_time * 1000:.6f} ms, {flops/pdl_avg_time:.3f} TFLOPS, speedup: {cublas_avg_time / pdl_avg_time:.2f}x"
156+
f"PDL average time: {pdl_avg_time * 1000:.6f} ms, {flops / pdl_avg_time:.3f} TFLOPS, speedup: {cublas_avg_time / pdl_avg_time:.2f}x"
155157
)
156158

157159
# Store results for CSV

0 commit comments

Comments
 (0)