Skip to content

Commit fcd5b01

Browse files
committed
port test over
Signed-off-by: Hao Wu <skyw@nvidia.com>
1 parent 52d6824 commit fcd5b01

File tree

14 files changed

+261
-100
lines changed

14 files changed

+261
-100
lines changed

.github/workflows/cicd-main.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ jobs:
4242
uses: ./.github/workflows/_build_container.yml
4343
needs: cicd-wait-in-queue
4444
with:
45-
image-name: llm_shower
45+
image-name: emerging_optimizers
4646
dockerfile: docker/Dockerfile.ci
4747
runner: self-hosted-nemo
4848
secrets:
@@ -72,7 +72,7 @@ jobs:
7272
script: ${{ matrix.script }}
7373
timeout: ${{ matrix.timeout || 10 }}
7474
is_unit_test: "true"
75-
image: llm_shower
75+
image: emerging_optimizers
7676
cpu-only: ${{ matrix.cpu-only || false }}
7777
has-azure-credentials: "true"
7878
azure-client-id: ${{ secrets.AZURE_CLIENT_ID }}
@@ -100,7 +100,7 @@ jobs:
100100
runner: ${{ runner.name }}
101101
script: ${{ matrix.script }}
102102
timeout: ${{ matrix.timeout || 10 }}
103-
image: llm_shower
103+
image: emerging_optimizers
104104
cpu-only: ${{ matrix.cpu-only || false }}
105105
has-azure-credentials: "true"
106106
azure-client-id: ${{ secrets.AZURE_CLIENT_ID }}

docker/Dockerfile.ci

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ ENV PATH="$UV_PROJECT_ENVIRONMENT/bin:$PATH"
2525

2626
WORKDIR /workspace
2727
RUN --mount=type=bind,source=pyproject.toml,target=/workspace/pyproject.toml \
28-
--mount=type=bind,source=llm_shower/__init__.py,target=/workspace/llm_shower/__init__.py \
29-
--mount=type=bind,source=llm_shower/package_info.py,target=/workspace/llm_shower/package_info.py \
28+
--mount=type=bind,source=emerging_optimizers/__init__.py,target=/workspace/emerging_optimizers/__init__.py \
29+
--mount=type=bind,source=emerging_optimizers/package_info.py,target=/workspace/emerging_optimizers/package_info.py \
3030
--mount=type=bind,source=uv.lock,target=/workspace/uv.lock bash -exu <<"EOF"
3131

3232
# Use the container's torch installation rather than reinstall it

emerging_optimizers/orthogonalized_optimizers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,5 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import *
1615
from emerging_optimizers.orthogonalized_optimizers.muon import *
16+
from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import *

emerging_optimizers/orthogonalized_optimizers/muon.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
from typing import Callable
1615
from functools import partial
16+
from typing import Callable
1717

1818
import torch
1919
from torch.optim.optimizer import ParamsT
2020

21-
from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import OrthogonalizedOptimizer, _args_doc
2221
from emerging_optimizers.orthogonalized_optimizers.muon_utils import newton_schulz
22+
from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import OrthogonalizedOptimizer, _args_doc
2323

2424

2525
class Muon(OrthogonalizedOptimizer):

emerging_optimizers/orthogonalized_optimizers/muon_utils.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import torch
1818
from absl import logging
1919

20+
2021
__all__ = ["newton_schulz", "newton_schulz_tp"]
2122

