|
| 1 | +import torch |
| 2 | +import triton |
| 3 | +import triton.language as tl |
| 4 | + |
| 5 | +import triton_kernels_benchmark as benchmark_suit |
| 6 | + |
| 7 | + |
| 8 | +@triton.jit |
| 9 | +def dot_scale_kernel(a_base, stride_a0, stride_a1, a_scale, b_base, stride_b0, stride_b1, b_scale, out, |
| 10 | + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, type_a: tl.constexpr, |
| 11 | + type_b: tl.constexpr): |
| 12 | + DIV_FACTOR_A: tl.constexpr = 2 if type_a == 'e2m1' else 1 |
| 13 | + DIV_FACTOR_B: tl.constexpr = 2 if type_b == 'e2m1' else 1 |
| 14 | + PACKED_BLOCK_K_A: tl.constexpr = BLOCK_K // DIV_FACTOR_A |
| 15 | + PACKED_BLOCK_K_B: tl.constexpr = BLOCK_K // DIV_FACTOR_B |
| 16 | + a_ptr = a_base + tl.arange(0, BLOCK_M)[:, None] * stride_a0 + tl.arange(0, PACKED_BLOCK_K_A)[None, :] * stride_a1 |
| 17 | + b_ptr = b_base + tl.arange(0, PACKED_BLOCK_K_B)[:, None] * stride_b0 + tl.arange(0, BLOCK_N)[None, :] * stride_b1 |
| 18 | + |
| 19 | + a = tl.load(a_ptr) |
| 20 | + b = tl.load(b_ptr) |
| 21 | + SCALE_BLOCK_K: tl.constexpr = BLOCK_K // 32 |
| 22 | + if a_scale is not None: |
| 23 | + scale_a_ptr = a_scale + tl.arange(0, BLOCK_M)[:, None] * SCALE_BLOCK_K + tl.arange(0, SCALE_BLOCK_K)[None, :] |
| 24 | + a_scale = tl.load(scale_a_ptr) |
| 25 | + if b_scale is not None: |
| 26 | + scale_b_ptr = b_scale + tl.arange(0, BLOCK_N)[:, None] * SCALE_BLOCK_K + tl.arange(0, SCALE_BLOCK_K)[None, :] |
| 27 | + b_scale = tl.load(scale_b_ptr) |
| 28 | + c = tl.dot_scaled(a, a_scale, type_a, b, b_scale, type_b) |
| 29 | + out_ptr = out + \ |
| 30 | + tl.arange(0, BLOCK_M)[:, None] * BLOCK_N + \ |
| 31 | + tl.arange(0, BLOCK_N)[None, :] |
| 32 | + tl.store(out_ptr, c.to(tl.bfloat16)) |
| 33 | + |
| 34 | + |
| 35 | +def dot_scaled(M, N, K, x, y, z, scale_x, scale_y, type_a, type_b, num_warps): |
| 36 | + kernel_kwargs = {'num_warps': num_warps} |
| 37 | + dot_scale_kernel[(1, )](x, *x.stride(), scale_x, y, *y.stride(), scale_y, z, M, N, K, type_a, type_b, |
| 38 | + **kernel_kwargs) |
| 39 | + |
| 40 | + |
| 41 | +# Benchmark Performance |
| 42 | +@benchmark_suit.perf_report( |
| 43 | + benchmark_suit.Benchmark( |
| 44 | + # argument names to use as an x-axis for the plot |
| 45 | + x_names=['M', 'K', 'N', 'col_a', 'col_b', 'rhs_scale', 'mxfp_type', 'normal_type'], |
| 46 | + x_vals=[(M, N, K, col_a, col_b, rhs_scale, mxfp_type, normal_type) |
| 47 | + for M, N, K in [(128, 128, 128)] |
| 48 | + for col_a, col_b in [(True, True), (False, False)] |
| 49 | + for rhs_scale in [True, False] |
| 50 | + for mxfp_type in ['e2m1', 'e4m3'] |
| 51 | + for normal_type in ['bf16']], |
| 52 | + line_arg='provider', |
| 53 | + # argument name whose value corresponds to a different line in the plot |
| 54 | + # possible values for `line_arg`` |
| 55 | + line_vals=['triton'], |
| 56 | + # label name for the lines |
| 57 | + line_names=['Triton'], |
| 58 | + # line styles |
| 59 | + styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')], |
| 60 | + ylabel=['GB/s', 'TFlops'], # label name for the y-axis |
| 61 | + plot_name='scaled-dot', |
| 62 | + # name for the plot. Used also as a file name for saving the plot. |
| 63 | + args={}, |
| 64 | + )) |
| 65 | +def benchmark(M, N, K, col_a, col_b, rhs_scale, mxfp_type, normal_type, provider): |
| 66 | + |
| 67 | + device = 'xpu' |
| 68 | + num_warps = 4 |
| 69 | + quantiles = [0.5, 0.0, 1.0] |
| 70 | + |
| 71 | + comp_dtype = torch.float16 if normal_type == 'fp16' else torch.bfloat16 |
| 72 | + # The max exponent we use to initialize data in the x/y and associated scale tensor to avoid |
| 73 | + # overflow when scaling. |
| 74 | + comp_dtype_max_exp = 6 if normal_type == 'fp16' else 15 |
| 75 | + |
| 76 | + torch.manual_seed(0) |
| 77 | + |
| 78 | + def make_arg(shape, ty, col_major=False): |
| 79 | + if col_major: |
| 80 | + shape = shape[:-2] + (shape[-1], shape[-2]) |
| 81 | + if ty in ['fp16', 'bf16']: |
| 82 | + ret = torch.randn(shape, dtype=comp_dtype, device=device) |
| 83 | + # Clamp to avoid relative error issues |
| 84 | + ret.clamp_(-2**comp_dtype_max_exp, 2**comp_dtype_max_exp - 1) |
| 85 | + else: |
| 86 | + ret = torch.randint(256, shape, dtype=torch.uint8, device=device) |
| 87 | + if col_major: |
| 88 | + ret = ret.mT |
| 89 | + return ret |
| 90 | + |
| 91 | + type_a = normal_type if rhs_scale else mxfp_type |
| 92 | + type_b = mxfp_type if rhs_scale else normal_type |
| 93 | + |
| 94 | + DIV_FACTOR_A = 2 if type_a == 'e2m1' else 1 |
| 95 | + DIV_FACTOR_B = 2 if type_b == 'e2m1' else 1 |
| 96 | + x = make_arg((M, K // DIV_FACTOR_A), type_a, col_major=col_a) |
| 97 | + y = make_arg((K // DIV_FACTOR_B, N), type_b, col_major=col_b) |
| 98 | + |
| 99 | + min_scale, max_scale = (0, 142) if comp_dtype == torch.bfloat16 else (124, 131) |
| 100 | + scale_x = torch.randint(min_scale, max_scale + 1, (M, K // 32), dtype=torch.uint8, device=device) |
| 101 | + scale_y = torch.randint(min_scale, max_scale + 1, (N, K // 32), dtype=torch.uint8, device=device) |
| 102 | + |
| 103 | + def make_finite(x, dtype): |
| 104 | + # e5m2 has too many non-finite values when sampled uniformly (1 / 32) and |
| 105 | + # Fp8E5M2_to_Bf16 doesn't preserve NaNs (fixme) |
| 106 | + if dtype not in ('e5m2', 'e4m3'): |
| 107 | + return x |
| 108 | + if dtype == 'e5m2' and comp_dtype == torch.float16: |
| 109 | + x = x & 0xB |
| 110 | + mask = 0x7C if dtype == 'e5m2' else 0x7F |
| 111 | + finite = torch.arange(x.numel(), device=device, dtype=torch.uint8).reshape_as(x) % mask |
| 112 | + x_finite = torch.where(x & mask == mask, finite | (0x80 & x), x) |
| 113 | + x.copy_(x_finite) |
| 114 | + return x |
| 115 | + |
| 116 | + x = make_finite(x, type_a) |
| 117 | + y = make_finite(y, type_b) |
| 118 | + z = x.new_empty((M, N), dtype=comp_dtype) |
| 119 | + if rhs_scale: |
| 120 | + scale_x = None |
| 121 | + else: |
| 122 | + scale_y = None |
| 123 | + |
| 124 | + if provider == 'triton': |
| 125 | + triton_fn = lambda: dot_scaled(M, N, K, x, y, z, scale_x, scale_y, type_a, type_b, num_warps) |
| 126 | + |
| 127 | + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, |
| 128 | + quantiles=quantiles) |
| 129 | + else: |
| 130 | + raise NotImplementedError(f'Unsupported provider {provider}') |
| 131 | + |
| 132 | + def tflops(ms): |
| 133 | + scale_ops = N * K if rhs_scale else M * K |
| 134 | + return (2 * M * N * K + scale_ops) * (1e-12) / (ms * 1e-3) |
| 135 | + |
| 136 | + def gbps(ms): |
| 137 | + |
| 138 | + def size_x(m, n, ty): |
| 139 | + if ty in ['e2m1']: |
| 140 | + return m * n // 2 |
| 141 | + if ty in ['e4m3', 'e5m2']: |
| 142 | + return m * n |
| 143 | + if ty in ['fp16', 'bf16']: |
| 144 | + return m * n * 2 |
| 145 | + raise NotImplementedError(f'Unsupported type {ty} for scaledot operand') |
| 146 | + |
| 147 | + tensor_size = size_x(M, K, type_a) + size_x(K, N, type_b) |
| 148 | + scale_size = (M * K // 32) if rhs_scale else (N * K // 32) |
| 149 | + return (tensor_size + scale_size + 4.0 * (M * N)) * (1e-9) / (ms * 1e-3) |
| 150 | + |
| 151 | + return (gbps(mean_ms), gbps(max_ms), gbps(min_ms)), (tflops(mean_ms), tflops(max_ms), tflops(min_ms)), cv |
| 152 | + |
| 153 | + |
| 154 | +if __name__ == '__main__': |
| 155 | + benchmark.run(show_plots=False, print_data=True) |
0 commit comments