|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | +# type: ignore |
| 16 | +import torch |
| 17 | +import triton |
| 18 | +import triton.language as tl |
| 19 | + |
| 20 | + |
| 21 | +try: |
| 22 | + from triton.tools.tensor_descriptor import TensorDescriptor |
| 23 | +except ImportError: |
| 24 | + raise ImportError( |
| 25 | + f"Triton version ({triton.__version__}) doesn't support tensor descriptor API. Minimum required version is 3.4.0." |
| 26 | + ) |
| 27 | + |
| 28 | + |
| 29 | +__all__ = ["ssyrk", "tsyrk_ex"] |
| 30 | + |
| 31 | + |
| 32 | +@triton.jit |
| 33 | +def cvt_tf32_rn(x: tl.tensor) -> tl.tensor: |
| 34 | + return tl.inline_asm_elementwise("cvt.rna.tf32.f32 $0, $1;", "=r, r", [x], dtype=tl.float32, is_pure=True, pack=1) |
| 35 | + |
| 36 | + |
| 37 | +@triton.autotune( |
| 38 | + configs=[ |
| 39 | + triton.Config({"TILE_N": tn, "TILE_K": tk}, num_warps=nw, num_stages=ns) |
| 40 | + for tn in (64, 128) |
| 41 | + for tk in (16, 32, 64) |
| 42 | + for nw in (4, 8) |
| 43 | + for ns in (3, 4) |
| 44 | + ], |
| 45 | + key=["N", "K", "ALLOW_TF32"], |
| 46 | +) |
| 47 | +@triton.jit |
| 48 | +def syrk_op_n_simple_kernel( |
| 49 | + c_ptr, |
| 50 | + a_ptr, |
| 51 | + N: tl.constexpr, |
| 52 | + K: tl.constexpr, |
| 53 | + STRIDE_N: tl.constexpr, |
| 54 | + STRIDE_K: tl.constexpr, |
| 55 | + ALLOW_TF32: tl.constexpr, |
| 56 | + TILE_N: tl.constexpr, |
| 57 | + TILE_K: tl.constexpr, |
| 58 | +): |
| 59 | + # receives tensor of shape (N, K) |
| 60 | + # computes A * A^T (-> produces NxN) |
| 61 | + |
| 62 | + pid_row = tl.program_id(0) |
| 63 | + pid_col = tl.program_id(1) |
| 64 | + |
| 65 | + IS_BELOW_DIAG = pid_row < pid_col |
| 66 | + IS_ABOVE_DIAG = pid_row > pid_col |
| 67 | + |
| 68 | + if IS_ABOVE_DIAG: |
| 69 | + return |
| 70 | + |
| 71 | + offs_row = pid_row * TILE_N + tl.arange(0, TILE_N) |
| 72 | + offs_col = pid_col * TILE_N + tl.arange(0, TILE_N) |
| 73 | + offs_k = tl.arange(0, TILE_K) |
| 74 | + |
| 75 | + mask_row = offs_row < N |
| 76 | + mask_col = offs_col < N |
| 77 | + |
| 78 | + a_ptrs_x = a_ptr + offs_row[:, None] * STRIDE_N + offs_k[None, :] * STRIDE_K |
| 79 | + a_ptrs_y = a_ptr + offs_col[None, :] * STRIDE_N + offs_k[:, None] * STRIDE_K |
| 80 | + |
| 81 | + acc = tl.zeros((TILE_N, TILE_N), dtype=tl.float32) |
| 82 | + |
| 83 | + num_tiles_k = tl.cdiv(K, TILE_K) |
| 84 | + for k in range(0, num_tiles_k): |
| 85 | + mask_k = offs_k < K - k * TILE_K |
| 86 | + mask_x = mask_row[:, None] & mask_k[None, :] |
| 87 | + mask_y = mask_col[None, :] & mask_k[:, None] |
| 88 | + x = tl.load(a_ptrs_x, mask=mask_x, other=0.0) |
| 89 | + y = tl.load(a_ptrs_y, mask=mask_y, other=0.0) |
| 90 | + |
| 91 | + if ALLOW_TF32 == 0: |
| 92 | + acc = tl.dot(x, y, acc=acc, input_precision="ieee") |
| 93 | + elif ALLOW_TF32 == 1: |
| 94 | + x = cvt_tf32_rn(x) |
| 95 | + y = cvt_tf32_rn(y) |
| 96 | + acc = tl.dot(x, y, acc=acc, input_precision="tf32") |
| 97 | + else: |
| 98 | + tl.static_assert(False, "Unsupported precision.") |
| 99 | + |
| 100 | + a_ptrs_x += TILE_K * STRIDE_K |
| 101 | + a_ptrs_y += TILE_K * STRIDE_K |
| 102 | + |
| 103 | + # store diagonal or below diagonal values |
| 104 | + c_ptrs = c_ptr + offs_row[:, None] * N + offs_col[None, :] |
| 105 | + mask_c = mask_row[:, None] & mask_col[None, :] |
| 106 | + tl.store(c_ptrs, acc, mask=mask_c) |
| 107 | + |
| 108 | + # store replicated values above diagonal |
| 109 | + if IS_BELOW_DIAG: |
| 110 | + c_ptrs_diag = c_ptr + offs_col[None, :] * N + offs_row[:, None] |
| 111 | + tl.store(c_ptrs_diag, acc, mask=mask_c) |
| 112 | + |
| 113 | + |
| 114 | +def ssyrk(a: torch.Tensor, trans: bool = False) -> torch.Tensor: |
| 115 | + """Triton implementation of BLAS ssyrk operation. |
| 116 | +
|
| 117 | + Note: |
| 118 | + This function assumes row major layout of the input tensor. |
| 119 | +
|
| 120 | + TODO(mstadler): Add support for alpha, beta and c. |
| 121 | +
|
| 122 | + Args: |
| 123 | + a: Input tensor of shape (N, K) or (K, N) |
| 124 | + trans: Whether to compute A * A^T (trans=False) or A^T * A (trans=True) |
| 125 | +
|
| 126 | + Returns: |
| 127 | + Output tensor of shape (N, N) |
| 128 | + """ |
| 129 | + assert a.dim() == 2, "Input tensor must be 2D" |
| 130 | + N, K = a.shape |
| 131 | + if trans: |
| 132 | + raise NotImplementedError("Transpose is not supported yet.") |
| 133 | + |
| 134 | + STRIDE_N = a.stride(0) |
| 135 | + STRIDE_K = a.stride(1) |
| 136 | + |
| 137 | + if (fp32_matmul_prec := torch.get_float32_matmul_precision()) == "highest": |
| 138 | + ALLOW_TF32 = 0 |
| 139 | + elif fp32_matmul_prec == "high": |
| 140 | + ALLOW_TF32 = 1 |
| 141 | + else: |
| 142 | + raise ValueError(f"Unsupported precision {fp32_matmul_prec}, only 'highest' and 'high' are supported.") |
| 143 | + |
| 144 | + c = torch.empty((N, N), dtype=a.dtype, device=a.device) |
| 145 | + |
| 146 | + def grid(META): |
| 147 | + return (triton.cdiv(N, META["TILE_N"]), triton.cdiv(N, META["TILE_N"])) |
| 148 | + |
| 149 | + if not trans: |
| 150 | + syrk_op_n_simple_kernel[grid](c, a, N, K, STRIDE_N, STRIDE_K, ALLOW_TF32) |
| 151 | + |
| 152 | + return c |
| 153 | + |
| 154 | + |
| 155 | +def prune_invalid_configs(configs: list[triton.Config], named_args: dict, **kwargs) -> list[triton.Config]: |
| 156 | + """Prune invalid Triton kernel configs based on input size and tile parameters. |
| 157 | +
|
| 158 | + Args: |
| 159 | + configs: List of Triton kernel configs. |
| 160 | + named_args: Named arguments for the kernel. |
| 161 | + **kwargs: Additional keyword arguments. |
| 162 | +
|
| 163 | + Returns: |
| 164 | + List of valid Triton kernel configs. |
| 165 | + """ |
| 166 | + N = named_args["N"] |
| 167 | + |
| 168 | + conf = [] |
| 169 | + for c in configs: |
| 170 | + TILE_M = c.kwargs.get("TILE_M", 0) |
| 171 | + TILE_N = c.kwargs.get("TILE_N", 0) |
| 172 | + TILE_K = c.kwargs.get("TILE_K", 0) |
| 173 | + |
| 174 | + # 5000 is an empirically determined threshold from size shmoo to select the best config |
| 175 | + if N >= 5000: |
| 176 | + if TILE_M == 128 and TILE_N == 256 and TILE_K == 64: |
| 177 | + conf.append(c) |
| 178 | + else: |
| 179 | + if TILE_M <= 128 and TILE_N >= TILE_M and TILE_K <= 128: |
| 180 | + conf.append(c) |
| 181 | + return conf |
| 182 | + |
| 183 | + |
| 184 | +def matmul_tma_set_block_size_hook(nargs: dict) -> None: |
| 185 | + """Sets the block shapes for tensor descriptors based on tile sizes. |
| 186 | +
|
| 187 | + Args: |
| 188 | + nargs: Named arguments for the kernel. |
| 189 | + """ |
| 190 | + TILE_M = nargs["TILE_M"] |
| 191 | + TILE_N = nargs["TILE_N"] |
| 192 | + TILE_K = nargs["TILE_K"] |
| 193 | + TRANS = nargs["TRANS"] |
| 194 | + nargs["a_desc"].block_shape = [TILE_K, TILE_M] if TRANS else [TILE_M, TILE_K] |
| 195 | + nargs["a_t_desc"].block_shape = [TILE_K, TILE_N] if TRANS else [TILE_N, TILE_K] |
| 196 | + if nargs["c_desc"] is not None: |
| 197 | + nargs["c_desc"].block_shape = [TILE_M, TILE_N] |
| 198 | + nargs["d_desc"].block_shape = [TILE_M, TILE_N] |
| 199 | + nargs["d_t_desc"].block_shape = [TILE_N, TILE_M] |
| 200 | + |
| 201 | + |
| 202 | +@triton.autotune( |
| 203 | + configs=[ |
| 204 | + triton.Config( |
| 205 | + {"TILE_M": tm, "TILE_N": tn, "TILE_K": tk, "GROUP_SIZE_M": gm}, |
| 206 | + num_warps=nw, |
| 207 | + num_stages=ns, |
| 208 | + num_ctas=nc, |
| 209 | + pre_hook=matmul_tma_set_block_size_hook, |
| 210 | + ) |
| 211 | + for tm in (64, 128, 256) |
| 212 | + for tn in (64, 128, 256) |
| 213 | + for tk in (64, 128, 256) |
| 214 | + for gm in (2, 4, 8) |
| 215 | + for nw in (4, 8) |
| 216 | + for ns in (2, 3, 4) |
| 217 | + for nc in (1,) |
| 218 | + ], |
| 219 | + key=["N", "K", "TRANS", "WARP_SPECIALIZE"], |
| 220 | + prune_configs_by={"early_config_prune": prune_invalid_configs}, |
| 221 | +) |
| 222 | +@triton.jit |
| 223 | +def syrk_kernel_bf16( |
| 224 | + d_desc, |
| 225 | + d_t_desc, |
| 226 | + a_desc, |
| 227 | + a_t_desc, |
| 228 | + c_desc, |
| 229 | + alpha: tl.constexpr, |
| 230 | + beta: tl.constexpr, |
| 231 | + SKIP_UPPER_TRIANGLE: tl.constexpr, |
| 232 | + TRANS: tl.constexpr, |
| 233 | + N: tl.constexpr, |
| 234 | + K: tl.constexpr, |
| 235 | + TILE_M: tl.constexpr, |
| 236 | + TILE_N: tl.constexpr, |
| 237 | + TILE_K: tl.constexpr, |
| 238 | + GROUP_SIZE_M: tl.constexpr, |
| 239 | + WARP_SPECIALIZE: tl.constexpr, |
| 240 | +): |
| 241 | + # input A tensor of shape (N, K) |
| 242 | + # computes D = alpha * A * A^T + beta * C (-> produces NxN) |
| 243 | + # NOTE: If beta != 0, then C must be a symmetric matrix (i.e., C == C^T) |
| 244 | + |
| 245 | + pid = tl.program_id(axis=0) |
| 246 | + num_pid_m = tl.cdiv(N, TILE_M) |
| 247 | + num_pid_n = tl.cdiv(N, TILE_N) |
| 248 | + num_pid_in_group = GROUP_SIZE_M * num_pid_n |
| 249 | + group_id = pid // num_pid_in_group |
| 250 | + first_pid_m = group_id * GROUP_SIZE_M |
| 251 | + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) |
| 252 | + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) |
| 253 | + pid_n = (pid % num_pid_in_group) // group_size_m |
| 254 | + |
| 255 | + IS_BELOW_DIAG = pid_m * TILE_M >= pid_n * TILE_N + TILE_N |
| 256 | + IS_ABOVE_DIAG = pid_m * TILE_M + TILE_M <= pid_n * TILE_N |
| 257 | + IS_SQUARE_TILE = TILE_M == TILE_N |
| 258 | + |
| 259 | + if IS_ABOVE_DIAG: |
| 260 | + return |
| 261 | + |
| 262 | + # hints for the compiler |
| 263 | + tl.assume(pid_m >= 0) |
| 264 | + tl.assume(pid_n >= 0) |
| 265 | + |
| 266 | + offs_row = pid_m * TILE_M |
| 267 | + offs_col = pid_n * TILE_N |
| 268 | + |
| 269 | + acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32) |
| 270 | + |
| 271 | + num_tiles_k = tl.cdiv(K, TILE_K) |
| 272 | + for k in tl.range(num_tiles_k, warp_specialize=WARP_SPECIALIZE): |
| 273 | + offs_k = k * TILE_K |
| 274 | + if TRANS: |
| 275 | + x = a_desc.load([offs_k, offs_row]) |
| 276 | + y = a_t_desc.load([offs_k, offs_col]) |
| 277 | + acc = tl.dot(x.T, y, acc=acc) |
| 278 | + else: |
| 279 | + x = a_desc.load([offs_row, offs_k]) |
| 280 | + y = a_t_desc.load([offs_col, offs_k]) |
| 281 | + acc = tl.dot(x, y.T, acc=acc) |
| 282 | + |
| 283 | + if alpha != 1.0: |
| 284 | + acc = alpha * acc |
| 285 | + if beta != 0.0: |
| 286 | + z = c_desc.load([offs_row, offs_col]).to(tl.float32) |
| 287 | + acc = beta * z + acc |
| 288 | + |
| 289 | + d = acc.to(tl.bfloat16) |
| 290 | + |
| 291 | + offs_row = pid_m * TILE_M |
| 292 | + offs_col = pid_n * TILE_N |
| 293 | + d_desc.store([offs_row, offs_col], d) |
| 294 | + |
| 295 | + # store replicated values above diagonal. if skip_upper_triangle is True, we only store the values below the diagonal. |
| 296 | + if (IS_SQUARE_TILE and IS_BELOW_DIAG) or (not IS_SQUARE_TILE and not IS_ABOVE_DIAG): |
| 297 | + if not SKIP_UPPER_TRIANGLE: |
| 298 | + d_t_desc.store([offs_col, offs_row], d.T) |
| 299 | + |
| 300 | + |
| 301 | +def tsyrk_ex( |
| 302 | + a: torch.Tensor, c: torch.Tensor = None, alpha: float = 1.0, beta: float = 0.0, skip_upper_triangle: bool = False |
| 303 | +) -> torch.Tensor: |
| 304 | + """Triton implementation of bf16 syrk operation, following cuBLAS naming conventions with 't' denoting bf16. |
| 305 | +
|
| 306 | + Note: |
| 307 | + If beta != 0, then a must be a symmetric matrix (i.e., a == a.T) |
| 308 | +
|
| 309 | + Args: |
| 310 | + a: Input tensor of shape (N, K) |
| 311 | + c: None or symmetric input tensor of shape (N, N) |
| 312 | + alpha: Scaling factor for the matrix multiplication |
| 313 | + beta: Scaling factor for the matrix addition |
| 314 | + skip_upper_triangle: Whether to skip the upper triangle part of the output |
| 315 | +
|
| 316 | + Returns: |
| 317 | + Output tensor of shape (N, N) |
| 318 | + """ |
| 319 | + |
| 320 | + assert a.dim() == 2, "Input tensor must be 2D" |
| 321 | + assert a.is_contiguous() or a.T.is_contiguous(), "invalid input tensor layout. a or a.T must be contiguous." |
| 322 | + |
| 323 | + N, K = a.shape |
| 324 | + assert (c is None and beta == 0.0) or (c is not None and c.shape == (N, N)), ( |
| 325 | + "if c is provided, c must be of shape (N, N)" |
| 326 | + ) |
| 327 | + assert c is None or c.is_contiguous() or c.T.is_contiguous(), "if c is provided, c or c.T must be contiguous" |
| 328 | + |
| 329 | + d = torch.empty((N, N), device=a.device, dtype=a.dtype) |
| 330 | + |
| 331 | + dummy_block = [1, 1] |
| 332 | + |
| 333 | + is_trans = a.T.is_contiguous() |
| 334 | + |
| 335 | + if is_trans: |
| 336 | + # the descriptor relys on contiguous tensor to load the data |
| 337 | + a = a.T |
| 338 | + # descriptor to load [TILE_M, TILE_K] from a |
| 339 | + a_desc = TensorDescriptor(a, a.shape, a.stride(), dummy_block) |
| 340 | + # descriptor to load [TILE_K, TILE_N] from a.T |
| 341 | + a_t_desc = TensorDescriptor(a, a.shape, a.stride(), dummy_block) |
| 342 | + # descriptor to store [TILE_M, TILE_N] to d |
| 343 | + d_desc = TensorDescriptor(d, d.shape, d.stride(), dummy_block) |
| 344 | + # descriptor to store [TILE_M, TILE_N] to d.T |
| 345 | + d_t_desc = TensorDescriptor(d, d.shape, d.stride(), dummy_block) |
| 346 | + |
| 347 | + if beta != 0.0: |
| 348 | + c = c.T if c.T.is_contiguous() else c |
| 349 | + # descriptor to load [TILE_M, TILE_N] from a |
| 350 | + c_desc = TensorDescriptor(c, c.shape, c.stride(), dummy_block) |
| 351 | + else: |
| 352 | + c_desc = None |
| 353 | + |
| 354 | + def grid(META): |
| 355 | + return (triton.cdiv(N, META["TILE_M"]) * triton.cdiv(N, META["TILE_N"]),) |
| 356 | + |
| 357 | + syrk_kernel_bf16[grid]( |
| 358 | + d_desc, |
| 359 | + d_t_desc, |
| 360 | + a_desc, |
| 361 | + a_t_desc, |
| 362 | + c_desc, |
| 363 | + alpha, |
| 364 | + beta, |
| 365 | + skip_upper_triangle, |
| 366 | + is_trans, |
| 367 | + N, |
| 368 | + K, |
| 369 | + WARP_SPECIALIZE=False, |
| 370 | + ) |
| 371 | + return d |
0 commit comments