Skip to content

Commit df73f86

Browse files
authored
Add option to use triton tsyrk_ex kernels in muon (#50)
Also fixed some testing related settings Signed-off-by: Hao Wu <[email protected]>
1 parent bc58ee0 commit df73f86

File tree

12 files changed

+240
-232
lines changed

12 files changed

+240
-232
lines changed

docker/Dockerfile.ci

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ RUN --mount=type=bind,source=pyproject.toml,target=/workspace/pyproject.toml \
3434
uv sync --link-mode symlink --locked --all-groups \
3535
--no-install-package absl-py \
3636
--no-install-package torch \
37-
--no-install-package triton \
3837
--no-install-package nvidia-cublas-cu12 \
3938
--no-install-package nvidia-cuda-cupti-cu12 \
4039
--no-install-package nvidia-cuda-nvrtc-cu12 \

emerging_optimizers/orthogonalized_optimizers/muon.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616
from typing import Callable
1717

1818
import torch
19+
from absl import logging
1920
from torch.optim.optimizer import ParamsT
2021

22+
from emerging_optimizers import triton_kernels
2123
from emerging_optimizers.orthogonalized_optimizers.muon_utils import newton_schulz
2224
from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import OrthogonalizedOptimizer, _args_doc
2325

@@ -56,6 +58,7 @@ class Muon(OrthogonalizedOptimizer):
5658
num_ns_steps: The number of iteration steps to use in the Newton-Schulz iteration.
5759
scale_mode: The type of scale factor to use for the update. Defaults to "spectral" style scaling.
5860
extra_scale_factor: The additional scale factor to use for the update.
61+
use_syrk: Whether to use the Triton kernel for the Newton-Schulz iteration.
5962
"""
6063

6164
def __init__(
@@ -74,11 +77,27 @@ def __init__(
7477
num_ns_steps: int = 5,
7578
scale_mode: str = "spectral",
7679
extra_scale_factor: float = 1.0,
80+
use_syrk: bool = False,
7781
) -> None:
7882
if num_ns_steps < 1:
7983
raise ValueError(f"num_ns_steps must be at least 1, got {num_ns_steps}")
8084

81-
orthogonalize_fn = partial(newton_schulz, steps=num_ns_steps, coefficient_type=coefficient_type)
85+
if use_syrk:
86+
if torch.cuda.is_available():
87+
sm_version = torch.cuda.get_device_capability()
88+
else:
89+
sm_version = (0, 0)
90+
if not triton_kernels.HAS_TRITON_340: # type: ignore[attr-defined]
91+
logging.error("Triton 3.4.0 or higher is required for use_syrk to be True.")
92+
use_syrk = False
93+
elif sm_version not in ((8, 0), (9, 0), (10, 0), (10, 3)):
94+
logging.error(
95+
f"Correctness of Triton kernel on SM {sm_version} cannot be guaranteed. Setting use_syrk to False."
96+
)
97+
use_syrk = False
98+
orthogonalize_fn = partial(
99+
newton_schulz, steps=num_ns_steps, coefficient_type=coefficient_type, use_syrk=use_syrk
100+
)
82101
scale_factor_fn = partial(get_muon_scale_factor, mode=scale_mode, extra_scale_factor=extra_scale_factor)
83102

84103
super().__init__(

emerging_optimizers/orthogonalized_optimizers/muon_utils.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import torch
1818
from absl import logging
1919

20+
from emerging_optimizers import triton_kernels
21+
2022

2123
__all__ = ["newton_schulz", "newton_schulz_tp"]
2224

@@ -70,6 +72,7 @@ def newton_schulz(
7072
eps: float = 1e-7,
7173
transpose: bool | None = None,
7274
tp_group: torch.distributed.ProcessGroup | None = None,
75+
use_syrk: bool = False,
7376
) -> torch.Tensor:
7477
"""Use Newton-Schulz iteration to compute the zeroth power / orthogonalization of x.
7578
@@ -97,6 +100,7 @@ def newton_schulz(
97100
transpose: Whether to transpose the tensor to perform whitening on the smaller dimension.
98101
If None, will be determined based on the size of the tensor.
99102
tp_group: The process group for communication if input is distributed.
103+
use_syrk: Whether to use the Triton kernel for the Newton-Schulz iteration.
100104
101105
Returns:
102106
The orthogonalization of x.
@@ -131,6 +135,7 @@ def newton_schulz(
131135
if steps % len(coefficient_sets) != 0:
132136
raise ValueError(f"steps ({steps}) must be multiple of len(coefficient_sets) ({len(coefficient_sets)}).")
133137

138+
ns_step_fn = newton_schulz_step
134139
# Perform the NS iterations
135140
if torch.get_float32_matmul_precision() == "medium":
136141
# PyTorch doesn't really have FP32 I/O BF16 compute kernels for precision "medium"
@@ -140,10 +145,12 @@ def newton_schulz(
140145
# is always in FP32.
141146
X = X.to(torch.bfloat16)
142147
logging.log_first_n(logging.INFO, "Using BF16 I/O kernels for Newton-Schulz iteration.", 1)
148+
if use_syrk:
149+
ns_step_fn = newton_schulz_step_tsyrk
143150

144151
for i in range(steps):
145152
a, b, c = coefficient_sets[i % len(coefficient_sets)]
146-
X = newton_schulz_step(X, a, b, c, tp_group=tp_group)
153+
X = ns_step_fn(X, a, b, c, tp_group=tp_group)
147154

148155
# Convert back to FP32. This is a noop if X is already in FP32.
149156
X = X.to(torch.float32)
@@ -244,6 +251,34 @@ def newton_schulz_step(
244251
A = X @ X.mT
245252
if tp_group is not None:
246253
torch.distributed.all_reduce(A, op=torch.distributed.ReduceOp.SUM, group=tp_group)
247-
B = torch.addmm(A, A, A, beta=b, alpha=c)
248-
X = torch.addmm(X, B, X, beta=a, alpha=1.0)
254+
B = torch.addmm(A, A, A, alpha=c, beta=b)
255+
X = torch.addmm(X, B, X, alpha=1.0, beta=a)
256+
return X
257+
258+
259+
def newton_schulz_step_tsyrk(
260+
X: torch.Tensor, a: float, b: float, c: float, tp_group: torch.distributed.ProcessGroup | None = None
261+
) -> torch.Tensor:
262+
"""Perform a single Newton-Schulz iteration step.
263+
264+
This function performs a single Newton-Schulz iteration step using the Triton kernel for extended syrk.
265+
266+
Arguments:
267+
X: The tensor to be orthogonalized. Must be bfloat16.
268+
a: The a coefficient.
269+
b: The b coefficient.
270+
c: The c coefficient.
271+
tp_group: The process group to use for the all-reduce.
272+
273+
Returns:
274+
The orthogonalization of X.
275+
"""
276+
assert triton_kernels.HAS_TRITON_340, ( # type: ignore[attr-defined]
277+
"Triton version doesn't support tensor descriptor API. Minimum required version is 3.4.0."
278+
)
279+
A = triton_kernels.tsyrk_ex(X) # type: ignore[attr-defined]
280+
if tp_group is not None:
281+
torch.distributed.all_reduce(A, op=torch.distributed.ReduceOp.SUM, group=tp_group)
282+
B = triton_kernels.tsyrk_ex(A, A, alpha=c, beta=b) # type: ignore[attr-defined]
283+
X = torch.addmm(X, B, X, alpha=1.0, beta=a)
249284
return X

emerging_optimizers/triton_kernels/syrk.py

Lines changed: 31 additions & 144 deletions
Original file line numberDiff line numberDiff line change
@@ -13,143 +13,23 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
# type: ignore
16+
import sys
17+
1618
import torch
1719
import triton
1820
import triton.language as tl
21+
from absl import logging
1922

2023

2124
try:
2225
from triton.tools.tensor_descriptor import TensorDescriptor
23-
except ImportError:
24-
raise ImportError(
25-
f"Triton version ({triton.__version__}) doesn't support tensor descriptor API. Minimum required version is 3.4.0."
26-
)
27-
28-
29-
__all__ = ["ssyrk", "tsyrk_ex"]
30-
31-
32-
@triton.jit
33-
def cvt_tf32_rn(x: tl.tensor) -> tl.tensor:
34-
return tl.inline_asm_elementwise("cvt.rna.tf32.f32 $0, $1;", "=r, r", [x], dtype=tl.float32, is_pure=True, pack=1)
3526

27+
HAS_TRITON_340 = True
28+
except ImportError:
29+
HAS_TRITON_340 = False
3630

37-
@triton.autotune(
38-
configs=[
39-
triton.Config({"TILE_N": tn, "TILE_K": tk}, num_warps=nw, num_stages=ns)
40-
for tn in (64, 128)
41-
for tk in (16, 32, 64)
42-
for nw in (4, 8)
43-
for ns in (3, 4)
44-
],
45-
key=["N", "K", "ALLOW_TF32"],
46-
)
47-
@triton.jit
48-
def syrk_op_n_simple_kernel(
49-
c_ptr,
50-
a_ptr,
51-
N: tl.constexpr,
52-
K: tl.constexpr,
53-
STRIDE_N: tl.constexpr,
54-
STRIDE_K: tl.constexpr,
55-
ALLOW_TF32: tl.constexpr,
56-
TILE_N: tl.constexpr,
57-
TILE_K: tl.constexpr,
58-
):
59-
# receives tensor of shape (N, K)
60-
# computes A * A^T (-> produces NxN)
61-
62-
pid_row = tl.program_id(0)
63-
pid_col = tl.program_id(1)
64-
65-
IS_BELOW_DIAG = pid_row < pid_col
66-
IS_ABOVE_DIAG = pid_row > pid_col
67-
68-
if IS_ABOVE_DIAG:
69-
return
70-
71-
offs_row = pid_row * TILE_N + tl.arange(0, TILE_N)
72-
offs_col = pid_col * TILE_N + tl.arange(0, TILE_N)
73-
offs_k = tl.arange(0, TILE_K)
74-
75-
mask_row = offs_row < N
76-
mask_col = offs_col < N
77-
78-
a_ptrs_x = a_ptr + offs_row[:, None] * STRIDE_N + offs_k[None, :] * STRIDE_K
79-
a_ptrs_y = a_ptr + offs_col[None, :] * STRIDE_N + offs_k[:, None] * STRIDE_K
80-
81-
acc = tl.zeros((TILE_N, TILE_N), dtype=tl.float32)
82-
83-
num_tiles_k = tl.cdiv(K, TILE_K)
84-
for k in range(0, num_tiles_k):
85-
mask_k = offs_k < K - k * TILE_K
86-
mask_x = mask_row[:, None] & mask_k[None, :]
87-
mask_y = mask_col[None, :] & mask_k[:, None]
88-
x = tl.load(a_ptrs_x, mask=mask_x, other=0.0)
89-
y = tl.load(a_ptrs_y, mask=mask_y, other=0.0)
90-
91-
if ALLOW_TF32 == 0:
92-
acc = tl.dot(x, y, acc=acc, input_precision="ieee")
93-
elif ALLOW_TF32 == 1:
94-
x = cvt_tf32_rn(x)
95-
y = cvt_tf32_rn(y)
96-
acc = tl.dot(x, y, acc=acc, input_precision="tf32")
97-
else:
98-
tl.static_assert(False, "Unsupported precision.")
99-
100-
a_ptrs_x += TILE_K * STRIDE_K
101-
a_ptrs_y += TILE_K * STRIDE_K
102-
103-
# store diagonal or below diagonal values
104-
c_ptrs = c_ptr + offs_row[:, None] * N + offs_col[None, :]
105-
mask_c = mask_row[:, None] & mask_col[None, :]
106-
tl.store(c_ptrs, acc, mask=mask_c)
107-
108-
# store replicated values above diagonal
109-
if IS_BELOW_DIAG:
110-
c_ptrs_diag = c_ptr + offs_col[None, :] * N + offs_row[:, None]
111-
tl.store(c_ptrs_diag, acc, mask=mask_c)
112-
113-
114-
def ssyrk(a: torch.Tensor, trans: bool = False) -> torch.Tensor:
115-
"""Triton implementation of BLAS ssyrk operation.
116-
117-
Note:
118-
This function assumes row major layout of the input tensor.
119-
120-
TODO(mstadler): Add support for alpha, beta and c.
121-
122-
Args:
123-
a: Input tensor of shape (N, K) or (K, N)
124-
trans: Whether to compute A * A^T (trans=False) or A^T * A (trans=True)
125-
126-
Returns:
127-
Output tensor of shape (N, N)
128-
"""
129-
assert a.dim() == 2, "Input tensor must be 2D"
130-
N, K = a.shape
131-
if trans:
132-
raise NotImplementedError("Transpose is not supported yet.")
133-
134-
STRIDE_N = a.stride(0)
135-
STRIDE_K = a.stride(1)
136-
137-
if (fp32_matmul_prec := torch.get_float32_matmul_precision()) == "highest":
138-
ALLOW_TF32 = 0
139-
elif fp32_matmul_prec == "high":
140-
ALLOW_TF32 = 1
141-
else:
142-
raise ValueError(f"Unsupported precision {fp32_matmul_prec}, only 'highest' and 'high' are supported.")
143-
144-
c = torch.empty((N, N), dtype=a.dtype, device=a.device)
145-
146-
def grid(META):
147-
return (triton.cdiv(N, META["TILE_N"]), triton.cdiv(N, META["TILE_N"]))
148-
149-
if not trans:
150-
syrk_op_n_simple_kernel[grid](c, a, N, K, STRIDE_N, STRIDE_K, ALLOW_TF32)
15131

152-
return c
32+
__all__ = ["tsyrk_ex", "HAS_TRITON_340"]
15333

15434

15535
def prune_invalid_configs(configs: list[triton.Config], named_args: dict, **kwargs) -> list[triton.Config]:
@@ -199,23 +79,30 @@ def matmul_tma_set_block_size_hook(nargs: dict) -> None:
19979
nargs["d_t_desc"].block_shape = [TILE_N, TILE_M]
20080

20181

82+
_CONFIGS = [
83+
triton.Config(
84+
{"TILE_M": tm, "TILE_N": tn, "TILE_K": tk, "GROUP_SIZE_M": gm},
85+
num_warps=nw,
86+
num_stages=ns,
87+
num_ctas=nc,
88+
pre_hook=matmul_tma_set_block_size_hook,
89+
)
90+
for tm in (64, 128, 256)
91+
for tn in (64, 128, 256)
92+
for tk in (64, 128, 256)
93+
for gm in (2, 4, 8)
94+
for nw in (4, 8)
95+
for ns in (2, 3, 4)
96+
for nc in (1,)
97+
]
98+
99+
if "absl.testing" in sys.modules.keys():
100+
logging.warning("Running in absl.testing mode, disable autotune for triton.")
101+
_CONFIGS = _CONFIGS[:1]
102+
103+
202104
@triton.autotune(
203-
configs=[
204-
triton.Config(
205-
{"TILE_M": tm, "TILE_N": tn, "TILE_K": tk, "GROUP_SIZE_M": gm},
206-
num_warps=nw,
207-
num_stages=ns,
208-
num_ctas=nc,
209-
pre_hook=matmul_tma_set_block_size_hook,
210-
)
211-
for tm in (64, 128, 256)
212-
for tn in (64, 128, 256)
213-
for tk in (64, 128, 256)
214-
for gm in (2, 4, 8)
215-
for nw in (4, 8)
216-
for ns in (2, 3, 4)
217-
for nc in (1,)
218-
],
105+
configs=_CONFIGS,
219106
key=["N", "K", "TRANS", "WARP_SPECIALIZE"],
220107
prune_configs_by={"early_config_prune": prune_invalid_configs},
221108
)
@@ -316,7 +203,7 @@ def tsyrk_ex(
316203
Returns:
317204
Output tensor of shape (N, N)
318205
"""
319-
206+
assert a.dtype == torch.bfloat16, "Input tensor must be bfloat16"
320207
assert a.dim() == 2, "Input tensor must be 2D"
321208
assert a.is_contiguous() or a.T.is_contiguous(), "invalid input tensor layout. a or a.T must be contiguous."
322209

pyproject.toml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,13 @@ test = [
7979
"coverage>=7.8.1",
8080
"flake8>=7.2.0",
8181
"pylint>=3.3.7",
82+
"triton>=3.4.0",
8283
]
8384
dev = [
8485
"pre-commit>=3.6.0",
8586
"ruff>=0.9.9",
8687
"mypy>=1.8.0",
88+
"triton>=3.4.0",
8789
]
8890

8991
[tool.uv]
@@ -169,6 +171,10 @@ omit = ["/tmp/*"]
169171
relative_files = true
170172
source = ["emerging_optimizers"]
171173

172-
173174
[tool.coverage.paths]
174175
source = ["emerging_optimizers/", "/workspace/emerging_optimizers"]
176+
177+
[too.coverage.report]
178+
exclude_also = [
179+
"@triton"
180+
]

tests/ci/L0_Tests_CPU.sh

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
export TORCH_COMPILE_DISABLE=1
15-
set -o pipefail
16-
torchrun --nproc_per_node=8 --no-python coverage run -p tests/test_distributed_muon_utils_cpu.py
17-
torchrun --nproc_per_node=4 --no-python coverage run -p tests/test_distributed_muon_utils_cpu.py
18-
coverage run -p --source=emerging_optimizers tests/test_scalar_optimizers.py --device=cpu
19-
coverage run -p --source=emerging_optimizers tests/test_procrustes_step.py --device=cpu
15+
16+
error=0
17+
torchrun --nproc_per_node=8 --no-python coverage run -p tests/test_distributed_muon_utils_cpu.py || error=1
18+
torchrun --nproc_per_node=4 --no-python coverage run -p tests/test_distributed_muon_utils_cpu.py || error=1
19+
coverage run -p --source=emerging_optimizers tests/test_scalar_optimizers.py --device=cpu || error=1
20+
coverage run -p --source=emerging_optimizers tests/test_procrustes_step.py --device=cpu || error=1
21+
22+
exit "${error}"

0 commit comments

Comments
 (0)