Skip to content

Commit b5381cf

Browse files
[benchmarks] Reworked the conversion benchmark and added more tests for up/down casts
1 parent bae3356 commit b5381cf

File tree

3 files changed

+95
-57
lines changed

3 files changed

+95
-57
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .float_conversion import benchmark # type: ignore # noqa: F401
1+
from .float_conversion import bench_upcast, bench_downcast # type: ignore # noqa: F401
Lines changed: 92 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,68 +1,105 @@
1+
from functools import lru_cache
2+
13
import torch
24
import triton
35
import triton.language as tl
46

7+
TYPES = {
8+
tl.float8e4nv: torch.float8_e4m3fn, tl.float8e5: torch.float8_e5m2, tl.float16: torch.float16, tl.bfloat16:
9+
torch.bfloat16, tl.float32: torch.float32
10+
}
11+
UP_VALS = [(s, t) for s in TYPES for t in TYPES if s.itemsize < t.itemsize]
12+
DOWN_VALS = [(s, t, r) for s in TYPES for t in TYPES if s.itemsize > t.itemsize for r in ('rtne', 'rtz')
13+
if r == 'rtne' or s == tl.float32]
14+
15+
16+
@lru_cache
17+
def _make_kernel(name):
18+
19+
def kernel(
20+
x_ptr,
21+
y_ptr,
22+
n_elements,
23+
BLOCK_SIZE: tl.constexpr,
24+
x_type: tl.constexpr,
25+
y_type: tl.constexpr,
26+
rnd: tl.constexpr,
27+
):
28+
pid = tl.program_id(axis=0)
29+
block_start = pid * BLOCK_SIZE
30+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
31+
mask = offsets < n_elements
32+
x_itype = tl.int8 if x_type.itemsize == 1 else tl.int16 if x_type.itemsize == 2 else tl.int32
33+
y_itype = tl.int8 if y_type.itemsize == 1 else tl.int16 if y_type.itemsize == 2 else tl.int32
34+
35+
x = tl.load(x_ptr + offsets, mask=mask)
36+
converted = x.to(y_type, fp_downcast_rounding=rnd)
37+
x = tl.cast(x, x_itype, bitcast=True)
38+
y = tl.cast(converted, y_itype, bitcast=True)
39+
for i in range(99):
40+
x += tl.full(x.shape, i, x_itype)
41+
converted = tl.cast(x, x_type, bitcast=True).to(y_type, fp_downcast_rounding=rnd)
42+
y += tl.cast(converted, y_itype, bitcast=True)
43+
y = tl.cast(y, y_type, bitcast=True)
44+
tl.store(y_ptr + offsets, y, mask=mask)
45+
46+
kernel.__name__ = kernel.__qualname__ = name
47+
return triton.jit(kernel)
548

6-
@triton.jit
7-
def float_trunc_kernel(
8-
x_ptr,
9-
n_elements,
10-
BLOCK_SIZE: tl.constexpr,
11-
target_type: tl.constexpr,
12-
):
13-
pid = tl.program_id(axis=0)
14-
block_start = pid * BLOCK_SIZE
15-
offsets = block_start + tl.arange(0, BLOCK_SIZE)
16-
mask = offsets < n_elements
17-
18-
x = tl.load(x_ptr + offsets, mask=mask)
19-
20-
as_target = x.to(target_type)
21-
as_f32 = as_target.to(tl.float32)
22-
for _ in range(100):
23-
as_f32 += 1 # plus one ensures that there are no redundant conversions that can be removed
24-
as_target = as_f32.to(target_type)
25-
as_f32 = as_target.to(tl.float32)
26-
27-
tl.store(x_ptr + offsets, as_f32, mask=mask)
28-
29-
30-
def launch_conversion(x: torch.Tensor, target_type: type):
31-
assert x.is_xpu
32-
n_elements = x.numel()
33-
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
34-
float_trunc_kernel[grid](x, n_elements, BLOCK_SIZE=1024, target_type=target_type)
35-
return x
36-
37-
38-
@triton.testing.perf_report(
39-
triton.testing.Benchmark(
40-
x_names=['N'],
41-
x_vals=[2**i for i in range(12, 28, 2)],
42-
line_arg='target_type',
43-
line_vals=['bfloat16', 'float16'],
44-
line_names=['BF16', 'FP16'],
45-
styles=[('blue', '-'), ('green', '-'), ('orange', '-')],
46-
ylabel='GB/s',
47-
plot_name='float-conversion',
48-
args={},
49-
))
50-
def benchmark(N, target_type):
51-
quantiles = [0.5, 0.2, 0.8]
52-
inputs = torch.rand(N, dtype=torch.float32, device='xpu', requires_grad=True)
5349