2223
_COEFFICIENT_SETS = {
@@ -216,3 +217,33 @@ def newton_schulz_tp(
216217
raise ValueError(f"Invalid mode: {mode}")
217218

218219
return output
220+
221+
222+
def newton_schulz_step(
223+
X: torch.Tensor, a: float, b: float, c: float, tp_group: torch.distributed.ProcessGroup | None = None
224+
) -> torch.Tensor:
225+
"""Perform a single Newton-Schulz iteration step.
226+
227+
This function performs a single Newton-Schulz iteration step. It supports distributed input that's sharded
228+
along the smaller (orthogonalize) dimension.
229+
230+
Warning:
231+
If distributed, this function doesn't have the information to verify that X is sharded along the smaller
232+
(orthogonalize) dimension. It is user's responsibility to ensure that X is sharded correctly.
233+
234+
Arguments:
235+
X: The tensor to be orthogonalized.
236+
a: The a coefficient.
237+
b: The b coefficient.
238+
c: The c coefficient.
239+
tp_group: The process group to use for the all-reduce.
240+
241+
Returns:
242+
The orthogonalization of X.
243+
"""
244+
A = X @ X.mT
245+
if tp_group is not None:
246+
torch.distributed.all_reduce(A, op=torch.distributed.ReduceOp.SUM, group=tp_group)
247+
B = torch.addmm(A, A, A, beta=b, alpha=c)
248+
X = torch.addmm(X, B, X, beta=a, alpha=1.0)
249+
return X

emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,9 @@
1414
# limitations under the License.
1515
from typing import Any, Callable, override
1616

17-
from absl import logging
18-
1917
import torch
2018
import torch.optim as optim
19+
from absl import logging
2120
from torch.optim.optimizer import ParamsT
2221

2322
from emerging_optimizers import utils

emerging_optimizers/utils/__init__.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,13 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
from .eig import *
16-
1715
from contextlib import contextmanager
1816
from typing import Generator
17+
1918
import torch
2019

21-
__all__ = [
22-
"fp32_matmul_precision", "get_pg_size", "get_pg_rank"
23-
]
20+
21+
__all__ = ["fp32_matmul_precision", "get_pg_size", "get_pg_rank"]
2422

2523

2624
@contextmanager

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,9 @@ line-ending = "auto"
194194
[tool.coverage.run]
195195
concurrency = ["thread", "multiprocessing"]
196196
omit = ["/tmp/*"]
197+
relative_files = true
198+
source = ["emerging_optimizers"]
199+
197200

198201
[tool.coverage.paths]
199-
source = ["llm_shower/", "/workspace/llm_shower"]
202+
source = ["emerging_optimizers/", "/workspace/emerging_optimizers"]
Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
import os
16+
17+
import numpy as np
18+
import torch
19+
from absl.testing import absltest, parameterized
20+
21+
from emerging_optimizers.orthogonalized_optimizers import muon_utils
22+
23+
24+
class DistributedNewtonSchulzStepCpuTest(parameterized.TestCase):
25+
def setUp(self):
26+
self.coefs = 3.4445, -4.7750, 2.0315
27+
28+
@parameterized.parameters(
29+
{"shape": (21, 16)},
30+
{"shape": (16, 32)},
31+
)
32+
def test_close_to_non_distributed(self, shape):
33+
x = torch.nn.functional.normalize(torch.randint(-5, 5, shape, device="cpu", dtype=torch.float32), dim=(-2, -1))
34+
# All-reduce ensures that every rank gets the same x
35+
torch.distributed.all_reduce(x, op=torch.distributed.ReduceOp.SUM)
36+
37+
world_size = torch.distributed.get_world_size()
38+
rank = torch.distributed.get_rank()
39+
local_x = x.chunk(world_size, dim=1)[rank]
40+
41+
dist_out = muon_utils.newton_schulz_step(local_x, *self.coefs, tp_group=torch.distributed.group.WORLD)
42+
43+
ref_out = muon_utils.newton_schulz_step(x, *self.coefs)
44+
45+
torch.testing.assert_close(ref_out.chunk(world_size, dim=1)[rank], dist_out)
46+
47+
@absltest.skipIf(int(os.environ.get("WORLD_SIZE", 1)) < 4, "test requires at least 2 ranks")
48+
@parameterized.product(
49+
shape=((21, 16), (16, 32)),
50+
tp_size=(2, 4),
51+
)
52+
def test_with_partial_tp(self, shape, tp_size):
53+
x = torch.nn.functional.normalize(torch.randint(-5, 5, shape, device="cpu", dtype=torch.float32), dim=(-2, -1))
54+
# All-reduce ensures that every rank gets the same x
55+
torch.distributed.all_reduce(x, op=torch.distributed.ReduceOp.SUM)
56+
57+
num_tp_groups = torch.distributed.get_world_size() // tp_size
58+
tp_group, _ = torch.distributed.new_subgroups_by_enumeration(
59+
np.split(np.arange(torch.distributed.get_world_size()), num_tp_groups)
60+
)
61+
assert tp_group.size() == tp_size
62+
local_x = x.chunk(tp_group.size(), dim=1)[tp_group.rank()]
63+
64+
dist_out = muon_utils.newton_schulz_step(local_x, *self.coefs, tp_group=tp_group)
65+
ref_out = muon_utils.newton_schulz_step(x, *self.coefs)
66+
torch.testing.assert_close(ref_out.chunk(tp_group.size(), dim=1)[tp_group.rank()], dist_out)
67+
68+
69+
class DistributedNewtonSchulzCpuTest(parameterized.TestCase):
70+
@parameterized.parameters(
71+
{"shape": (21, 16)},
72+
{"shape": (16, 32)},
73+
)
74+
def test_distributed_normalize_close_to_non_distributed(self, shape):
75+
x = torch.randint(-5, 5, shape, device="cpu", dtype=torch.float32)
76+
torch.distributed.all_reduce(x, op=torch.distributed.ReduceOp.SUM)
77+
78+
world_size = torch.distributed.get_world_size()
79+
rank = torch.distributed.get_rank()
80+
local_x = x.chunk(world_size, dim=1)[rank]
81+
82+
dist_out = muon_utils.distributed_normalize_p2(local_x, eps=1e-7, group=torch.distributed.group.WORLD)
83+
ref_out = torch.nn.functional.normalize(x, dim=(-2, -1), eps=1e-7)
84+
85+
torch.testing.assert_close(ref_out.chunk(world_size, dim=1)[rank], dist_out)
86+
87+
@parameterized.parameters(
88+
{"shape": (3, 32)},
89+
{"shape": (5, 100)},
90+
)
91+
def test_1step_close_to_non_distributed(self, shape):
92+
x = torch.randint(-5, 5, shape, device="cpu", dtype=torch.float32)
93+
# All-reduce ensures that every rank gets the same x
94+
torch.distributed.all_reduce(x, op=torch.distributed.ReduceOp.SUM)
95+
96+
world_size = torch.distributed.get_world_size()
97+
rank = torch.distributed.get_rank()
98+
local_x = x.chunk(world_size, dim=1)[rank]
99+
100+
dist_out = muon_utils.newton_schulz(
101+
local_x, steps=1, coefficient_type="simple", tp_group=torch.distributed.group.WORLD
102+
)
103+
ref_out = muon_utils.newton_schulz(x, steps=1, coefficient_type="simple")
104+
torch.testing.assert_close(ref_out.chunk(world_size, dim=1)[rank], dist_out)
105+
106+
@parameterized.parameters(
107+
{"shape": (32, 3), "transpose": True},
108+
{"shape": (5, 100), "transpose": False},
109+
)
110+
def test_5steps_with_transpose_close_to_non_distributed(self, shape, transpose):
111+
x = torch.randint(-5, 5, shape, device="cpu", dtype=torch.float32)
112+
# All-reduce ensures that every rank gets the same x
113+
torch.distributed.all_reduce(x, op=torch.distributed.ReduceOp.SUM)
114+
115+
world_size = torch.distributed.get_world_size()
116+
rank = torch.distributed.get_rank()
117+
118+
chunk_dim = 0 if transpose else 1
119+
local_x = x.chunk(world_size, dim=chunk_dim)[rank]
120+
121+
dist_out = muon_utils.newton_schulz(
122+
local_x, steps=5, tp_group=torch.distributed.group.WORLD, transpose=transpose
123+
)
124+
ref_out = muon_utils.newton_schulz(x, steps=5, transpose=transpose)
125+
torch.testing.assert_close(ref_out.chunk(world_size, dim=chunk_dim)[rank], dist_out)
126+
127+
@parameterized.parameters(
128+
{"shape": (32, 3), "transpose": True, "tp_size": 2},
129+
{"shape": (5, 100), "transpose": False, "tp_size": 4},
130+
)
131+
def test_1step_with_partial_tp_close_to_non_distributed(self, shape, transpose, tp_size):
132+
x = torch.randint(-5, 5, shape, device="cpu", dtype=torch.float32)
133+
# All-reduce ensures that every rank gets the same x
134+
torch.distributed.all_reduce(x, op=torch.distributed.ReduceOp.SUM)
135+
136+
num_tp_groups = torch.distributed.get_world_size() // tp_size
137+
tp_group, _ = torch.distributed.new_subgroups_by_enumeration(
138+
np.split(np.arange(torch.distributed.get_world_size()), num_tp_groups)
139+
)
140+
assert tp_group.size() == tp_size
141+
142+
chunk_dim = 0 if transpose else 1
143+
local_x = x.chunk(tp_group.size(), dim=chunk_dim)[tp_group.rank()]
144+
145+
dist_out = muon_utils.newton_schulz(
146+
local_x, steps=1, coefficient_type="simple", tp_group=tp_group, transpose=transpose
147+
)
148+
ref_out = muon_utils.newton_schulz(x, steps=1, coefficient_type="simple", transpose=transpose)
149+
torch.testing.assert_close(ref_out.chunk(tp_group.size(), dim=chunk_dim)[tp_group.rank()], dist_out)
150+
151+
152+
class TestTensorParallelNewtonSchulz(parameterized.TestCase):
153+
@parameterized.parameters(
154+
{"shape": (21, 16)},
155+
{"shape": (16, 32)},
156+
)
157+
def test_fall_back_to_non_tp(self, shape):
158+
x = torch.randint(-5, 5, shape, device="cpu", dtype=torch.float32)
159+
160+
test_out = muon_utils.newton_schulz_tp(
161+
x, steps=5, coefficient_type="quintic", partition_dim=None, tp_group=None
162+
)
163+
ref_out = muon_utils.newton_schulz(x, steps=5, coefficient_type="quintic")
164+
165+
torch.testing.assert_close(test_out, ref_out, atol=0, rtol=0)
166+
167+
@parameterized.product(
168+
shape=((20, 16), (16, 32)),
169+
partition_dim=(0, 1),
170+
mode=("distributed", "duplicated"),
171+
)
172+
def test_1step_close_to_non_distributed(self, shape, partition_dim, mode):
173+
if shape[partition_dim] % torch.distributed.get_world_size() != 0:
174+
self.skipTest("Skipping because incompatible shape and world size")
175+
x = torch.randint(-5, 5, shape, device="cpu", dtype=torch.float32)
176+
# All-reduce ensures that every rank gets the same x
177+
torch.distributed.all_reduce(x, op=torch.distributed.ReduceOp.SUM)
178+
179+
world_size = torch.distributed.get_world_size()
180+
rank = torch.distributed.get_rank()
181+
local_x = x.chunk(world_size, dim=partition_dim)[rank]
182+
183+
dist_out = muon_utils.newton_schulz_tp(
184+
local_x,
185+
steps=1,
186+
coefficient_type="simple",
187+
tp_group=torch.distributed.group.WORLD,
188+
partition_dim=partition_dim,
189+
mode=mode,
190+
)
191+
192+
ref_out = muon_utils.newton_schulz(x, steps=1, coefficient_type="simple")
193+
194+
torch.testing.assert_close(ref_out.chunk(world_size, dim=partition_dim)[rank], dist_out, atol=1e-6, rtol=0)
195+
196+
197+
if __name__ == "__main__":
198+
torch.distributed.init_process_group(backend="gloo")
199+
torch.set_float32_matmul_precision("highest")
200+
absltest.main()
201+
202+
torch.distributed.destroy_process_group()

tests/test_muon_utils.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,13 @@
1414
# limitations under the License.
1515
import math
1616

17-
from absl import logging
18-
from absl.testing import parameterized, absltest
19-
2017
import torch
18+
from absl import logging
19+
from absl.testing import absltest, parameterized
2120

22-
from llm_shower.orthogonalized_optimizers.muon_utils import newton_schulz, _COEFFICIENT_SETS
23-
from llm_shower.orthogonalized_optimizers.muon import Muon, get_muon_scale_factor
24-
from llm_shower import utils
21+
from emerging_optimizers import utils
22+
from emerging_optimizers.orthogonalized_optimizers.muon import Muon, get_muon_scale_factor
23+
from emerging_optimizers.orthogonalized_optimizers.muon_utils import _COEFFICIENT_SETS, newton_schulz
2524

2625

2726
def newton_schulz_ref(x: torch.Tensor, coefficient_sets: list[tuple[float, float, float]]) -> torch.Tensor:

0 commit comments

Comments
 (0)