Skip to content

Commit 8070fa6

Browse files
committed
reorder soap tests
Signed-off-by: Hao Wu <skyw@nvidia.com>
1 parent 4592326 commit 8070fa6

File tree

2 files changed

+229
-254
lines changed

2 files changed

+229
-254
lines changed

tests/test_soap.py

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,239 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
import math
16+
from typing import Any
17+
1518
import torch
1619
from absl.testing import absltest, parameterized
1720

1821
from emerging_optimizers.soap import soap
22+
from emerging_optimizers.soap.soap import (
23+
_clip_update_rms_in_place,
24+
_get_precondition_frequency,
25+
_is_eigenbasis_update_step,
26+
)
27+
from emerging_optimizers.utils.precondition_schedules import LinearSchedule
28+
29+
30+
class SoapFunctionsTest(parameterized.TestCase):
31+
def test_init_preconditioner_multidim_tensor_shapes(self) -> None:
32+
"""Tests init_preconditioner with a multi-dimensional tensor."""
33+
grad = torch.randn(3, 4, 5)
34+
state: dict[str, Any] = {}
35+
state["GG"] = soap.init_kronecker_factors(grad, precondition_1d=False)
36+
self.assertEqual(len(state["GG"]), grad.dim())
37+
self.assertEqual(state["GG"][0].shape, (3, 3))
38+
self.assertEqual(state["GG"][1].shape, (4, 4))
39+
self.assertEqual(state["GG"][2].shape, (5, 5))
40+
41+
@parameterized.parameters(
42+
(1,),
43+
(2,),
44+
(3,),
45+
)
46+
def test_adam_warmup_steps(self, adam_warmup_steps: int) -> None:
47+
"""Tests that adam_warmup_steps causes state["Q"] to be None until the specified steps are completed."""
48+
49+
param = torch.randn(5, 3, requires_grad=True, device="cuda")
50+
51+
optimizer = soap.SOAP(
52+
[param],
53+
lr=0.001,
54+
weight_decay=0.01,
55+
adam_warmup_steps=adam_warmup_steps,
56+
precondition_frequency=1,
57+
)
58+
59+
dummy_Q = [torch.eye(shape, device=param.device) for shape in param.shape]
60+
for step in range(adam_warmup_steps - 1):
61+
param.grad = torch.randn_like(param)
62+
optimizer.step()
63+
state = optimizer.state[param]
64+
65+
torch.testing.assert_close(
66+
state["Q"], dummy_Q, atol=0, rtol=0, msg=f"Q should stay identity at step {step}"
67+
)
68+
69+
for step in range(adam_warmup_steps - 1, adam_warmup_steps + 3):
70+
param.grad = torch.randn_like(param)
71+
optimizer.step()
72+
state = optimizer.state[param]
73+
74+
# Verify Q has the right shape (a list with tensors for each dim)
75+
self.assertIsInstance(state["Q"], list)
76+
self.assertEqual(len(state["Q"]), param.dim())
77+
# Verify Q has the right shape (a list with square eigenvector matrices for each dim)
78+
self.assertEqual(state["Q"][0].shape, (5, 5))
79+
self.assertEqual(state["Q"][1].shape, (3, 3))
80+
81+
def test_update_kronecker_factors(self) -> None:
82+
max_dim = 8
83+
shampoo_beta = 0.9
84+
dim0, dim1, dim2 = 3, max_dim + 2, 5
85+
grad = torch.randn(dim0, dim1, dim2)
86+
87+
# Initialize factors
88+
initial_factors = soap.init_kronecker_factors(grad, precondition_1d=False)
89+
90+
kronecker_factors = [f.clone() for f in initial_factors]
91+
92+
soap.update_kronecker_factors(
93+
kronecker_factor_list=kronecker_factors,
94+
grad=grad,
95+
shampoo_beta=shampoo_beta,
96+
precondition_1d=False,
97+
)
98+
99+
self.assertEqual(len(kronecker_factors), grad.dim())
100+
101+
contract_dims_0 = [1, 2]
102+
outer_product_0 = torch.tensordot(grad, grad, dims=[contract_dims_0] * 2)
103+
expected_factor_0 = initial_factors[0] * shampoo_beta + outer_product_0 * (1 - shampoo_beta)
104+
torch.testing.assert_close(kronecker_factors[0], expected_factor_0, atol=1e-6, rtol=1e-6)
105+
106+
contract_dims_2 = [0, 1]
107+
outer_product_2 = torch.tensordot(grad, grad, dims=[contract_dims_2] * 2)
108+
expected_factor_2 = initial_factors[2] * shampoo_beta + outer_product_2 * (1 - shampoo_beta)
109+
torch.testing.assert_close(kronecker_factors[2], expected_factor_2, atol=1e-6, rtol=1e-6)
110+
111+
@parameterized.parameters(
112+
(4, 5),
113+
(3, 3),
114+
(5, 4),
115+
)
116+
def test_tensordot_vs_matmul(self, m, n):
117+
# Create tensors with random eigenvectors for rotation matrices QL and QR
118+
grad = torch.randn(m, n)
119+
left_matrix = torch.randn(m, m)
120+
Q_L = torch.linalg.qr(left_matrix + left_matrix.T).Q
121+
right_matrix = torch.randn(n, n)
122+
Q_R = torch.linalg.qr(right_matrix + right_matrix.T).Q
123+
124+
# Test that project operation to eigenbasis is correct
125+
# Calculate using sequential tensordot as used by the code
126+
grad_intermediate = torch.tensordot(grad, Q_L, dims=([0], [0]))
127+
# Check that grad_intermediate is transposed
128+
self.assertTrue(grad_intermediate.dim() == grad.transpose(0, 1).dim())
129+
grad_td = torch.tensordot(grad_intermediate, Q_R, dims=([0], [0]))
130+
# Calculate using pure sequential matmul
131+
grad_pt = Q_L.transpose(0, 1).matmul(grad).matmul(Q_R)
132+
self.assertTrue(torch.allclose(grad_td, grad_pt, atol=1e-6))
133+
134+
# Test that project_back operation out of eigenbasis is correct
135+
# Calculate using sequential tensordot as used by the code
136+
grad_intermediate = torch.tensordot(grad, Q_L, dims=([0], [1]))
137+
# Check that grad_intermediate is transposed
138+
self.assertTrue(grad_intermediate.dim() == grad.transpose(0, 1).dim())
139+
grad_td = torch.tensordot(grad_intermediate, Q_R, dims=([0], [1]))
140+
# Calculate using pure sequential matmul
141+
grad_pt = Q_L.matmul(grad).matmul(Q_R.transpose(0, 1))
142+
self.assertTrue(torch.allclose(grad_td, grad_pt, atol=1e-6))
143+
144+
@parameterized.parameters( # type: ignore[misc]
145+
{"N": 4, "M": 8},
146+
{"N": 16, "M": 8},
147+
{"N": 32, "M": 8},
148+
)
149+
def test_project_and_project_back(self, N: int, M: int) -> None:
150+
"""Tests that projecting a tensor to eigenbasis of QL and QR and back
151+
152+
The projected tensor should approximately recover the original tensor.
153+
"""
154+
torch.manual_seed(0)
155+
# Create a random tensor to project in and out of eigenbasis
156+
grad = torch.randn(M, N)
157+
# Create a state with 2 orthonormal matrix.
158+
Q_L = torch.linalg.qr(torch.randn(M, M))[0]
159+
Q_R = torch.linalg.qr(torch.randn(N, N))[0]
160+
orthonormal_matrix_list = [Q_L, Q_R]
161+
162+
projected = soap.precondition(
163+
grad=grad,
164+
eigenbasis_list=orthonormal_matrix_list,
165+
dims=[[0], [0]],
166+
)
167+
recov = soap.precondition(
168+
grad=projected,
169+
eigenbasis_list=orthonormal_matrix_list,
170+
dims=[[0], [1]],
171+
)
172+
# Check that the recovered tensor is close to the original.
173+
torch.testing.assert_close(
174+
grad,
175+
recov,
176+
atol=1e-6,
177+
rtol=1e-6,
178+
msg="Project and project_back did not recover the original tensor.",
179+
)
180+
181+
def test_get_precondition_frequency_fixed(self) -> None:
182+
"""Test that _get_precondition_frequency works with fixed frequency (default case)."""
183+
freq = _get_precondition_frequency(10, 100)
184+
self.assertEqual(freq, 10)
185+
186+
@parameterized.parameters(
187+
(5, 10, 20, 10, False),
188+
(15, 10, 20, 10, True),
189+
(20, 10, 15, 10, True),
190+
(21, 10, 15, 10, False),
191+
(30, 10, 15, 10, True),
192+
(31, 10, 15, 10, False),
193+
)
194+
def test_is_eigenbasis_update_step_fixed_frequency(
195+
self, step: int, adam_warmup_steps: int, precondition_warmup: int, precondition_frequency: int, expected: bool
196+
) -> None:
197+
"""Test _is_eigenbasis_update_step with fixed frequency."""
198+
result = _is_eigenbasis_update_step(step, adam_warmup_steps, precondition_warmup, precondition_frequency)
199+
self.assertEqual(result, expected)
200+
201+
def test_soap_optimizer_fixed_frequency(self) -> None:
202+
"""Test that SOAP optimizer can be created with fixed precondition frequency (default case)."""
203+
param = torch.randn(10, 5, requires_grad=True)
204+
optimizer = soap.SOAP([param], lr=1e-3, precondition_frequency=10)
205+
self.assertEqual(optimizer.param_groups[0]["precondition_frequency"], 10)
206+
207+
def test_soap_optimizer_class_based_schedule(self) -> None:
208+
"""Test that SOAP optimizer can be created with class-based precondition frequency schedule."""
209+
param = torch.randn(10, 5, requires_grad=True)
210+
schedule = LinearSchedule(min_freq=2, max_freq=10, transition_steps=100)
211+
optimizer = soap.SOAP([param], lr=1e-3, precondition_frequency=schedule)
212+
self.assertTrue((optimizer.param_groups[0]["precondition_frequency"]) == schedule)
213+
214+
self.assertEqual(schedule(0), 2)
215+
self.assertEqual(schedule(50), 6)
216+
self.assertEqual(schedule(100), 10)
217+
218+
adam_warmup = 1
219+
precondition_warmup = 0
220+
221+
self.assertTrue(_is_eigenbasis_update_step(10, adam_warmup, precondition_warmup, schedule))
222+
self.assertFalse(_is_eigenbasis_update_step(11, adam_warmup, precondition_warmup, schedule))
223+
self.assertTrue(_is_eigenbasis_update_step(60, adam_warmup, precondition_warmup, schedule))
224+
self.assertFalse(_is_eigenbasis_update_step(61, adam_warmup, precondition_warmup, schedule))
225+
self.assertTrue(_is_eigenbasis_update_step(120, adam_warmup, precondition_warmup, schedule))
226+
self.assertFalse(_is_eigenbasis_update_step(121, adam_warmup, precondition_warmup, schedule))
227+
228+
@parameterized.parameters(
229+
(1.0,),
230+
(0.0,),
231+
(0.5,),
232+
)
233+
def test_clip_update_rms(self, max_rms: float) -> None:
234+
"""Test that _clip_update_rms works by clipping the update RMS to max_rms in place."""
235+
# test for 5 different u values
236+
u_s = [
237+
torch.tensor([4.0, -1.0, 1.0, -1.0, 1.0], device="cuda"),
238+
torch.tensor([0.2, 0.2, 0.2, 0.2, 0.0], device="cuda"),
239+
torch.tensor([0.8, 0.0, 0.0, 0.0, 0.0], device="cuda"),
240+
]
241+
for u in u_s:
242+
u_clipped = u.clone()
243+
_clip_update_rms_in_place(u_clipped, max_rms=max_rms)
244+
if max_rms == 0:
245+
self.assertTrue(torch.linalg.norm(u_clipped) == torch.linalg.norm(u))
246+
else:
247+
self.assertTrue(torch.linalg.norm(u_clipped) / math.sqrt(u.numel()) <= max_rms)
19248

20249

21250
class SoapTest(parameterized.TestCase):

0 commit comments

Comments
 (0)