54-
if target_type == 'bfloat16':
55-
fwd = lambda: launch_conversion(inputs, tl.bfloat16)
56-
elif target_type == 'float16':
57-
fwd = lambda: launch_conversion(inputs, tl.float16)
50+
def _benchmark(N, args):
51+
quantiles = [0.5, 0.2, 0.8]
52+
x_type = args[0]
53+
y_type = args[1]
54+
if x_type.itemsize == 1:
55+
x = torch.rand(N, dtype=torch.float16, device='xpu', requires_grad=True).to(TYPES[x_type])
5856
else:
59-
raise NotImplementedError(f'Type {target_type} is not supported')
57+
x = torch.rand(N, dtype=TYPES[x_type], device='xpu', requires_grad=True)
58+
y = torch.empty_like(x, dtype=TYPES[y_type], device='xpu')
59+
rnd = args[2] if x_type.itemsize > y_type.itemsize else None
60+
name = f"{x_type}_to_{y_type}_conversion_kernel"
61+
if rnd:
62+
name = f"{rnd}_{name}"
63+
kernel = _make_kernel(name)
6064

61-
ms, min_ms, max_ms = triton.testing.do_bench(fwd, quantiles=quantiles)
62-
gbps = lambda ms: (inputs.numel() * inputs.element_size() * 1e-9) / (ms * 1e-3)
65+
def fwd():
66+
BLOCK_SIZE = 4096
67+
grid = lambda meta: (triton.cdiv(N, meta['BLOCK_SIZE']), )
68+
kernel[grid](x, y, N, BLOCK_SIZE, x_type, y_type, rnd)
69+
return x
6370

71+
ms, min_ms, max_ms = triton.testing.do_bench(fwd, quantiles=quantiles)
72+
gbps = lambda ms: (N * x.element_size() * 1e-9) / (ms * 1e-3)
6473
return gbps(ms), gbps(max_ms), gbps(min_ms)
6574

6675

76+
def _report(plot_name, line_names, line_vals):
77+
report = triton.testing.perf_report(
78+
triton.testing.Benchmark(
79+
x_names=['N'],
80+
x_vals=[2**i for i in range(12, 28, 2)],
81+
line_arg='args',
82+
line_vals=line_vals,
83+
line_names=line_names,
84+
styles=[(c, s) for c in 'bgry' for s in ('-', '--', '-.', ':')],
85+
ylabel='GB/s',
86+
plot_name=plot_name,
87+
args={},
88+
))
89+
return report(_benchmark)
90+
91+
92+
bench_upcast = _report(
93+
plot_name='float-upcast',
94+
line_names=[f"{s}->{t}" for s, t in UP_VALS],
95+
line_vals=UP_VALS,
96+
)
97+
bench_downcast = _report(
98+
plot_name='float-downcast',
99+
line_names=[f"{s}->{t}/{r}" for s, t, r in DOWN_VALS],
100+
line_vals=DOWN_VALS,
101+
)
102+
67103
if __name__ == '__main__':
68-
benchmark.run(print_data=True)
104+
bench_upcast.run(print_data=True)
105+
bench_downcast.run(print_data=True)

benchmarks/micro_benchmarks/run_benchmarks.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,6 @@
1212
help='directory to save reports',
1313
)
1414
args = parser.parse_args()
15-
float_conversion.benchmark.run(print_data=True, save_path=args.reports)
15+
float_conversion.bench_upcast.run(print_data=True, save_path=args.reports)
16+
float_conversion.bench_downcast.run(print_data=True, save_path=args.reports)
1617
dot_scaled.benchmark.run(print_data=True, save_path=args.reports)

0 commit comments

Comments
 (0)