|
| 1 | +import torch |
| 2 | +import triton |
| 3 | +import triton.language as tl |
| 4 | +from functools import partial |
| 5 | + |
| 6 | +device = 'xpu' |
| 7 | +backend = getattr(torch, device) |
| 8 | + |
| 9 | + |
| 10 | +def compute_time( |
| 11 | + fn, |
| 12 | + warmup=1, |
| 13 | + rep=5, |
| 14 | + grad_to_none=None, |
| 15 | + quantiles=None, |
| 16 | + fast_flush=True, |
| 17 | + return_mode="mean", |
| 18 | +): |
| 19 | + assert return_mode in ["min", "max", "mean", "median"] |
| 20 | + |
| 21 | + """ |
| 22 | + Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with |
| 23 | + the 20-th and 80-th performance percentile. |
| 24 | +
|
| 25 | + :param fn: Function to benchmark |
| 26 | + :type fn: Callable |
| 27 | + :param warmup: Warmup time (in ms) |
| 28 | + :type warmup: int |
| 29 | + :param rep: Repetition time (in ms) |
| 30 | + :type rep: int |
| 31 | + :param grad_to_none: Reset the gradient of the provided tensor to None |
| 32 | + :type grad_to_none: torch.tensor, optional |
| 33 | + :param quantiles: Performance percentile to return in addition to the median. |
| 34 | + :type quantiles: list[float] |
| 35 | + :param fast_flush: Use faster kernel to flush L2 between measurements |
| 36 | + :type fast_flush: bool |
| 37 | + """ |
| 38 | + backend.synchronize() |
| 39 | + |
| 40 | + # We maintain a buffer of 256 MB that we clear |
| 41 | + # before each kernel call to make sure that the L2 |
| 42 | + # doesn't contain any input data before the run |
| 43 | + if fast_flush: |
| 44 | + cache = torch.empty(int(256e6 // 4), dtype=torch.int, device=device) |
| 45 | + else: |
| 46 | + cache = torch.empty(int(256e6), dtype=torch.int8, device=device) |
| 47 | + |
| 48 | + # compute number of warmup and repeat |
| 49 | + |
| 50 | + start_event = [backend.Event(enable_timing=True) for i in range(rep)] |
| 51 | + end_event = [backend.Event(enable_timing=True) for i in range(rep)] |
| 52 | + # Warm-up |
| 53 | + for _ in range(warmup): |
| 54 | + fn() |
| 55 | + # Benchmark |
| 56 | + for i in range(rep): |
| 57 | + # we don't want `fn` to accumulate gradient values |
| 58 | + # if it contains a backward pass. So we clear the |
| 59 | + # provided gradients |
| 60 | + if grad_to_none is not None: |
| 61 | + for x in grad_to_none: |
| 62 | + if hasattr(x, 'grad'): |
| 63 | + x.grad = None |
| 64 | + # we clear the L2 cache before each run |
| 65 | + cache.zero_() |
| 66 | + # record time of `fn` |
| 67 | + start_event[i].record() |
| 68 | + fn() |
| 69 | + end_event[i].record() |
| 70 | + # Record clocks |
| 71 | + backend.synchronize() |
| 72 | + times = torch.tensor( |
| 73 | + [s.elapsed_time(e) for s, e in zip(start_event, end_event)], dtype=torch.float |
| 74 | + ) |
| 75 | + if quantiles is not None: |
| 76 | + ret = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist() |
| 77 | + if len(ret) == 1: |
| 78 | + ret = ret[0] |
| 79 | + return ret |
| 80 | + return getattr(torch, return_mode)(times).item() |
| 81 | + |
| 82 | + |
| 83 | +@triton.autotune( |
| 84 | + configs=[ |
| 85 | + triton.Config(kwargs={'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': 'large'}, num_stages=2, num_warps=32), |
| 86 | + # triton.Config(kwargs={'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4}, num_stages=3, num_warps=32), |
| 87 | + # triton.Config(kwargs={'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4}, num_stages=2, num_warps=32), |
| 88 | + # triton.Config(kwargs={'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4}, num_stages=2, num_warps=32), |
| 89 | + # triton.Config(kwargs={'BLOCK_SIZE_M': 8, 'BLOCK_SIZE_N': 512, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1}, num_stages=2, num_warps=32), |
| 90 | + ], |
| 91 | + key=['M', 'N', 'K'],) |
| 92 | +@triton.jit |
| 93 | +def matmul_kernel_with_block_pointers( |
| 94 | + # Pointers to matrices |
| 95 | + a_ptr, b_ptr, bias_ptr, c_ptr, |
| 96 | + # Matrix dimensions |
| 97 | + M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, |
| 98 | + # The stride variables represent how much to increase the ptr by when moving by 1 |
| 99 | + # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` |
| 100 | + # by to get the element one row down (A has M rows). |
| 101 | + stride_am: tl.constexpr, stride_ak: tl.constexpr, # |
| 102 | + stride_bk: tl.constexpr, stride_bn: tl.constexpr, # |
| 103 | + stride_cm: tl.constexpr, stride_cn: tl.constexpr, |
| 104 | + BIAS_REQD: tl.constexpr, |
| 105 | + # Meta-parameters |
| 106 | + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): |
| 107 | + """Kernel for computing the matmul C = A x B. |
| 108 | + A has shape (M, K), B has shape (K, N) and C has shape (M, N) |
| 109 | + """ |
| 110 | + # ----------------------------------------------------------- |
| 111 | + # Map program ids `pid` to the block of C it should compute. |
| 112 | + # This is done in a grouped ordering to promote L2 data reuse. |
| 113 | + # See the matrix multiplication tutorial for details. |
| 114 | + pid = tl.program_id(axis=0) |
| 115 | + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) |
| 116 | + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) |
| 117 | + num_pid_in_group = GROUP_SIZE_M * num_pid_n |
| 118 | + group_id = pid // num_pid_in_group |
| 119 | + first_pid_m = group_id * GROUP_SIZE_M |
| 120 | + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) |
| 121 | + pid_m = first_pid_m + (pid % group_size_m) |
| 122 | + pid_n = (pid % num_pid_in_group) // group_size_m |
| 123 | + #tl.device_print("pid", pid_m) |
| 124 | + |
| 125 | + # ---------------------------------------------------------- |
| 126 | + # Create block pointers for the first blocks of A and B. |
| 127 | + # We will advance this pointer as we move in the K direction and accumulate. |
| 128 | + # See above `Make a Block Pointer` section for details. |
| 129 | + a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), |
| 130 | + offsets=(pid_m * BLOCK_SIZE_M, 0), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K), |
| 131 | + order=(1, 0)) |
| 132 | + b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), |
| 133 | + offsets=(0, pid_n * BLOCK_SIZE_N), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N), |
| 134 | + order=(1, 0)) |
| 135 | + |
| 136 | + # ----------------------------------------------------------- |
| 137 | + # Iterate to compute a block of the C matrix. |
| 138 | + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block. |
| 139 | + # of fp32 values for higher accuracy. |
| 140 | + # `accumulator` will be converted back to fp16 after the loop. |
| 141 | + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) |
| 142 | + for k in range(0, K, BLOCK_SIZE_K): |
| 143 | + # Load with boundary checks, no need to calculate the mask manually. |
| 144 | + # For better performance, you may remove some axis from the boundary |
| 145 | + # check, if you can guarantee that the access is always in-bound in |
| 146 | + # that axis. |
| 147 | + # See above `Load/Store a Block Pointer` section for details. |
| 148 | + a = tl.load(a_block_ptr, boundary_check=(0, 1)) |
| 149 | + b = tl.load(b_block_ptr, boundary_check=(0, 1)) |
| 150 | + # We accumulate along the K dimension. |
| 151 | + accumulator += tl.dot(a, b) |
| 152 | + # Advance the block pointer to the next K block. |
| 153 | + # See above `Advance a Block Pointer` section for details. |
| 154 | + a_block_ptr = tl.advance(a_block_ptr, (0, BLOCK_SIZE_K)) |
| 155 | + b_block_ptr = tl.advance(b_block_ptr, (BLOCK_SIZE_K, 0)) |
| 156 | + c = accumulator.to(tl.float32) |
| 157 | + # add bias to accumulator |
| 158 | + if BIAS_REQD: |
| 159 | + offs_yn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N |
| 160 | + bias = tl.load(bias_ptr + offs_yn, mask=offs_yn < N, other=0.0).to(tl.float32) |
| 161 | + c += bias[None, :] |
| 162 | + # ---------------------------------------------------------------- |
| 163 | + # Write back the block of the output matrix C with boundary checks. |
| 164 | + # See above `Load/Store a Block Pointer` section for details. |
| 165 | + c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn), |
| 166 | + offsets=(pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N), |
| 167 | + block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N), order=(1, 0)) |
| 168 | + tl.store(c_block_ptr, c.to(tl.float16), boundary_check=(0, 1)) |
| 169 | + |
| 170 | + |
| 171 | +def triton_mm(X, Y, b=None, transpose_x=False, transpose_y=False): |
| 172 | + if transpose_x: |
| 173 | + K, M = X.shape |
| 174 | + Xstride0, Xstride1 = X.stride(1), X.stride(0) |
| 175 | + else: |
| 176 | + M, K = X.shape |
| 177 | + Xstride0, Xstride1 = X.stride(0), X.stride(1) |
| 178 | + if transpose_y: |
| 179 | + N, _ = Y.shape |
| 180 | + Wstride0, Wstride1 = Y.stride(1), Y.stride(0) |
| 181 | + else: |
| 182 | + _, N = Y.shape |
| 183 | + Wstride0, Wstride1 = Y.stride(0), Y.stride(1) |
| 184 | + # Allocates output. |
| 185 | + Z = torch.empty((M, N), device=X.device, dtype=X.dtype) |
| 186 | + # 1D launch kernel where each block gets its own program. |
| 187 | + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) |
| 188 | + |
| 189 | + matmul_kernel_with_block_pointers[grid]( |
| 190 | + X, Y, b, Z, |
| 191 | + M, N, K, |
| 192 | + Xstride0, Xstride1, |
| 193 | + Wstride0, Wstride1, |
| 194 | + Z.stride(0), Z.stride(1), |
| 195 | + BIAS_REQD=b is not None, |
| 196 | + ) |
| 197 | + |
| 198 | + return Z |
| 199 | + |
| 200 | + |
| 201 | +M = 1024 |
| 202 | +K = 5120 |
| 203 | +N = 4096 |
| 204 | +dtype = torch.float16 |
| 205 | +torch.manual_seed(0) |
| 206 | + |
| 207 | +AxB = True |
| 208 | +AxBT = True |
| 209 | +ATxB = True |
| 210 | +ATxBT = True |
| 211 | + |
| 212 | +if AxB: |
| 213 | + print('Compute A x B') |
| 214 | + X = torch.randn((M, K), device=device, dtype=dtype, requires_grad=False) |
| 215 | + Y = torch.randn((K, N), device=device, dtype=dtype, requires_grad=False) |
| 216 | + |
| 217 | + fn_tor = partial(torch.mm, X, Y) |
| 218 | + fn_tri = partial(triton_mm, X, Y) |
| 219 | + |
| 220 | + rtol = 1e-3 |
| 221 | + result_tor = fn_tor() |
| 222 | + result_tri = fn_tri() |
| 223 | + if torch.allclose(result_tri, result_tor, atol=1e-2, rtol=rtol): |
| 224 | + print("✅ Triton and Torch match") |
| 225 | + else: |
| 226 | + exit("❌ Triton and Torch differ") |
| 227 | + |
| 228 | + t_tor = compute_time(fn_tor, warmup=5, rep=100) |
| 229 | + t_tri = compute_time(fn_tri, warmup=5, rep=100) |
| 230 | + print(f"Time for torch: {t_tor} ms") |
| 231 | + print(f"Time for triton: {t_tri} ms") |
| 232 | + |
| 233 | + |
| 234 | +if AxBT: |
| 235 | + torch.manual_seed(0) |
| 236 | + print('Compute A x B.T') |
| 237 | + X = torch.randn((M, K), device=device, dtype=dtype, requires_grad=False) |
| 238 | + Y = torch.randn((N, K), device=device, dtype=dtype, requires_grad=False) |
| 239 | + |
| 240 | + fn_tor = partial(torch.mm, X, Y.T) |
| 241 | + fn_tri = partial(triton_mm, X, Y, transpose_y=True) |
| 242 | + |
| 243 | + rtol = 1e-3 |
| 244 | + result_tor = fn_tor() |
| 245 | + result_tri = fn_tri() |
| 246 | + if torch.allclose(result_tri, result_tor, atol=1e-2, rtol=rtol): |
| 247 | + print("✅ Triton and Torch match") |
| 248 | + else: |
| 249 | + exit("❌ Triton and Torch differ") |
| 250 | + |
| 251 | + t_tor = compute_time(fn_tor, warmup=5, rep=100) |
| 252 | + t_tri = compute_time(fn_tri, warmup=5, rep=100) |
| 253 | + print(f"Time for torch: {t_tor} ms") |
| 254 | + print(f"Time for triton: {t_tri} ms") |
| 255 | + |
| 256 | +if ATxB: |
| 257 | + torch.manual_seed(0) |
| 258 | + print('Compute A.T x B') |
| 259 | + X = torch.randn((K, M), device=device, dtype=dtype, requires_grad=False) |
| 260 | + Y = torch.randn((K, N), device=device, dtype=dtype, requires_grad=False) |
| 261 | + |
| 262 | + fn_tor = partial(torch.mm, X.T, Y) |
| 263 | + fn_tri = partial(triton_mm, X, Y, transpose_x=True) |
| 264 | + |
| 265 | + rtol = 1e-3 |
| 266 | + result_tor = fn_tor() |
| 267 | + result_tri = fn_tri() |
| 268 | + if torch.allclose(result_tri, result_tor, atol=1e-2, rtol=rtol): |
| 269 | + print("✅ Triton and Torch match") |
| 270 | + else: |
| 271 | + exit("❌ Triton and Torch differ") |
| 272 | + |
| 273 | + t_tor = compute_time(fn_tor, warmup=5, rep=100) |
| 274 | + t_tri = compute_time(fn_tri, warmup=5, rep=100) |
| 275 | + print(f"Time for torch: {t_tor} ms") |
| 276 | + print(f"Time for triton: {t_tri} ms") |
| 277 | + |
| 278 | +if ATxBT: |
| 279 | + torch.manual_seed(0) |
| 280 | + print('Compute A.T x B.T') |
| 281 | + X = torch.randn((K, M), device=device, dtype=dtype, requires_grad=False) |
| 282 | + Y = torch.randn((N, K), device=device, dtype=dtype, requires_grad=False) |
| 283 | + |
| 284 | + fn_tor = partial(torch.mm, X.T, Y.T) |
| 285 | + fn_tri = partial(triton_mm, X, Y, transpose_x=True, transpose_y=True) |
| 286 | + |
| 287 | + rtol = 1e-3 |
| 288 | + result_tor = fn_tor() |
| 289 | + result_tri = fn_tri() |
| 290 | + if torch.allclose(result_tri, result_tor, atol=1e-2, rtol=rtol): |
| 291 | + print("✅ Triton and Torch match") |
| 292 | + else: |
| 293 | + exit("❌ Triton and Torch differ") |
| 294 | + |
| 295 | + t_tor = compute_time(fn_tor, warmup=5, rep=100) |
| 296 | + t_tri = compute_time(fn_tri, warmup=5, rep=100) |
| 297 | + print(f"Time for torch: {t_tor} ms") |
| 298 | + print(f"Time for triton: {t_tri} ms") |
0 commit comments