Skip to content

Commit c1d4de6

Browse files
[benchmarks] Reworked the conversion benchmark and added more tests for up/down casts
1 parent 0ab03be commit c1d4de6

File tree

4 files changed

+122
-79
lines changed

4 files changed

+122
-79
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .float_conversion import benchmark # type: ignore # noqa: F401
1+
from .float_conversion import get_benchmarks, run_benchmarks # type: ignore # noqa: F401
Lines changed: 103 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,68 +1,111 @@
1+
import os
2+
import sys
3+
from functools import lru_cache
4+
15
import torch
26
import triton
37
import triton.language as tl
48

9+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../..')))
10+
from triton_kernels_benchmark import Benchmark, do_bench, perf_report # pylint: disable=C0413
11+
12+
TYPES = {
13+
tl.float8e4nv: torch.float8_e4m3fn, tl.float8e5: torch.float8_e5m2, tl.float16: torch.float16, tl.bfloat16:
14+
torch.bfloat16, tl.float32: torch.float32
15+
}
16+
17+
18+
@lru_cache
19+
def get_kernel(name):
20+
21+
def kernel(
22+
x_ptr,
23+
y_ptr,
24+
n_elements,
25+
BLOCK_SIZE: tl.constexpr,
26+
x_type: tl.constexpr,
27+
y_type: tl.constexpr,
28+
rnd: tl.constexpr,
29+
):
30+
pid = tl.program_id(axis=0)
31+
block_start = pid * BLOCK_SIZE
32+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
33+
mask = offsets < n_elements
34+
x_itype = tl.int8 if x_type.itemsize == 1 else tl.int16 if x_type.itemsize == 2 else tl.int32
35+
y_itype = tl.int8 if y_type.itemsize == 1 else tl.int16 if y_type.itemsize == 2 else tl.int32
36+
37+
x = tl.load(x_ptr + offsets, mask=mask)
38+
converted = x.to(y_type, fp_downcast_rounding=rnd)
39+
x = tl.cast(x, x_itype, bitcast=True)
40+
y = tl.cast(converted, y_itype, bitcast=True)
41+
for i in range(99):
42+
x += tl.full(x.shape, i, x_itype)
43+
converted = tl.cast(x, x_type, bitcast=True).to(y_type, fp_downcast_rounding=rnd)
44+
y += tl.cast(converted, y_itype, bitcast=True)
45+
y = tl.cast(y, y_type, bitcast=True)
46+
tl.store(y_ptr + offsets, y, mask=mask)
47+
48+
kernel.__name__ = kernel.__qualname__ = name
49+
return triton.jit(kernel)
50+
51+
52+
def get_bench(x_type, y_type):
53+
assert x_type.itemsize < y_type.itemsize
54+
plot_name = f'{x_type}-{y_type}'
55+
line_vals = [(x_type, y_type, None), (y_type, x_type, 'rtne')]
56+
line_names = [f'{x_type}->{y_type}', f'{y_type}->{x_type}-rtne']
57+
if y_type == tl.float32:
58+
line_vals.append((y_type, x_type, 'rtz'))
59+
line_names.append(f'{y_type}->{x_type}-rtz')
60+
61+
@perf_report(
62+
Benchmark(
63+
x_names=['N'],
64+
x_vals=[2**i for i in range(12, 28, 2)],
65+
line_arg='args',
66+
line_vals=line_vals,
67+
line_names=line_names,
68+
styles=[(c, s) for c in 'bgry' for s in ('-', '--', '-.', ':')],
69+
ylabel=('GB/s', ),
70+
plot_name=plot_name,
71+
args={},
72+
))
73+
def bench(N, args):
74+
quantiles = [0.5, 0.2, 0.8]
75+
x_type = args[0]
76+
y_type = args[1]
77+
if x_type.itemsize == 1:
78+
x = torch.rand(N, dtype=torch.float16, device='xpu', requires_grad=True).to(TYPES[x_type])
79+
else:
80+
x = torch.rand(N, dtype=TYPES[x_type], device='xpu', requires_grad=True)
81+
y = torch.empty_like(x, dtype=TYPES[y_type], device='xpu')
82+
rnd = args[2] if x_type.itemsize > y_type.itemsize else None
83+
name = f'{x_type}_to_{y_type}_conversion_kernel'
84+
if rnd:
85+
name = f'{rnd}_{name}'
86+
kernel = get_kernel(name)
87+
88+
def fwd():
89+
BLOCK_SIZE = 4096
90+
grid = lambda meta: (triton.cdiv(N, meta['BLOCK_SIZE']), )
91+
kernel[grid](x, y, N, BLOCK_SIZE, x_type, y_type, rnd)
92+
return x
93+
94+
_, min_ms, max_ms, mean_ms, cv = do_bench(fwd, n_warmup=10, n_repeat=10, quantiles=quantiles)
95+
gbps = lambda ms: (N * x.element_size() * 1e-9) / (ms * 1e-3)
96+
return (gbps(mean_ms), gbps(max_ms), gbps(min_ms)), cv
97+
98+
return bench
99+
100+
101+
def get_benchmarks():
102+
return [get_bench(s, t) for s in TYPES for t in TYPES if s.itemsize < t.itemsize]
103+
5104

