|
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
15 | 15 | # type: ignore |
| 16 | +import sys |
| 17 | + |
16 | 18 | import torch |
17 | 19 | import triton |
18 | 20 | import triton.language as tl |
| 21 | +from absl import logging |
19 | 22 |
|
20 | 23 |
|
21 | 24 | try: |
22 | 25 | from triton.tools.tensor_descriptor import TensorDescriptor |
23 | | -except ImportError: |
24 | | - raise ImportError( |
25 | | - f"Triton version ({triton.__version__}) doesn't support tensor descriptor API. Minimum required version is 3.4.0." |
26 | | - ) |
27 | | - |
28 | | - |
29 | | -__all__ = ["ssyrk", "tsyrk_ex"] |
30 | | - |
31 | | - |
32 | | -@triton.jit |
33 | | -def cvt_tf32_rn(x: tl.tensor) -> tl.tensor: |
34 | | - return tl.inline_asm_elementwise("cvt.rna.tf32.f32 $0, $1;", "=r, r", [x], dtype=tl.float32, is_pure=True, pack=1) |
35 | 26 |
|
| 27 | + HAS_TRITON_340 = True |
| 28 | +except ImportError: |
| 29 | + HAS_TRITON_340 = False |
36 | 30 |
|
37 | | -@triton.autotune( |
38 | | - configs=[ |
39 | | - triton.Config({"TILE_N": tn, "TILE_K": tk}, num_warps=nw, num_stages=ns) |
40 | | - for tn in (64, 128) |
41 | | - for tk in (16, 32, 64) |
42 | | - for nw in (4, 8) |
43 | | - for ns in (3, 4) |
44 | | - ], |
45 | | - key=["N", "K", "ALLOW_TF32"], |
46 | | -) |
47 | | -@triton.jit |
48 | | -def syrk_op_n_simple_kernel( |
49 | | - c_ptr, |
50 | | - a_ptr, |
51 | | - N: tl.constexpr, |
52 | | - K: tl.constexpr, |
53 | | - STRIDE_N: tl.constexpr, |
54 | | - STRIDE_K: tl.constexpr, |
55 | | - ALLOW_TF32: tl.constexpr, |
56 | | - TILE_N: tl.constexpr, |
57 | | - TILE_K: tl.constexpr, |
58 | | -): |
59 | | - # receives tensor of shape (N, K) |
60 | | - # computes A * A^T (-> produces NxN) |
61 | | - |
62 | | - pid_row = tl.program_id(0) |
63 | | - pid_col = tl.program_id(1) |
64 | | - |
65 | | - IS_BELOW_DIAG = pid_row < pid_col |
66 | | - IS_ABOVE_DIAG = pid_row > pid_col |
67 | | - |
68 | | - if IS_ABOVE_DIAG: |
69 | | - return |
70 | | - |
71 | | - offs_row = pid_row * TILE_N + tl.arange(0, TILE_N) |
72 | | - offs_col = pid_col * TILE_N + tl.arange(0, TILE_N) |
73 | | - offs_k = tl.arange(0, TILE_K) |
74 | | - |
75 | | - mask_row = offs_row < N |
76 | | - mask_col = offs_col < N |
77 | | - |
78 | | - a_ptrs_x = a_ptr + offs_row[:, None] * STRIDE_N + offs_k[None, :] * STRIDE_K |
79 | | - a_ptrs_y = a_ptr + offs_col[None, :] * STRIDE_N + offs_k[:, None] * STRIDE_K |
80 | | - |
81 | | - acc = tl.zeros((TILE_N, TILE_N), dtype=tl.float32) |
82 | | - |
83 | | - num_tiles_k = tl.cdiv(K, TILE_K) |
84 | | - for k in range(0, num_tiles_k): |
85 | | - mask_k = offs_k < K - k * TILE_K |
86 | | - mask_x = mask_row[:, None] & mask_k[None, :] |
87 | | - mask_y = mask_col[None, :] & mask_k[:, None] |
88 | | - x = tl.load(a_ptrs_x, mask=mask_x, other=0.0) |
89 | | - y = tl.load(a_ptrs_y, mask=mask_y, other=0.0) |
90 | | - |
91 | | - if ALLOW_TF32 == 0: |
92 | | - acc = tl.dot(x, y, acc=acc, input_precision="ieee") |
93 | | - elif ALLOW_TF32 == 1: |
94 | | - x = cvt_tf32_rn(x) |
95 | | - y = cvt_tf32_rn(y) |
96 | | - acc = tl.dot(x, y, acc=acc, input_precision="tf32") |
97 | | - else: |
98 | | - tl.static_assert(False, "Unsupported precision.") |
99 | | - |
100 | | - a_ptrs_x += TILE_K * STRIDE_K |
101 | | - a_ptrs_y += TILE_K * STRIDE_K |
102 | | - |
103 | | - # store diagonal or below diagonal values |
104 | | - c_ptrs = c_ptr + offs_row[:, None] * N + offs_col[None, :] |
105 | | - mask_c = mask_row[:, None] & mask_col[None, :] |
106 | | - tl.store(c_ptrs, acc, mask=mask_c) |
107 | | - |
108 | | - # store replicated values above diagonal |
109 | | - if IS_BELOW_DIAG: |
110 | | - c_ptrs_diag = c_ptr + offs_col[None, :] * N + offs_row[:, None] |
111 | | - tl.store(c_ptrs_diag, acc, mask=mask_c) |
112 | | - |
113 | | - |
114 | | -def ssyrk(a: torch.Tensor, trans: bool = False) -> torch.Tensor: |
115 | | - """Triton implementation of BLAS ssyrk operation. |
116 | | -
|
117 | | - Note: |
118 | | - This function assumes row major layout of the input tensor. |
119 | | -
|
120 | | - TODO(mstadler): Add support for alpha, beta and c. |
121 | | -
|
122 | | - Args: |
123 | | - a: Input tensor of shape (N, K) or (K, N) |
124 | | - trans: Whether to compute A * A^T (trans=False) or A^T * A (trans=True) |
125 | | -
|
126 | | - Returns: |
127 | | - Output tensor of shape (N, N) |
128 | | - """ |
129 | | - assert a.dim() == 2, "Input tensor must be 2D" |
130 | | - N, K = a.shape |
131 | | - if trans: |
132 | | - raise NotImplementedError("Transpose is not supported yet.") |
133 | | - |
134 | | - STRIDE_N = a.stride(0) |
135 | | - STRIDE_K = a.stride(1) |
136 | | - |
137 | | - if (fp32_matmul_prec := torch.get_float32_matmul_precision()) == "highest": |
138 | | - ALLOW_TF32 = 0 |
139 | | - elif fp32_matmul_prec == "high": |
140 | | - ALLOW_TF32 = 1 |
141 | | - else: |
142 | | - raise ValueError(f"Unsupported precision {fp32_matmul_prec}, only 'highest' and 'high' are supported.") |
143 | | - |
144 | | - c = torch.empty((N, N), dtype=a.dtype, device=a.device) |
145 | | - |
146 | | - def grid(META): |
147 | | - return (triton.cdiv(N, META["TILE_N"]), triton.cdiv(N, META["TILE_N"])) |
148 | | - |
149 | | - if not trans: |
150 | | - syrk_op_n_simple_kernel[grid](c, a, N, K, STRIDE_N, STRIDE_K, ALLOW_TF32) |
151 | 31 |
|
152 | | - return c |
| 32 | +__all__ = ["tsyrk_ex", "HAS_TRITON_340"] |
153 | 33 |
|
154 | 34 |
|
155 | 35 | def prune_invalid_configs(configs: list[triton.Config], named_args: dict, **kwargs) -> list[triton.Config]: |
@@ -199,23 +79,30 @@ def matmul_tma_set_block_size_hook(nargs: dict) -> None: |
199 | 79 | nargs["d_t_desc"].block_shape = [TILE_N, TILE_M] |
200 | 80 |
|
201 | 81 |
|
| 82 | +_CONFIGS = [ |
| 83 | + triton.Config( |
| 84 | + {"TILE_M": tm, "TILE_N": tn, "TILE_K": tk, "GROUP_SIZE_M": gm}, |
| 85 | + num_warps=nw, |
| 86 | + num_stages=ns, |
| 87 | + num_ctas=nc, |
| 88 | + pre_hook=matmul_tma_set_block_size_hook, |
| 89 | + ) |
| 90 | + for tm in (64, 128, 256) |
| 91 | + for tn in (64, 128, 256) |
| 92 | + for tk in (64, 128, 256) |
| 93 | + for gm in (2, 4, 8) |
| 94 | + for nw in (4, 8) |
| 95 | + for ns in (2, 3, 4) |
| 96 | + for nc in (1,) |
| 97 | +] |
| 98 | + |
| 99 | +if "absl.testing" in sys.modules.keys(): |
| 100 | + logging.warning("Running in absl.testing mode, disable autotune for triton.") |
| 101 | + _CONFIGS = _CONFIGS[:1] |
| 102 | + |
| 103 | + |
202 | 104 | @triton.autotune( |
203 | | - configs=[ |
204 | | - triton.Config( |
205 | | - {"TILE_M": tm, "TILE_N": tn, "TILE_K": tk, "GROUP_SIZE_M": gm}, |
206 | | - num_warps=nw, |
207 | | - num_stages=ns, |
208 | | - num_ctas=nc, |
209 | | - pre_hook=matmul_tma_set_block_size_hook, |
210 | | - ) |
211 | | - for tm in (64, 128, 256) |
212 | | - for tn in (64, 128, 256) |
213 | | - for tk in (64, 128, 256) |
214 | | - for gm in (2, 4, 8) |
215 | | - for nw in (4, 8) |
216 | | - for ns in (2, 3, 4) |
217 | | - for nc in (1,) |
218 | | - ], |
| 105 | + configs=_CONFIGS, |
219 | 106 | key=["N", "K", "TRANS", "WARP_SPECIALIZE"], |
220 | 107 | prune_configs_by={"early_config_prune": prune_invalid_configs}, |
221 | 108 | ) |
@@ -316,7 +203,7 @@ def tsyrk_ex( |
316 | 203 | Returns: |
317 | 204 | Output tensor of shape (N, N) |
318 | 205 | """ |
319 | | - |
| 206 | + assert a.dtype == torch.bfloat16, "Input tensor must be bfloat16" |
320 | 207 | assert a.dim() == 2, "Input tensor must be 2D" |
321 | 208 | assert a.is_contiguous() or a.T.is_contiguous(), "invalid input tensor layout. a or a.T must be contiguous." |
322 | 209 |
|
|
0 commit comments