|
| 1 | +""" |
| 2 | +Gemm with A@B^t benchmark |
| 3 | +==================================== |
| 4 | +
|
| 5 | +This benchmark is modified from gemm_benchmark.py with added transpose. |
| 6 | +""" |
| 7 | + |
| 8 | +import torch |
| 9 | +import triton |
| 10 | +import triton.language as tl |
| 11 | + |
| 12 | +import triton_kernels_benchmark as benchmark_suit |
| 13 | +import xetla_kernel |
| 14 | + |
| 15 | +if benchmark_suit.USE_IPEX_OPTION: |
| 16 | + import intel_extension_for_pytorch # type: ignore # noqa: F401 |
| 17 | + |
| 18 | + |
| 19 | +@triton.autotune( |
| 20 | + configs=[ |
| 21 | + triton.Config( |
| 22 | + {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, |
| 23 | + num_stages=s, num_warps=32) for s in [1, 2, 3] |
| 24 | + ] + [ |
| 25 | + triton.Config( |
| 26 | + {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, |
| 27 | + num_stages=s, num_warps=32) for s in [2, 3] |
| 28 | + ] + [ |
| 29 | + triton.Config( |
| 30 | + {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, |
| 31 | + num_stages=s, num_warps=32) for s in [2] |
| 32 | + ] + [ |
| 33 | + triton.Config( |
| 34 | + {'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': 'large'}, |
| 35 | + num_stages=s, num_warps=32) for s in [2, 3] |
| 36 | + ], |
| 37 | + key=['M', 'N', 'K'], |
| 38 | +) |
| 39 | +@triton.jit |
| 40 | +def matmul_kernel_with_block_pointers( |
| 41 | + # Pointers to matrices |
| 42 | + a_ptr, b_ptr, c_ptr, |
| 43 | + # Matrix dimensions |
| 44 | + M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, |
| 45 | + # Stride variables |
| 46 | + stride_am: tl.constexpr, stride_ak: tl.constexpr, # |
| 47 | + stride_bk: tl.constexpr, stride_bn: tl.constexpr, # |
| 48 | + stride_cm: tl.constexpr, stride_cn: tl.constexpr, |
| 49 | + # Meta-parameters |
| 50 | + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): |
| 51 | + pid = tl.program_id(axis=0) |
| 52 | + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) |
| 53 | + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) |
| 54 | + num_pid_in_group = GROUP_SIZE_M * num_pid_n |
| 55 | + group_id = pid // num_pid_in_group |
| 56 | + first_pid_m = group_id * GROUP_SIZE_M |
| 57 | + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) |
| 58 | + pid_m = first_pid_m + (pid % group_size_m) |
| 59 | + pid_n = (pid % num_pid_in_group) // group_size_m |
| 60 | + |
| 61 | + a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), |
| 62 | + offsets=(pid_m * BLOCK_SIZE_M, 0), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K), |
| 63 | + order=(1, 0)) |
| 64 | + b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), |
| 65 | + offsets=(0, pid_n * BLOCK_SIZE_N), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N), |
| 66 | + order=(1, 0)) |
| 67 | + |
| 68 | + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) |
| 69 | + for _ in range(0, K, BLOCK_SIZE_K): |
| 70 | + a = tl.load(a_block_ptr, boundary_check=(0, 1)) |
| 71 | + b = tl.load(b_block_ptr, boundary_check=(0, 1)) |
| 72 | + accumulator += tl.dot(a, b) |
| 73 | + a_block_ptr = tl.advance(a_block_ptr, (0, BLOCK_SIZE_K)) |
| 74 | + b_block_ptr = tl.advance(b_block_ptr, (BLOCK_SIZE_K, 0)) |
| 75 | + c = accumulator.to(tl.float32) |
| 76 | + |
| 77 | + c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn), |
| 78 | + offsets=(pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N), |
| 79 | + block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N), order=(1, 0)) |
| 80 | + tl.store(c_block_ptr, c, boundary_check=(0, 1)) |
| 81 | + |
| 82 | + |
| 83 | +# pylint: disable=unused-argument |
| 84 | +@triton.autotune( |
| 85 | + configs=[ |
| 86 | + triton.Config( |
| 87 | + {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, |
| 88 | + num_stages=s, num_warps=32) for s in [2, 3] |
| 89 | + ] + [ |
| 90 | + triton.Config( |
| 91 | + {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, |
| 92 | + num_stages=s, num_warps=32) for s in [2] |
| 93 | + ] + [ |
| 94 | + triton.Config( |
| 95 | + {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, |
| 96 | + num_stages=s, num_warps=32) for s in [2] |
| 97 | + ] + [ |
| 98 | + triton.Config( |
| 99 | + {'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': 'large'}, |
| 100 | + num_stages=s, num_warps=32) for s in [2] |
| 101 | + ] + [ |
| 102 | + triton.Config( |
| 103 | + {'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'grf_mode': 'large'}, |
| 104 | + num_stages=s, num_warps=4) for s in [2] |
| 105 | + ], |
| 106 | + key=['M', 'N', 'K'], |
| 107 | +) |
| 108 | +@triton.jit |
| 109 | +def matmul_kernel_with_block_pointers_batched( |
| 110 | + # Pointers to matrices |
| 111 | + a_ptr, b_ptr, c_ptr, |
| 112 | + # Matrix dimensions |
| 113 | + B: tl.constexpr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, |
| 114 | + # Stride variables |
| 115 | + stride_az: tl.constexpr, stride_am: tl.constexpr, stride_ak: tl.constexpr, # |
| 116 | + stride_bz: tl.constexpr, stride_bk: tl.constexpr, stride_bn: tl.constexpr, # |
| 117 | + stride_cz: tl.constexpr, stride_cm: tl.constexpr, stride_cn: tl.constexpr, |
| 118 | + # Meta-parameters |
| 119 | + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): |
| 120 | + bid = tl.program_id(axis=0) |
| 121 | + pid = tl.program_id(axis=1) |
| 122 | + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) |
| 123 | + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) |
| 124 | + num_pid_in_group = GROUP_SIZE_M * num_pid_n |
| 125 | + group_id = pid // num_pid_in_group |
| 126 | + first_pid_m = group_id * GROUP_SIZE_M |
| 127 | + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) |
| 128 | + pid_m = first_pid_m + (pid % group_size_m) |
| 129 | + pid_n = (pid % num_pid_in_group) // group_size_m |
| 130 | + |
| 131 | + offset_a = bid.to(tl.int64) * stride_az |
| 132 | + offset_b = bid.to(tl.int64) * stride_bz |
| 133 | + |
| 134 | + a_block_ptr = tl.make_block_ptr(base=a_ptr + offset_a, shape=(M, K), strides=(stride_am, stride_ak), |
| 135 | + offsets=(pid_m * BLOCK_SIZE_M, 0), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K), |
| 136 | + order=(1, 0)) |
| 137 | + b_block_ptr = tl.make_block_ptr(base=b_ptr + offset_b, shape=(K, N), strides=(stride_bk, stride_bn), |
| 138 | + offsets=(0, pid_n * BLOCK_SIZE_N), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N), |
| 139 | + order=(1, 0)) |
| 140 | + |
| 141 | + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) |
| 142 | + for _ in range(0, K, BLOCK_SIZE_K): |
| 143 | + a = tl.load(a_block_ptr, boundary_check=(0, 1)) |
| 144 | + b = tl.load(b_block_ptr, boundary_check=(0, 1)) |
| 145 | + accumulator += tl.dot(a, b) |
| 146 | + a_block_ptr = tl.advance(a_block_ptr, (0, BLOCK_SIZE_K)) |
| 147 | + b_block_ptr = tl.advance(b_block_ptr, (BLOCK_SIZE_K, 0)) |
| 148 | + c = accumulator.to(tl.float32) |
| 149 | + |
| 150 | + offset_c = bid.to(tl.int64) * stride_cz |
| 151 | + c_block_ptr = tl.make_block_ptr(base=c_ptr + offset_c, shape=(M, N), strides=(stride_cm, stride_cn), |
| 152 | + offsets=(pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N), |
| 153 | + block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N), order=(1, 0)) |
| 154 | + tl.store(c_block_ptr, c, boundary_check=(0, 1)) |
| 155 | + |
| 156 | + |
| 157 | +# We can now create a convenience wrapper function that only takes two input tensors, |
| 158 | +# and (1) checks any shape constraint; (2) launches the above kernel. |
| 159 | +def matmul(a, b, c): |
| 160 | + # Check constraints. |
| 161 | + if len(a.shape) == 3 and len(b.shape) == 3: |
| 162 | + assert a.shape[0] == b.shape[0], 'Incompatible Batch dimension' |
| 163 | + assert a.shape[2] == b.shape[2], 'Incompatible dimensions' |
| 164 | + assert a.is_contiguous(), 'Matrix A must be contiguous' |
| 165 | + assert b.is_contiguous(), 'Matrix B must be contiguous' |
| 166 | + B, M, K = a.shape |
| 167 | + B, N, K = b.shape |
| 168 | + # 1D launch kernel where each block gets its own program. |
| 169 | + grid = lambda META: ( |
| 170 | + B, |
| 171 | + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), |
| 172 | + ) |
| 173 | + matmul_kernel_with_block_pointers_batched[grid]( |
| 174 | + a, b, c, # |
| 175 | + B, M, N, K, # |
| 176 | + a.stride(0), a.stride(1), a.stride(2), # |
| 177 | + b.stride(0), b.stride(2), b.stride(1), # |
| 178 | + c.stride(0), c.stride(1), c.stride(2)) |
| 179 | + elif len(a.shape) == 2 and len(b.shape) == 2: |
| 180 | + assert a.shape[1] == b.shape[1], 'Incompatible dimensions' |
| 181 | + assert a.is_contiguous(), 'Matrix A must be contiguous' |
| 182 | + assert b.is_contiguous(), 'Matrix B must be contiguous' |
| 183 | + M, K = a.shape |
| 184 | + N, K = b.shape |
| 185 | + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) |
| 186 | + matmul_kernel_with_block_pointers[grid]( |
| 187 | + a, b, c, # |
| 188 | + M, N, K, # |
| 189 | + a.stride(0), a.stride(1), # |
| 190 | + b.stride(1), b.stride(0), # |
| 191 | + c.stride(0), c.stride(1)) |
| 192 | + else: |
| 193 | + assert False, 'Input matrixs dimensions mismatch' |
| 194 | + return c |
| 195 | + |
| 196 | + |
| 197 | +# Benchmark Performance |
| 198 | +@benchmark_suit.perf_report( |
| 199 | + benchmark_suit.Benchmark( |
| 200 | + # argument names to use as an x-axis for the plot |
| 201 | + x_names=['B', 'M', 'K', 'N'], |
| 202 | + # different possible values for `x_name` |
| 203 | + x_vals=[[1, 1024 * i, 1024 * i, 1024 * i] for i in [1, 2, 4, 8]] + # |
| 204 | + [ # |
| 205 | + [1, 1, 5120, 13824], # |
| 206 | + [1, 4, 4096, 12288], # |
| 207 | + [1, 512, 8192, 8192], # |
| 208 | + [1, 512, 8192, 32768], # |
| 209 | + [1, 512, 32768, 8192], # |
| 210 | + [1, 1024, 16384, 8192], # |
| 211 | + [1, 1024, 28672, 8192], # |
| 212 | + [1, 3072, 4096, 3072], # FIXME: Remove this case when gemm_streamk_benchmark works |
| 213 | + [1, 4096, 16384, 8192], # |
| 214 | + [1, 8192, 16384, 1024], # |
| 215 | + [1, 8192, 16384, 4096], # |
| 216 | + [1, 16384, 1024, 8192], # |
| 217 | + [1, 16384, 4096, 8192], # |
| 218 | + [1, 16384, 8192, 1024], # |
| 219 | + [1, 16384, 8192, 4096], # |
| 220 | + [4, 32768, 128, 4096], # |
| 221 | + [4, 32768, 4096, 128], # |
| 222 | + [32, 4096, 4096, 128], # |
| 223 | + [4096, 8, 128, 16384], # |
| 224 | + [4096, 8, 16384, 128] |
| 225 | + ], |
| 226 | + line_arg='provider', |
| 227 | + # argument name whose value corresponds to a different line in the plot |
| 228 | + # possible values for `line_arg`` |
| 229 | + line_vals=['triton'], |
| 230 | + # label name for the lines |
| 231 | + line_names=['Triton'], |
| 232 | + # line styles |
| 233 | + styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')], |
| 234 | + ylabel=['GB/s', 'TFlops'], # label name for the y-axis |
| 235 | + plot_name='matmul-performance', |
| 236 | + # name for the plot. Used also as a file name for saving the plot. |
| 237 | + args={}, |
| 238 | + )) |
| 239 | +def benchmark(B, M, N, K, provider): |
| 240 | + if B == 1: |
| 241 | + a = torch.rand((M, K), device='xpu', dtype=torch.bfloat16) |
| 242 | + b = torch.rand((N, K), device='xpu', dtype=torch.bfloat16) |
| 243 | + else: |
| 244 | + a = torch.rand((B, M, K), device='xpu', dtype=torch.bfloat16) |
| 245 | + b = torch.rand((B, N, K), device='xpu', dtype=torch.bfloat16) |
| 246 | + |
| 247 | + quantiles = [0.5, 0.0, 1.0] |
| 248 | + |
| 249 | + if provider == 'onednn': |
| 250 | + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, torch.transpose(b, -1, -2)), |
| 251 | + warmup=10, rep=10, quantiles=quantiles, |
| 252 | + fast_flush=False) |
| 253 | + elif provider == 'triton': |
| 254 | + assert len(a.shape) == len(b.shape), 'Incompatible sizes' |
| 255 | + if len(a.shape) == 3: |
| 256 | + c = torch.empty((B, M, N), device='xpu', dtype=torch.float32) |
| 257 | + else: |
| 258 | + assert len(a.shape) == 2, 'Expecting shape of length 2' |
| 259 | + c = torch.empty((M, N), device='xpu', dtype=torch.float32) |
| 260 | + triton_fn = lambda: matmul(a, b, c) |
| 261 | + torch_fn = lambda: torch.matmul(a, torch.transpose(b, -1, -2)).to(torch.float32) |
| 262 | + rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 |
| 263 | + benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') |
| 264 | + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, |
| 265 | + fast_flush=False) |
| 266 | + elif provider == 'xetla': |
| 267 | + if B == 1: |
| 268 | + c = torch.empty((M, N), device='xpu', dtype=torch.float32) |
| 269 | + acc = torch.empty((M, N), device='xpu', dtype=torch.float32) |
| 270 | + cnt = torch.empty((M, N), device='xpu', dtype=torch.int32) |
| 271 | + else: |
| 272 | + c = torch.empty((B, M, N), device='xpu', dtype=torch.float32) |
| 273 | + acc = torch.empty((B, M, N), device='xpu', dtype=torch.float32) |
| 274 | + cnt = torch.empty((B, M, N), device='xpu', dtype=torch.int32) |
| 275 | + name = f'gemm_shape_{B}_{M}_{K}_{N}' |
| 276 | + func = getattr(xetla_kernel, name) |
| 277 | + xetla_fn = lambda: func(a, torch.transpose(b, -1, -2), c, acc, cnt) |
| 278 | + torch_fn = lambda: torch.matmul(a, torch.tranpose(b, -1, -2)).to(torch.float32) |
| 279 | + # benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch') |
| 280 | + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles, |
| 281 | + fast_flush=False) |
| 282 | + else: |
| 283 | + raise NotImplementedError(f'Unsupported provider {provider}') |
| 284 | + |
| 285 | + tflops = lambda ms: 2 * B * M * N * K * (1e-12) / (ms * 1e-3) |
| 286 | + gbps = lambda ms: B * (2 * (M * K + K * N) + 4.0 * (M * N)) * (1e-9) / (ms * 1e-3) |
| 287 | + |
| 288 | + return (gbps(mean_ms), gbps(max_ms), gbps(min_ms)), (tflops(mean_ms), tflops(max_ms), tflops(min_ms)), cv |
| 289 | + |
| 290 | + |
| 291 | +if __name__ == '__main__': |
| 292 | + benchmark.run(show_plots=False, print_data=True) |
0 commit comments