|
| 1 | +import os |
| 2 | +import sys |
| 3 | +from functools import lru_cache |
| 4 | + |
1 | 5 | import torch
|
2 | 6 | import triton
|
3 | 7 | import triton.language as tl
|
4 | 8 |
|
| 9 | +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../..'))) |
| 10 | +from triton_kernels_benchmark import Benchmark, do_bench, perf_report # pylint: disable=C0413 |
| 11 | + |
| 12 | +TYPES = { |
| 13 | + tl.float8e4nv: torch.float8_e4m3fn, tl.float8e5: torch.float8_e5m2, tl.float16: torch.float16, tl.bfloat16: |
| 14 | + torch.bfloat16, tl.float32: torch.float32 |
| 15 | +} |
| 16 | + |
| 17 | + |
| 18 | +@lru_cache |
| 19 | +def get_kernel(name): |
| 20 | + |
| 21 | + def kernel( |
| 22 | + x_ptr, |
| 23 | + y_ptr, |
| 24 | + n_elements, |
| 25 | + BLOCK_SIZE: tl.constexpr, |
| 26 | + x_type: tl.constexpr, |
| 27 | + y_type: tl.constexpr, |
| 28 | + rnd: tl.constexpr, |
| 29 | + ): |
| 30 | + pid = tl.program_id(axis=0) |
| 31 | + block_start = pid * BLOCK_SIZE |
| 32 | + offsets = block_start + tl.arange(0, BLOCK_SIZE) |
| 33 | + mask = offsets < n_elements |
| 34 | + x_itype = tl.int8 if x_type.itemsize == 1 else tl.int16 if x_type.itemsize == 2 else tl.int32 |
| 35 | + y_itype = tl.int8 if y_type.itemsize == 1 else tl.int16 if y_type.itemsize == 2 else tl.int32 |
| 36 | + |
| 37 | + x = tl.load(x_ptr + offsets, mask=mask) |
| 38 | + converted = x.to(y_type, fp_downcast_rounding=rnd) |
| 39 | + x = tl.cast(x, x_itype, bitcast=True) |
| 40 | + y = tl.cast(converted, y_itype, bitcast=True) |
| 41 | + for i in range(99): |
| 42 | + x += tl.full(x.shape, i, x_itype) |
| 43 | + converted = tl.cast(x, x_type, bitcast=True).to(y_type, fp_downcast_rounding=rnd) |
| 44 | + y += tl.cast(converted, y_itype, bitcast=True) |
| 45 | + y = tl.cast(y, y_type, bitcast=True) |
| 46 | + tl.store(y_ptr + offsets, y, mask=mask) |
| 47 | + |
| 48 | + kernel.__name__ = kernel.__qualname__ = name |
| 49 | + return triton.jit(kernel) |
| 50 | + |
| 51 | + |
| 52 | +def get_bench(x_type, y_type): |
| 53 | + assert x_type.itemsize < y_type.itemsize |
| 54 | + plot_name = f'{x_type}-{y_type}' |
| 55 | + line_vals = [(x_type, y_type, None), (y_type, x_type, 'rtne')] |
| 56 | + line_names = [f'{x_type}->{y_type}', f'{y_type}->{x_type}-rtne'] |
| 57 | + if y_type == tl.float32: |
| 58 | + line_vals.append((y_type, x_type, 'rtz')) |
| 59 | + line_names.append(f'{y_type}->{x_type}-rtz') |
| 60 | + |
| 61 | + @perf_report( |
| 62 | + Benchmark( |
| 63 | + x_names=['N'], |
| 64 | + x_vals=[2**i for i in range(12, 28, 2)], |
| 65 | + line_arg='args', |
| 66 | + line_vals=line_vals, |
| 67 | + line_names=line_names, |
| 68 | + styles=[(c, s) for c in 'bgry' for s in ('-', '--', '-.', ':')], |
| 69 | + ylabel=('GB/s', ), |
| 70 | + plot_name=plot_name, |
| 71 | + args={}, |
| 72 | + )) |
| 73 | + def bench(N, args): |
| 74 | + quantiles = [0.5, 0.2, 0.8] |
| 75 | + x_type = args[0] |
| 76 | + y_type = args[1] |
| 77 | + if x_type.itemsize == 1: |
| 78 | + x = torch.rand(N, dtype=torch.float16, device='xpu', requires_grad=True).to(TYPES[x_type]) |
| 79 | + else: |
| 80 | + x = torch.rand(N, dtype=TYPES[x_type], device='xpu', requires_grad=True) |
| 81 | + y = torch.empty_like(x, dtype=TYPES[y_type], device='xpu') |
| 82 | + rnd = args[2] if x_type.itemsize > y_type.itemsize else None |
| 83 | + name = f'{x_type}_to_{y_type}_conversion_kernel' |
| 84 | + if rnd: |
| 85 | + name = f'{rnd}_{name}' |
| 86 | + kernel = get_kernel(name) |
| 87 | + |
| 88 | + def fwd(): |
| 89 | + BLOCK_SIZE = 4096 |
| 90 | + grid = lambda meta: (triton.cdiv(N, meta['BLOCK_SIZE']), ) |
| 91 | + kernel[grid](x, y, N, BLOCK_SIZE, x_type, y_type, rnd) |
| 92 | + return x |
| 93 | + |
| 94 | + _, min_ms, max_ms, mean_ms, cv = do_bench(fwd, n_warmup=10, n_repeat=10, quantiles=quantiles) |
| 95 | + gbps = lambda ms: (N * x.element_size() * 1e-9) / (ms * 1e-3) |
| 96 | + return (gbps(mean_ms), gbps(max_ms), gbps(min_ms)), cv |
| 97 | + |
| 98 | + return bench |
| 99 | + |
| 100 | + |
| 101 | +def get_benchmarks(): |
| 102 | + return [get_bench(s, t) for s in TYPES for t in TYPES if s.itemsize < t.itemsize] |
| 103 | + |
5 | 104 |
|
6 |
| -@triton.jit |
7 |
| -def float_trunc_kernel( |
8 |
| - x_ptr, |
9 |
| - n_elements, |
10 |
| - BLOCK_SIZE: tl.constexpr, |
11 |
| - target_type: tl.constexpr, |
12 |
| -): |
13 |
| - pid = tl.program_id(axis=0) |
14 |
| - block_start = pid * BLOCK_SIZE |
15 |
| - offsets = block_start + tl.arange(0, BLOCK_SIZE) |
16 |
| - mask = offsets < n_elements |
17 |
| - |
18 |
| - x = tl.load(x_ptr + offsets, mask=mask) |
19 |
| - |
20 |
| - as_target = x.to(target_type) |
21 |
| - as_f32 = as_target.to(tl.float32) |
22 |
| - for _ in range(100): |
23 |
| - as_f32 += 1 # plus one ensures that there are no redundant conversions that can be removed |
24 |
| - as_target = as_f32.to(target_type) |
25 |
| - as_f32 = as_target.to(tl.float32) |
26 |
| - |
27 |
| - tl.store(x_ptr + offsets, as_f32, mask=mask) |
28 |
| - |
29 |
| - |
30 |
| -def launch_conversion(x: torch.Tensor, target_type: type): |
31 |
| - assert x.is_xpu |
32 |
| - n_elements = x.numel() |
33 |
| - grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) |
34 |
| - float_trunc_kernel[grid](x, n_elements, BLOCK_SIZE=1024, target_type=target_type) |
35 |
| - return x |
36 |
| - |
37 |
| - |
38 |
| -@triton.testing.perf_report( |
39 |
| - triton.testing.Benchmark( |
40 |
| - x_names=['N'], |
41 |
| - x_vals=[2**i for i in range(12, 28, 2)], |
42 |
| - line_arg='target_type', |
43 |
| - line_vals=['bfloat16', 'float16'], |
44 |
| - line_names=['BF16', 'FP16'], |
45 |
| - styles=[('blue', '-'), ('green', '-'), ('orange', '-')], |
46 |
| - ylabel='GB/s', |
47 |
| - plot_name='float-conversion', |
48 |
| - args={}, |
49 |
| - )) |
50 |
| -def benchmark(N, target_type): |
51 |
| - quantiles = [0.5, 0.2, 0.8] |
52 |
| - inputs = torch.rand(N, dtype=torch.float32, device='xpu', requires_grad=True) |
53 |
| - |
54 |
| - if target_type == 'bfloat16': |
55 |
| - fwd = lambda: launch_conversion(inputs, tl.bfloat16) |
56 |
| - elif target_type == 'float16': |
57 |
| - fwd = lambda: launch_conversion(inputs, tl.float16) |
58 |
| - else: |
59 |
| - raise NotImplementedError(f'Type {target_type} is not supported') |
60 |
| - |
61 |
| - ms, min_ms, max_ms = triton.testing.do_bench(fwd, quantiles=quantiles) |
62 |
| - gbps = lambda ms: (inputs.numel() * inputs.element_size() * 1e-9) / (ms * 1e-3) |
63 |
| - |
64 |
| - return gbps(ms), gbps(max_ms), gbps(min_ms) |
| 105 | +def run_benchmarks(): |
| 106 | + for bench in get_benchmarks(): |
| 107 | + bench.run(print_data=True) |
65 | 108 |
|
66 | 109 |
|
67 | 110 | if __name__ == '__main__':
|
68 |
| - benchmark.run(print_data=True) |
| 111 | + run_benchmarks() |
0 commit comments