Skip to content

Commit 06d0bca

Browse files
authored
Add SYRK trition kernel for Muon optimizer (#47)
* add syrk Signed-off-by: gdeng <[email protected]>
1 parent 5679362 commit 06d0bca

File tree

5 files changed

+474
-2
lines changed

5 files changed

+474
-2
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
from .syrk import *
Lines changed: 371 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,371 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# type: ignore
16+
import torch
17+
import triton
18+
import triton.language as tl
19+
20+
21+
try:
22+
from triton.tools.tensor_descriptor import TensorDescriptor
23+
except ImportError:
24+
raise ImportError(
25+
f"Triton version ({triton.__version__}) doesn't support tensor descriptor API. Minimum required version is 3.4.0."
26+
)
27+
28+
29+
__all__ = ["ssyrk", "tsyrk_ex"]
30+
31+
32+
@triton.jit
33+
def cvt_tf32_rn(x: tl.tensor) -> tl.tensor:
34+
return tl.inline_asm_elementwise("cvt.rna.tf32.f32 $0, $1;", "=r, r", [x], dtype=tl.float32, is_pure=True, pack=1)
35+
36+
37+
@triton.autotune(
38+
configs=[
39+
triton.Config({"TILE_N": tn, "TILE_K": tk}, num_warps=nw, num_stages=ns)
40+
for tn in (64, 128)
41+
for tk in (16, 32, 64)
42+
for nw in (4, 8)
43+
for ns in (3, 4)
44+
],
45+
key=["N", "K", "ALLOW_TF32"],
46+
)
47+
@triton.jit
48+
def syrk_op_n_simple_kernel(
49+
c_ptr,
50+
a_ptr,
51+
N: tl.constexpr,
52+
K: tl.constexpr,
53+
STRIDE_N: tl.constexpr,
54+
STRIDE_K: tl.constexpr,
55+
ALLOW_TF32: tl.constexpr,
56+
TILE_N: tl.constexpr,
57+
TILE_K: tl.constexpr,
58+
):
59+
# receives tensor of shape (N, K)
60+
# computes A * A^T (-> produces NxN)
61+
62+
pid_row = tl.program_id(0)
63+
pid_col = tl.program_id(1)
64+
65+
IS_BELOW_DIAG = pid_row < pid_col
66+
IS_ABOVE_DIAG = pid_row > pid_col
67+
68+
if IS_ABOVE_DIAG:
69+
return
70+
71+
offs_row = pid_row * TILE_N + tl.arange(0, TILE_N)
72+
offs_col = pid_col * TILE_N + tl.arange(0, TILE_N)
73+
offs_k = tl.arange(0, TILE_K)
74+
75+
mask_row = offs_row < N
76+
mask_col = offs_col < N
77+
78+
a_ptrs_x = a_ptr + offs_row[:, None] * STRIDE_N + offs_k[None, :] * STRIDE_K
79+
a_ptrs_y = a_ptr + offs_col[None, :] * STRIDE_N + offs_k[:, None] * STRIDE_K
80+
81+
acc = tl.zeros((TILE_N, TILE_N), dtype=tl.float32)
82+
83+
num_tiles_k = tl.cdiv(K, TILE_K)
84+
for k in range(0, num_tiles_k):
85+
mask_k = offs_k < K - k * TILE_K
86+
mask_x = mask_row[:, None] & mask_k[None, :]
87+
mask_y = mask_col[None, :] & mask_k[:, None]
88+
x = tl.load(a_ptrs_x, mask=mask_x, other=0.0)
89+
y = tl.load(a_ptrs_y, mask=mask_y, other=0.0)
90+
91+
if ALLOW_TF32 == 0:
92+
acc = tl.dot(x, y, acc=acc, input_precision="ieee")
93+
elif ALLOW_TF32 == 1:
94+
x = cvt_tf32_rn(x)
95+
y = cvt_tf32_rn(y)
96+
acc = tl.dot(x, y, acc=acc, input_precision="tf32")
97+
else:
98+
tl.static_assert(False, "Unsupported precision.")
99+
100+
a_ptrs_x += TILE_K * STRIDE_K
101+
a_ptrs_y += TILE_K * STRIDE_K
102+
103+
# store diagonal or below diagonal values
104+
c_ptrs = c_ptr + offs_row[:, None] * N + offs_col[None, :]
105+
mask_c = mask_row[:, None] & mask_col[None, :]
106+
tl.store(c_ptrs, acc, mask=mask_c)
107+
108+
# store replicated values above diagonal
109+
if IS_BELOW_DIAG:
110+
c_ptrs_diag = c_ptr + offs_col[None, :] * N + offs_row[:, None]
111+
tl.store(c_ptrs_diag, acc, mask=mask_c)
112+
113+
114+
def ssyrk(a: torch.Tensor, trans: bool = False) -> torch.Tensor:
115+
"""Triton implementation of BLAS ssyrk operation.
116+
117+
Note:
118+
This function assumes row major layout of the input tensor.
119+
120+
TODO(mstadler): Add support for alpha, beta and c.
121+
122+
Args:
123+
a: Input tensor of shape (N, K) or (K, N)
124+
trans: Whether to compute A * A^T (trans=False) or A^T * A (trans=True)
125+
126+
Returns:
127+
Output tensor of shape (N, N)
128+
"""
129+
assert a.dim() == 2, "Input tensor must be 2D"
130+
N, K = a.shape
131+
if trans:
132+
raise NotImplementedError("Transpose is not supported yet.")
133+
134+
STRIDE_N = a.stride(0)
135+
STRIDE_K = a.stride(1)
136+
137+
if (fp32_matmul_prec := torch.get_float32_matmul_precision()) == "highest":
138+
ALLOW_TF32 = 0
139+
elif fp32_matmul_prec == "high":
140+
ALLOW_TF32 = 1
141+
else:
142+
raise ValueError(f"Unsupported precision {fp32_matmul_prec}, only 'highest' and 'high' are supported.")
143+
144+
c = torch.empty((N, N), dtype=a.dtype, device=a.device)
145+
146+
def grid(META):
147+
return (triton.cdiv(N, META["TILE_N"]), triton.cdiv(N, META["TILE_N"]))
148+
149+
if not trans:
150+
syrk_op_n_simple_kernel[grid](c, a, N, K, STRIDE_N, STRIDE_K, ALLOW_TF32)
151+
152+
return c
153+
154+
155+
def prune_invalid_configs(configs: list[triton.Config], named_args: dict, **kwargs) -> list[triton.Config]:
156+
"""Prune invalid Triton kernel configs based on input size and tile parameters.
157+
158+
Args:
159+
configs: List of Triton kernel configs.
160+
named_args: Named arguments for the kernel.
161+
**kwargs: Additional keyword arguments.
162+
163+
Returns:
164+
List of valid Triton kernel configs.
165+
"""
166+
N = named_args["N"]
167+
168+
conf = []
169+
for c in configs:
170+
TILE_M = c.kwargs.get("TILE_M", 0)
171+
TILE_N = c.kwargs.get("TILE_N", 0)
172+
TILE_K = c.kwargs.get("TILE_K", 0)
173+
174+
# 5000 is an empirically determined threshold from size shmoo to select the best config
175+
if N >= 5000:
176+
if TILE_M == 128 and TILE_N == 256 and TILE_K == 64:
177+
conf.append(c)
178+
else:
179+
if TILE_M <= 128 and TILE_N >= TILE_M and TILE_K <= 128:
180+
conf.append(c)
181+
return conf
182+
183+
184+
def matmul_tma_set_block_size_hook(nargs: dict) -> None:
185+
"""Sets the block shapes for tensor descriptors based on tile sizes.
186+
187+
Args:
188+
nargs: Named arguments for the kernel.
189+
"""
190+
TILE_M = nargs["TILE_M"]
191+
TILE_N = nargs["TILE_N"]
192+
TILE_K = nargs["TILE_K"]
193+
TRANS = nargs["TRANS"]
194+
nargs["a_desc"].block_shape = [TILE_K, TILE_M] if TRANS else [TILE_M, TILE_K]
195+
nargs["a_t_desc"].block_shape = [TILE_K, TILE_N] if TRANS else [TILE_N, TILE_K]
196+
if nargs["c_desc"] is not None:
197+
nargs["c_desc"].block_shape = [TILE_M, TILE_N]
198+
nargs["d_desc"].block_shape = [TILE_M, TILE_N]
199+
nargs["d_t_desc"].block_shape = [TILE_N, TILE_M]
200+
201+
202+
@triton.autotune(
203+
configs=[
204+
triton.Config(
205+
{"TILE_M": tm, "TILE_N": tn, "TILE_K": tk, "GROUP_SIZE_M": gm},
206+
num_warps=nw,
207+
num_stages=ns,
208+
num_ctas=nc,
209+
pre_hook=matmul_tma_set_block_size_hook,
210+
)
211+
for tm in (64, 128, 256)
212+
for tn in (64, 128, 256)
213+
for tk in (64, 128, 256)
214+
for gm in (2, 4, 8)
215+
for nw in (4, 8)
216+
for ns in (2, 3, 4)
217+
for nc in (1,)
218+
],
219+
key=["N", "K", "TRANS", "WARP_SPECIALIZE"],
220+
prune_configs_by={"early_config_prune": prune_invalid_configs},
221+
)
222+
@triton.jit
223+
def syrk_kernel_bf16(
224+
d_desc,
225+
d_t_desc,
226+
a_desc,
227+
a_t_desc,
228+
c_desc,
229+
alpha: tl.constexpr,
230+
beta: tl.constexpr,
231+
SKIP_UPPER_TRIANGLE: tl.constexpr,
232+
TRANS: tl.constexpr,
233+
N: tl.constexpr,
234+
K: tl.constexpr,
235+
TILE_M: tl.constexpr,
236+
TILE_N: tl.constexpr,
237+
TILE_K: tl.constexpr,
238+
GROUP_SIZE_M: tl.constexpr,
239+
WARP_SPECIALIZE: tl.constexpr,
240+
):
241+
# input A tensor of shape (N, K)
242+
# computes D = alpha * A * A^T + beta * C (-> produces NxN)
243+
# NOTE: If beta != 0, then C must be a symmetric matrix (i.e., C == C^T)
244+
245+
pid = tl.program_id(axis=0)
246+
num_pid_m = tl.cdiv(N, TILE_M)
247+
num_pid_n = tl.cdiv(N, TILE_N)
248+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
249+
group_id = pid // num_pid_in_group
250+
first_pid_m = group_id * GROUP_SIZE_M
251+
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
252+
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
253+
pid_n = (pid % num_pid_in_group) // group_size_m
254+
255+
IS_BELOW_DIAG = pid_m * TILE_M >= pid_n * TILE_N + TILE_N
256+
IS_ABOVE_DIAG = pid_m * TILE_M + TILE_M <= pid_n * TILE_N
257+
IS_SQUARE_TILE = TILE_M == TILE_N
258+
259+
if IS_ABOVE_DIAG:
260+
return
261+
262+
# hints for the compiler
263+
tl.assume(pid_m >= 0)
264+
tl.assume(pid_n >= 0)
265+
266+
offs_row = pid_m * TILE_M
267+
offs_col = pid_n * TILE_N
268+
269+
acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
270+
271+
num_tiles_k = tl.cdiv(K, TILE_K)
272+
for k in tl.range(num_tiles_k, warp_specialize=WARP_SPECIALIZE):
273+
offs_k = k * TILE_K
274+
if TRANS:
275+
x = a_desc.load([offs_k, offs_row])
276+
y = a_t_desc.load([offs_k, offs_col])
277+
acc = tl.dot(x.T, y, acc=acc)
278+
else:
279+
x = a_desc.load([offs_row, offs_k])
280+
y = a_t_desc.load([offs_col, offs_k])
281+
acc = tl.dot(x, y.T, acc=acc)
282+
283+
if alpha != 1.0:
284+
acc = alpha * acc
285+
if beta != 0.0:
286+
z = c_desc.load([offs_row, offs_col]).to(tl.float32)
287+
acc = beta * z + acc
288+
289+
d = acc.to(tl.bfloat16)
290+
291+
offs_row = pid_m * TILE_M
292+
offs_col = pid_n * TILE_N
293+
d_desc.store([offs_row, offs_col], d)
294+
295+
# store replicated values above diagonal. if skip_upper_triangle is True, we only store the values below the diagonal.
296+
if (IS_SQUARE_TILE and IS_BELOW_DIAG) or (not IS_SQUARE_TILE and not IS_ABOVE_DIAG):
297+
if not SKIP_UPPER_TRIANGLE:
298+
d_t_desc.store([offs_col, offs_row], d.T)
299+
300+
301+
def tsyrk_ex(
302+
a: torch.Tensor, c: torch.Tensor = None, alpha: float = 1.0, beta: float = 0.0, skip_upper_triangle: bool = False
303+
) -> torch.Tensor:
304+
"""Triton implementation of bf16 syrk operation, following cuBLAS naming conventions with 't' denoting bf16.
305+
306+
Note:
307+
If beta != 0, then a must be a symmetric matrix (i.e., a == a.T)
308+
309+
Args:
310+
a: Input tensor of shape (N, K)
311+
c: None or symmetric input tensor of shape (N, N)
312+
alpha: Scaling factor for the matrix multiplication
313+
beta: Scaling factor for the matrix addition
314+
skip_upper_triangle: Whether to skip the upper triangle part of the output
315+
316+
Returns:
317+
Output tensor of shape (N, N)
318+
"""
319+
320+
assert a.dim() == 2, "Input tensor must be 2D"
321+
assert a.is_contiguous() or a.T.is_contiguous(), "invalid input tensor layout. a or a.T must be contiguous."
322+
323+
N, K = a.shape
324+
assert (c is None and beta == 0.0) or (c is not None and c.shape == (N, N)), (
325+
"if c is provided, c must be of shape (N, N)"
326+
)
327+
assert c is None or c.is_contiguous() or c.T.is_contiguous(), "if c is provided, c or c.T must be contiguous"
328+
329+
d = torch.empty((N, N), device=a.device, dtype=a.dtype)
330+
331+
dummy_block = [1, 1]
332+
333+
is_trans = a.T.is_contiguous()
334+
335+
if is_trans:
336+
# the descriptor relys on contiguous tensor to load the data
337+
a = a.T
338+
# descriptor to load [TILE_M, TILE_K] from a
339+
a_desc = TensorDescriptor(a, a.shape, a.stride(), dummy_block)
340+
# descriptor to load [TILE_K, TILE_N] from a.T
341+
a_t_desc = TensorDescriptor(a, a.shape, a.stride(), dummy_block)
342+
# descriptor to store [TILE_M, TILE_N] to d
343+
d_desc = TensorDescriptor(d, d.shape, d.stride(), dummy_block)
344+
# descriptor to store [TILE_M, TILE_N] to d.T
345+
d_t_desc = TensorDescriptor(d, d.shape, d.stride(), dummy_block)
346+
347+
if beta != 0.0:
348+
c = c.T if c.T.is_contiguous() else c
349+
# descriptor to load [TILE_M, TILE_N] from a
350+
c_desc = TensorDescriptor(c, c.shape, c.stride(), dummy_block)
351+
else:
352+
c_desc = None
353+
354+
def grid(META):
355+
return (triton.cdiv(N, META["TILE_M"]) * triton.cdiv(N, META["TILE_N"]),)
356+
357+
syrk_kernel_bf16[grid](
358+
d_desc,
359+
d_t_desc,
360+
a_desc,
361+
a_t_desc,
362+
c_desc,
363+
alpha,
364+
beta,
365+
skip_upper_triangle,
366+
is_trans,
367+
N,
368+
K,
369+
WARP_SPECIALIZE=False,
370+
)
371+
return d

tests/ci/L0_Tests_GPU.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ coverage run -p --source=emerging_optimizers tests/soap_smoke_test.py
2222
coverage run -p --source=emerging_optimizers tests/soap_mnist_test.py
2323
coverage run -p --source=emerging_optimizers tests/test_scalar_optimizers.py --device=cuda
2424
coverage run -p --source=emerging_optimizers tests/test_spectral_clipping_utils.py
25+
coverage run -p --source=emerging_optimizers tests/test_triton_kernels.py TritonKernelsIntegerInputTest
2526
coverage run -p --source=emerging_optimizers tests/test_normalized_optimizer.py --device=cuda
2627
coverage run -p --source=emerging_optimizers tests/normalized_optimizer_convergence_test.py --device=cuda
2728
coverage run -p --source=emerging_optimizers tests/test_psgd_contractions.py --device=cuda
28-
coverage run -p --source=emerging_optimizers tests/test_psgd_utils.py --device=cuda
29+
coverage run -p --source=emerging_optimizers tests/test_psgd_utils.py --device=cuda

0 commit comments

Comments
 (0)