6-
@triton.jit
7-
def float_trunc_kernel(
8-
x_ptr,
9-
n_elements,
10-
BLOCK_SIZE: tl.constexpr,
11-
target_type: tl.constexpr,
12-
):
13-
pid = tl.program_id(axis=0)
14-
block_start = pid * BLOCK_SIZE
15-
offsets = block_start + tl.arange(0, BLOCK_SIZE)
16-
mask = offsets < n_elements
17-
18-
x = tl.load(x_ptr + offsets, mask=mask)
19-
20-
as_target = x.to(target_type)
21-
as_f32 = as_target.to(tl.float32)
22-
for _ in range(100):
23-
as_f32 += 1 # plus one ensures that there are no redundant conversions that can be removed
24-
as_target = as_f32.to(target_type)
25-
as_f32 = as_target.to(tl.float32)
26-
27-
tl.store(x_ptr + offsets, as_f32, mask=mask)
28-
29-
30-
def launch_conversion(x: torch.Tensor, target_type: type):
31-
assert x.is_xpu
32-
n_elements = x.numel()
33-
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
34-
float_trunc_kernel[grid](x, n_elements, BLOCK_SIZE=1024, target_type=target_type)
35-
return x
36-
37-
38-
@triton.testing.perf_report(
39-
triton.testing.Benchmark(
40-
x_names=['N'],
41-
x_vals=[2**i for i in range(12, 28, 2)],
42-
line_arg='target_type',
43-
line_vals=['bfloat16', 'float16'],
44-
line_names=['BF16', 'FP16'],
45-
styles=[('blue', '-'), ('green', '-'), ('orange', '-')],
46-
ylabel='GB/s',
47-
plot_name='float-conversion',
48-
args={},
49-
))
50-
def benchmark(N, target_type):
51-
quantiles = [0.5, 0.2, 0.8]
52-
inputs = torch.rand(N, dtype=torch.float32, device='xpu', requires_grad=True)
53-
54-
if target_type == 'bfloat16':
55-
fwd = lambda: launch_conversion(inputs, tl.bfloat16)
56-
elif target_type == 'float16':
57-
fwd = lambda: launch_conversion(inputs, tl.float16)
58-
else:
59-
raise NotImplementedError(f'Type {target_type} is not supported')
60-
61-
ms, min_ms, max_ms = triton.testing.do_bench(fwd, quantiles=quantiles)
62-
gbps = lambda ms: (inputs.numel() * inputs.element_size() * 1e-9) / (ms * 1e-3)
63-
64-
return gbps(ms), gbps(max_ms), gbps(min_ms)
105+
def run_benchmarks():
106+
for bench in get_benchmarks():
107+
bench.run(print_data=True)
65108

66109

67110
if __name__ == '__main__':
68-
benchmark.run(print_data=True)
111+
run_benchmarks()

benchmarks/micro_benchmarks/core_ops/dot_scaled.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
1+
import os
2+
import sys
3+
14
import torch
25
import triton
36
import triton.language as tl
47

8+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
9+
from triton_kernels_benchmark import Benchmark, do_bench, perf_report # pylint: disable=C0413
10+
511

612
@triton.jit
713
def dot_scale_kernel(a_base, stride_a0, stride_a1, a_scale, b_base, stride_b0, stride_b1, b_scale, out,
@@ -37,8 +43,8 @@ def dot_scaled(M, N, K, x, y, z, scale_x, scale_y, type_a, type_b, num_warps):
3743

3844

3945
# Benchmark Performance
40-
@triton.testing.perf_report(
41-
triton.testing.Benchmark(
46+
@perf_report(
47+
Benchmark(
4248
# argument names to use as an x-axis for the plot
4349
x_names=['M', 'K', 'N', 'col_a', 'col_b', 'rhs_scale', 'mxfp_type', 'normal_type'],
4450
x_vals=[(M, N, K, col_a, col_b, rhs_scale, mxfp_type, normal_type)
@@ -55,7 +61,7 @@ def dot_scaled(M, N, K, x, y, z, scale_x, scale_y, type_a, type_b, num_warps):
5561
line_names=['Triton'],
5662
# line styles
5763
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
58-
ylabel=['GB/s', 'TFlops'], # label name for the y-axis
64+
ylabel=('GB/s', ), # label name for the y-axis
5965
plot_name='scaled-dot',
6066
# name for the plot. Used also as a file name for saving the plot.
6167
args={},
@@ -122,7 +128,7 @@ def make_finite(x, dtype):
122128
if provider == 'triton':
123129
triton_fn = lambda: dot_scaled(M, N, K, x, y, z, scale_x, scale_y, type_a, type_b, num_warps)
124130

125-
ms, min_ms, max_ms = triton.testing.do_bench(triton_fn, quantiles=quantiles)
131+
_, min_ms, max_ms, mean_ms, cv = do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles)
126132
else:
127133
raise NotImplementedError(f'Unsupported provider {provider}')
128134

@@ -141,8 +147,12 @@ def size_x(m, n, ty):
141147
scale_size = (M * K // 32) if rhs_scale else (N * K // 32)
142148
return (tensor_size + scale_size + 4.0 * (M * N)) * (1e-9) / (ms * 1e-3)
143149

144-
return gbps(ms), gbps(max_ms), gbps(min_ms)
150+
return (gbps(mean_ms), gbps(max_ms), gbps(min_ms)), cv
145151

146152

147-
if __name__ == '__main__':
153+
def run_benchmarks():
148154
benchmark.run(show_plots=False, print_data=True)
155+
156+
157+
if __name__ == '__main__':
158+
run_benchmarks()
Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,6 @@
1-
import argparse
2-
31
from conversion import float_conversion
42
from core_ops import dot_scaled
53

64
if __name__ == '__main__':
7-
parser = argparse.ArgumentParser()
8-
parser.add_argument(
9-
'--reports',
10-
type=str,
11-
default='',
12-
help='directory to save reports',
13-
)
14-
args = parser.parse_args()
15-
float_conversion.benchmark.run(print_data=True, save_path=args.reports)
16-
dot_scaled.benchmark.run(print_data=True, save_path=args.reports)
5+
for mod in (float_conversion, dot_scaled):
6+
mod.run_benchmarks()

0 commit comments

Comments
 (0)