Skip to content

Commit 3e3dc6c

Browse files
committed
squashed changes for rebase
1 parent 075775e commit 3e3dc6c

File tree

14 files changed

+882
-33
lines changed

14 files changed

+882
-33
lines changed

benchmarks/bench_mm_fp8.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
"""
2+
Copyright (c) 2025 by FlashInfer team.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
from typing import Dict
18+
from flashinfer.autotuner import autotune
19+
from flashinfer.trtllm_low_latency_gemm import prepare_low_latency_gemm_weights
20+
import numpy as np
21+
import torch
22+
23+
from flashinfer import mm_fp8
24+
from flashinfer.testing.utils import bench_gpu_time
25+
26+
_cache_permute_indices: Dict[torch.Size, torch.Tensor] = {}
27+
28+
29+
def to_float8(
30+
x: torch.Tensor, dtype=torch.float8_e4m3fn
31+
) -> tuple[torch.Tensor, torch.Tensor]:
32+
finfo = torch.finfo(dtype)
33+
min_val, max_val = x.aminmax()
34+
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
35+
scale = finfo.max / amax
36+
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
37+
return x_scl_sat.to(dtype), scale.float().reciprocal()
38+
39+
40+
def bench_mm_fp8(m, n, k, in_dtype, out_dtype):
41+
torch.manual_seed(123)
42+
input_tensor = torch.randn([m, k], device="cuda", dtype=torch.bfloat16)
43+
input_fp8, input_inv_s = to_float8(input_tensor, dtype=in_dtype)
44+
45+
# mat2 row major -> column major
46+
mat2 = torch.randn([n, k], device="cuda", dtype=torch.bfloat16)
47+
mat2_fp8, mat2_inv_s = to_float8(mat2, dtype=in_dtype)
48+
49+
res = torch.zeros([m, n], device="cuda", dtype=out_dtype)
50+
global_scale = input_inv_s * mat2_inv_s
51+
52+
# Do row shuffling.
53+
prepared_weights = prepare_low_latency_gemm_weights(
54+
mat2_fp8, _cache_permute_indices
55+
)
56+
57+
with autotune(True):
58+
mm_fp8(
59+
input_fp8,
60+
prepared_weights,
61+
global_scale,
62+
out=res,
63+
)
64+
65+
measurements = bench_gpu_time(
66+
lambda: mm_fp8(
67+
input_fp8,
68+
prepared_weights,
69+
global_scale,
70+
res,
71+
),
72+
dry_run_time_ms=500,
73+
repeat_time_ms=2500,
74+
use_cuda_graph=True,
75+
)
76+
ms = np.median(measurements)
77+
tflops_per_second = 2 * m * n * k * 1e-9 / ms
78+
79+
bandwidth = (
80+
(
81+
input_fp8.numel() * input_fp8.element_size()
82+
+ prepared_weights.numel() * prepared_weights.element_size()
83+
+ res.numel() * res.element_size()
84+
)
85+
/ ms
86+
/ 1e9
87+
)
88+
89+
print(
90+
f"mm_fp8 m={m} n={n} k={k} in_dtype={in_dtype} out_dtype={out_dtype}: {tflops_per_second:.2f} TFLOPs/s over {ms:.6f} ms, {bandwidth:.2f} TB/s"
91+
)
92+
93+
94+
if __name__ == "__main__":
95+
for m in [1, 2, 4, 8, 16, 32, 64]:
96+
for n in [2560, 5120, 8192]:
97+
for k in [16384, 32768]:
98+
bench_mm_fp8(m, n, k, torch.float8_e4m3fn, torch.bfloat16)

0 commit comments

Comments
 (0)