|
2 | 2 | import triton |
3 | 3 | import triton.language as tl |
4 | 4 |
|
| 5 | +from triton_kernels_benchmark import Benchmark, do_bench, perf_report |
| 6 | + |
| 7 | +TYPES = { |
| 8 | + tl.float8e4nv: torch.float8_e4m3fn, tl.float8e5: torch.float8_e5m2, tl.float16: torch.float16, tl.bfloat16: |
| 9 | + torch.bfloat16, tl.float32: torch.float32 |
| 10 | +} |
| 11 | + |
5 | 12 |
|
6 | 13 | @triton.jit |
7 | | -def float_trunc_kernel( |
| 14 | +def float_conversion_kernel( |
8 | 15 | x_ptr, |
| 16 | + y_ptr, |
9 | 17 | n_elements, |
10 | 18 | BLOCK_SIZE: tl.constexpr, |
11 | | - target_type: tl.constexpr, |
| 19 | + x_type: tl.constexpr, |
| 20 | + y_type: tl.constexpr, |
| 21 | + rnd: tl.constexpr, |
12 | 22 | ): |
13 | 23 | pid = tl.program_id(axis=0) |
14 | 24 | block_start = pid * BLOCK_SIZE |
15 | 25 | offsets = block_start + tl.arange(0, BLOCK_SIZE) |
16 | 26 | mask = offsets < n_elements |
| 27 | + x_itype = tl.int8 if x_type.itemsize == 1 else tl.int16 if x_type.itemsize == 2 else tl.int32 |
| 28 | + y_itype = tl.int8 if y_type.itemsize == 1 else tl.int16 if y_type.itemsize == 2 else tl.int32 |
17 | 29 |
|
18 | 30 | x = tl.load(x_ptr + offsets, mask=mask) |
| 31 | + converted = x.to(y_type, fp_downcast_rounding=rnd) |
| 32 | + x = tl.cast(x, x_itype, bitcast=True) |
| 33 | + y = tl.cast(converted, y_itype, bitcast=True) |
| 34 | + for i in range(99): |
| 35 | + x += tl.full(x.shape, i, x_itype) |
| 36 | + converted = tl.cast(x, x_type, bitcast=True).to(y_type, fp_downcast_rounding=rnd) |
| 37 | + y += tl.cast(converted, y_itype, bitcast=True) |
| 38 | + y = tl.cast(y, y_type, bitcast=True) |
| 39 | + tl.store(y_ptr + offsets, y, mask=mask) |
| 40 | + |
| 41 | + |
| 42 | +def get_bench(x_type, y_type): |
| 43 | + assert x_type.itemsize < y_type.itemsize |
| 44 | + plot_name = f'{x_type}-{y_type}' |
| 45 | + line_vals = [(x_type, y_type, None), (y_type, x_type, 'rtne')] |
| 46 | + line_names = [f'{x_type}->{y_type}', f'{y_type}->{x_type}-rtne'] |
| 47 | + if y_type == tl.float32: |
| 48 | + line_vals.append((y_type, x_type, 'rtz')) |
| 49 | + line_names.append(f'{y_type}->{x_type}-rtz') |
| 50 | + |
| 51 | + @perf_report( |
| 52 | + Benchmark( |
| 53 | + x_names=['N'], |
| 54 | + x_vals=[2**i for i in range(12, 28, 2)], |
| 55 | + line_arg='args', |
| 56 | + line_vals=line_vals, |
| 57 | + line_names=line_names, |
| 58 | + styles=[(c, s) for c in 'bgry' for s in ('-', '--', '-.', ':')], |
| 59 | + ylabel=('GB/s', ), |
| 60 | + plot_name=plot_name, |
| 61 | + args={}, |
| 62 | + )) |
| 63 | + def bench(N, args): |
| 64 | + quantiles = [0.5, 0.2, 0.8] |
| 65 | + x_type = args[0] |
| 66 | + y_type = args[1] |
| 67 | + if x_type.itemsize == 1: |
| 68 | + x = torch.rand(N, dtype=torch.float16, device='xpu', requires_grad=True).to(TYPES[x_type]) |
| 69 | + else: |
| 70 | + x = torch.rand(N, dtype=TYPES[x_type], device='xpu', requires_grad=True) |
| 71 | + y = torch.empty_like(x, dtype=TYPES[y_type], device='xpu') |
| 72 | + rnd = args[2] if x_type.itemsize > y_type.itemsize else None |
| 73 | + |
| 74 | + def fwd(): |
| 75 | + BLOCK_SIZE = 4096 |
| 76 | + grid = lambda meta: (triton.cdiv(N, meta['BLOCK_SIZE']), ) |
| 77 | + float_conversion_kernel[grid](x, y, N, BLOCK_SIZE, x_type, y_type, rnd) |
| 78 | + return x |
| 79 | + |
| 80 | + _, min_ms, max_ms, mean_ms, cv = do_bench(fwd, n_warmup=10, n_repeat=10, quantiles=quantiles) |
| 81 | + gbps = lambda ms: (N * x.element_size() * 1e-9) / (ms * 1e-3) |
| 82 | + return (gbps(mean_ms), gbps(max_ms), gbps(min_ms)), cv |
| 83 | + |
| 84 | + return bench |
| 85 | + |
| 86 | + |
| 87 | +def get_benchmarks(): |
| 88 | + return [get_bench(s, t) for s in TYPES for t in TYPES if s.itemsize < t.itemsize] |
| 89 | + |
19 | 90 |
|
20 | | - as_target = x.to(target_type) |
21 | | - as_f32 = as_target.to(tl.float32) |
22 | | - for _ in range(100): |
23 | | - as_f32 += 1 # plus one ensures that there are no redundant conversions that can be removed |
24 | | - as_target = as_f32.to(target_type) |
25 | | - as_f32 = as_target.to(tl.float32) |
26 | | - |
27 | | - tl.store(x_ptr + offsets, as_f32, mask=mask) |
28 | | - |
29 | | - |
30 | | -def launch_conversion(x: torch.Tensor, target_type: type): |
31 | | - assert x.is_xpu |
32 | | - n_elements = x.numel() |
33 | | - grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) |
34 | | - float_trunc_kernel[grid](x, n_elements, BLOCK_SIZE=1024, target_type=target_type) |
35 | | - return x |
36 | | - |
37 | | - |
38 | | -@triton.testing.perf_report( |
39 | | - triton.testing.Benchmark( |
40 | | - x_names=['N'], |
41 | | - x_vals=[2**i for i in range(12, 28, 2)], |
42 | | - line_arg='target_type', |
43 | | - line_vals=['bfloat16', 'float16'], |
44 | | - line_names=['BF16', 'FP16'], |
45 | | - styles=[('blue', '-'), ('green', '-'), ('orange', '-')], |
46 | | - ylabel='GB/s', |
47 | | - plot_name='float-conversion', |
48 | | - args={}, |
49 | | - )) |
50 | | -def benchmark(N, target_type): |
51 | | - quantiles = [0.5, 0.2, 0.8] |
52 | | - inputs = torch.rand(N, dtype=torch.float32, device='xpu', requires_grad=True) |
53 | | - |
54 | | - if target_type == 'bfloat16': |
55 | | - fwd = lambda: launch_conversion(inputs, tl.bfloat16) |
56 | | - elif target_type == 'float16': |
57 | | - fwd = lambda: launch_conversion(inputs, tl.float16) |
58 | | - else: |
59 | | - raise NotImplementedError(f'Type {target_type} is not supported') |
60 | | - |
61 | | - ms, min_ms, max_ms = triton.testing.do_bench(fwd, quantiles=quantiles) |
62 | | - gbps = lambda ms: (inputs.numel() * inputs.element_size() * 1e-9) / (ms * 1e-3) |
63 | | - |
64 | | - return gbps(ms), gbps(max_ms), gbps(min_ms) |
| 91 | +def run_benchmarks(): |
| 92 | + for bench in get_benchmarks(): |
| 93 | + bench.run(print_data=True) |
65 | 94 |
|
66 | 95 |
|
67 | 96 | if __name__ == '__main__': |
68 | | - benchmark.run(print_data=True) |
| 97 | + run_benchmarks() |
0 commit comments