Skip to content

Commit 1659ad9

Browse files
AndreyPavlenkoetiottoanmyachev
authored
[benchmarks] Reworked the conversion benchmark and added more tests for up/down casts (#4800)
Here is an example of the benchmark report: ```TSV float-upcast: N fp8e4nv->fp16 fp8e4nv->bf16 fp8e4nv->fp32 fp8e5->fp16 fp8e5->bf16 fp8e5->fp32 fp16->fp32 bf16->fp32 0 4096.0 0.010524 0.010252 0.010187 0.080630 0.010488 0.043948 0.139510 0.193939 1 16384.0 0.042079 0.041001 0.040740 0.322013 0.041959 0.175794 0.556522 0.771375 2 65536.0 0.168248 0.163971 0.162928 1.286028 0.167731 0.701971 2.223066 3.079699 3 262144.0 0.507677 0.501730 0.493866 5.088199 0.503581 2.786395 8.820458 12.158813 4 1048576.0 1.202495 1.163017 1.182960 12.700775 1.169868 8.868200 30.411137 26.532794 5 4194304.0 1.623786 1.587861 1.594501 21.531335 1.587212 15.298745 50.582537 42.538580 6 16777216.0 1.961367 1.927211 1.924611 33.078107 1.923428 23.340590 83.022644 69.788752 7 67108864.0 2.121850 2.079672 2.081887 35.847220 2.076763 25.215246 95.001223 78.473379 float-downcast: N fp16->fp8e4nv/rtne fp16->fp8e5/rtne bf16->fp8e4nv/rtne bf16->fp8e5/rtne fp32->fp8e4nv/rtne fp32->fp8e4nv/rtz fp32->fp8e5/rtne fp32->fp8e5/rtz fp32->fp16/rtne fp32->fp16/rtz fp32->bf16/rtne fp32->bf16/rtz 0 4096.0 0.023953 0.071309 0.020427 0.020877 0.028330 0.202172 0.051561 0.304762 0.436674 0.508189 0.432981 0.135899 1 16384.0 0.095768 0.285237 0.081626 0.083490 0.113290 0.807094 0.206036 1.213630 1.731924 2.017734 1.724632 0.542517 2 65536.0 0.382893 1.138173 0.326310 0.333618 0.452534 3.222026 0.823523 4.825920 6.898526 8.031373 6.855230 2.167196 3 262144.0 1.490979 3.903276 1.226577 1.254518 1.786452 11.270163 3.193762 17.558205 26.969547 29.721542 26.586613 8.594885 4 1048576.0 4.111418 5.797081 2.281100 2.298904 6.448014 29.355431 10.123344 53.335504 60.401843 72.067079 57.174264 18.035363 5 4194304.0 7.612720 12.685410 4.564980 4.603662 12.267993 45.939803 18.057103 85.667974 103.768036 114.161786 97.000555 28.142136 6 16777216.0 10.127255 14.296733 5.533675 5.576271 18.747797 55.520604 26.028944 111.877941 179.781569 162.617195 168.074694 46.045713 7 67108864.0 10.861711 15.523535 5.929267 5.961534 18.644044 57.701748 27.834568 117.730718 204.201753 181.876698 191.695795 49.386877 ``` --------- Co-authored-by: Ettore Tiotto <[email protected]> Co-authored-by: Anatoly Myachev <[email protected]>
1 parent 869ed8a commit 1659ad9

File tree

4 files changed

+95
-66
lines changed

4 files changed

+95
-66
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .float_conversion import benchmark # type: ignore # noqa: F401
1+
from .float_conversion import get_benchmarks, run_benchmarks # type: ignore # noqa: F401

benchmarks/micro_benchmarks/conversion/float_conversion/float_conversion.py

Lines changed: 77 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -2,67 +2,96 @@
22
import triton
33
import triton.language as tl
44

5+
from triton_kernels_benchmark import Benchmark, do_bench, perf_report
6+
7+
TYPES = {
8+
tl.float8e4nv: torch.float8_e4m3fn, tl.float8e5: torch.float8_e5m2, tl.float16: torch.float16, tl.bfloat16:
9+
torch.bfloat16, tl.float32: torch.float32
10+
}
11+
512

613
@triton.jit
7-
def float_trunc_kernel(
14+
def float_conversion_kernel(
815
x_ptr,
16+
y_ptr,
917
n_elements,
1018
BLOCK_SIZE: tl.constexpr,
11-
target_type: tl.constexpr,
19+
x_type: tl.constexpr,
20+
y_type: tl.constexpr,
21+
rnd: tl.constexpr,
1222
):
1323
pid = tl.program_id(axis=0)
1424
block_start = pid * BLOCK_SIZE
1525
offsets = block_start + tl.arange(0, BLOCK_SIZE)
1626
mask = offsets < n_elements
27+
x_itype = tl.int8 if x_type.itemsize == 1 else tl.int16 if x_type.itemsize == 2 else tl.int32
28+
y_itype = tl.int8 if y_type.itemsize == 1 else tl.int16 if y_type.itemsize == 2 else tl.int32
1729

1830
x = tl.load(x_ptr + offsets, mask=mask)
31+
converted = x.to(y_type, fp_downcast_rounding=rnd)
32+
x = tl.cast(x, x_itype, bitcast=True)
33+
y = tl.cast(converted, y_itype, bitcast=True)
34+
for i in range(99):
35+
x += tl.full(x.shape, i, x_itype)
36+
converted = tl.cast(x, x_type, bitcast=True).to(y_type, fp_downcast_rounding=rnd)
37+
y += tl.cast(converted, y_itype, bitcast=True)
38+
y = tl.cast(y, y_type, bitcast=True)
39+
tl.store(y_ptr + offsets, y, mask=mask)
40+
41+
42+
def get_bench(x_type, y_type):
43+
assert x_type.itemsize < y_type.itemsize
44+
plot_name = f'{x_type}-{y_type}'
45+
line_vals = [(x_type, y_type, None), (y_type, x_type, 'rtne')]
46+
line_names = [f'{x_type}->{y_type}', f'{y_type}->{x_type}-rtne']
47+
if y_type == tl.float32:
48+
line_vals.append((y_type, x_type, 'rtz'))
49+
line_names.append(f'{y_type}->{x_type}-rtz')
50+
51+
@perf_report(
52+
Benchmark(
53+
x_names=['N'],
54+
x_vals=[2**i for i in range(12, 28, 2)],
55+
line_arg='args',
56+
line_vals=line_vals,
57+
line_names=line_names,
58+
styles=[(c, s) for c in 'bgry' for s in ('-', '--', '-.', ':')],
59+
ylabel=('GB/s', ),
60+
plot_name=plot_name,
61+
args={},
62+
))
63+
def bench(N, args):
64+
quantiles = [0.5, 0.2, 0.8]
65+
x_type = args[0]
66+
y_type = args[1]
67+
if x_type.itemsize == 1:
68+
x = torch.rand(N, dtype=torch.float16, device='xpu', requires_grad=True).to(TYPES[x_type])
69+
else:
70+
x = torch.rand(N, dtype=TYPES[x_type], device='xpu', requires_grad=True)
71+
y = torch.empty_like(x, dtype=TYPES[y_type], device='xpu')
72+
rnd = args[2] if x_type.itemsize > y_type.itemsize else None
73+
74+
def fwd():
75+
BLOCK_SIZE = 4096
76+
grid = lambda meta: (triton.cdiv(N, meta['BLOCK_SIZE']), )
77+
float_conversion_kernel[grid](x, y, N, BLOCK_SIZE, x_type, y_type, rnd)
78+
return x
79+
80+
_, min_ms, max_ms, mean_ms, cv = do_bench(fwd, n_warmup=10, n_repeat=10, quantiles=quantiles)
81+
gbps = lambda ms: (N * x.element_size() * 1e-9) / (ms * 1e-3)
82+
return (gbps(mean_ms), gbps(max_ms), gbps(min_ms)), cv
83+
84+
return bench
85+
86+
87+
def get_benchmarks():
88+
return [get_bench(s, t) for s in TYPES for t in TYPES if s.itemsize < t.itemsize]
89+
1990

20-
as_target = x.to(target_type)
21-
as_f32 = as_target.to(tl.float32)
22-
for _ in range(100):
23-
as_f32 += 1 # plus one ensures that there are no redundant conversions that can be removed
24-
as_target = as_f32.to(target_type)
25-
as_f32 = as_target.to(tl.float32)
26-
27-
tl.store(x_ptr + offsets, as_f32, mask=mask)
28-
29-
30-
def launch_conversion(x: torch.Tensor, target_type: type):
31-
assert x.is_xpu
32-
n_elements = x.numel()
33-
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
34-
float_trunc_kernel[grid](x, n_elements, BLOCK_SIZE=1024, target_type=target_type)
35-
return x
36-
37-
38-
@triton.testing.perf_report(
39-
triton.testing.Benchmark(
40-
x_names=['N'],
41-
x_vals=[2**i for i in range(12, 28, 2)],
42-
line_arg='target_type',
43-
line_vals=['bfloat16', 'float16'],
44-
line_names=['BF16', 'FP16'],
45-
styles=[('blue', '-'), ('green', '-'), ('orange', '-')],
46-
ylabel='GB/s',
47-
plot_name='float-conversion',
48-
args={},
49-
))
50-
def benchmark(N, target_type):
51-
quantiles = [0.5, 0.2, 0.8]
52-
inputs = torch.rand(N, dtype=torch.float32, device='xpu', requires_grad=True)
53-
54-
if target_type == 'bfloat16':
55-
fwd = lambda: launch_conversion(inputs, tl.bfloat16)
56-
elif target_type == 'float16':
57-
fwd = lambda: launch_conversion(inputs, tl.float16)
58-
else:
59-
raise NotImplementedError(f'Type {target_type} is not supported')
60-
61-
ms, min_ms, max_ms = triton.testing.do_bench(fwd, quantiles=quantiles)
62-
gbps = lambda ms: (inputs.numel() * inputs.element_size() * 1e-9) / (ms * 1e-3)
63-
64-
return gbps(ms), gbps(max_ms), gbps(min_ms)
91+
def run_benchmarks():
92+
for bench in get_benchmarks():
93+
bench.run(print_data=True)
6594

6695

6796
if __name__ == '__main__':
68-
benchmark.run(print_data=True)
97+
run_benchmarks()

benchmarks/micro_benchmarks/core_ops/dot_scaled.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import triton
33
import triton.language as tl
44

5+
from triton_kernels_benchmark import Benchmark, do_bench, perf_report
6+
57

68
@triton.jit
79
def dot_scale_kernel(a_base, stride_a0, stride_a1, a_scale, b_base, stride_b0, stride_b1, b_scale, out,
@@ -37,8 +39,8 @@ def dot_scaled(M, N, K, x, y, z, scale_x, scale_y, type_a, type_b, num_warps):
3739

3840

3941
# Benchmark Performance
40-
@triton.testing.perf_report(
41-
triton.testing.Benchmark(
42+
@perf_report(
43+
Benchmark(
4244
# argument names to use as an x-axis for the plot
4345
x_names=['M', 'K', 'N', 'col_a', 'col_b', 'rhs_scale', 'mxfp_type', 'normal_type'],
4446
x_vals=[(M, N, K, col_a, col_b, rhs_scale, mxfp_type, normal_type)
@@ -122,7 +124,7 @@ def make_finite(x, dtype):
122124
if provider == 'triton':
123125
triton_fn = lambda: dot_scaled(M, N, K, x, y, z, scale_x, scale_y, type_a, type_b, num_warps)
124126

125-
ms, min_ms, max_ms = triton.testing.do_bench(triton_fn, quantiles=quantiles)
127+
_, min_ms, max_ms, mean_ms, cv = do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles)
126128
else:
127129
raise NotImplementedError(f'Unsupported provider {provider}')
128130

@@ -141,8 +143,16 @@ def size_x(m, n, ty):
141143
scale_size = (M * K // 32) if rhs_scale else (N * K // 32)
142144
return (tensor_size + scale_size + 4.0 * (M * N)) * (1e-9) / (ms * 1e-3)
143145

144-
return gbps(ms), gbps(max_ms), gbps(min_ms)
146+
def tflops(ms):
147+
scale_size = (M * K // 32) if rhs_scale else (N * K // 32)
148+
return (2 * M * N * K + scale_size) * (1e-12) / (ms * 1e-3)
149+
150+
return (gbps(mean_ms), gbps(max_ms), gbps(min_ms)), (tflops(mean_ms), tflops(max_ms), tflops(min_ms)), cv
145151

146152

147-
if __name__ == '__main__':
153+
def run_benchmarks():
148154
benchmark.run(show_plots=False, print_data=True)
155+
156+
157+
if __name__ == '__main__':
158+
run_benchmarks()
Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,6 @@
1-
import argparse
2-
31
from conversion import float_conversion
42
from core_ops import dot_scaled
53

64
if __name__ == '__main__':
7-
parser = argparse.ArgumentParser()
8-
parser.add_argument(
9-
'--reports',
10-
type=str,
11-
default='',
12-
help='directory to save reports',
13-
)
14-
args = parser.parse_args()
15-
float_conversion.benchmark.run(print_data=True, save_path=args.reports)
16-
dot_scaled.benchmark.run(print_data=True, save_path=args.reports)
5+
for mod in (float_conversion, dot_scaled):
6+
mod.run_benchmarks()

0 commit comments

Comments
 (0)