Skip to content

Commit 4123d53

Browse files
committed
add test for kl shampoo
Signed-off-by: Hao Wu <skyw@nvidia.com>
1 parent 8070fa6 commit 4123d53

File tree

2 files changed

+69
-1
lines changed

2 files changed

+69
-1
lines changed

emerging_optimizers/soap/soap.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,8 @@ def update_kronecker_factors_kl_shampoo(
498498
approx_eigvals = utils.eig.conjugate(kronecker_factor, eigenbasis, diag=True)
499499
scale_factor = 1 / grad.shape[idx] * approx_eigvals.clamp_min(eps) ** eigval_exp
500500

501+
logging.debug(f"scale_factor[{idx}]: {scale_factor}")
502+
501503
correction = (eigenbasis * scale_factor[None, :]) @ eigenbasis.T
502504

503505
maybe_transpose_grad = grad.T if idx == 1 else grad

tests/test_soap.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
import math
16-
from typing import Any
16+
from functools import partial
17+
from typing import Any, List
1718

1819
import torch
1920
from absl.testing import absltest, parameterized
@@ -27,6 +28,43 @@
2728
from emerging_optimizers.utils.precondition_schedules import LinearSchedule
2829

2930

31+
def kl_shampoo_update_ref(
32+
kronecker_factor_list: List[torch.Tensor],
33+
grad: torch.Tensor,
34+
eigenbasis_list: List[torch.Tensor],
35+
shampoo_beta: float,
36+
eps: float,
37+
eigval_exp: float = -1.0,
38+
) -> None:
39+
"""Reference implementation of KL-Shampoo update.
40+
41+
Using same functionality implemented by different people as testing reference. The chance of two
42+
independent implementations having the same bug is very low.
43+
44+
"""
45+
if grad.dim() != 2:
46+
raise ValueError("KL-Shampoo mathematical correction is only supported for 2D tensors")
47+
# scale the gradient matrix by the approximate eigenvalues and the eigenbasis
48+
# G@Q_R@λ_R^(−1)@Q_R.T@G.T/dim(GG.T) and G.T@Q_L@λ_L^(−1)@Q_L.T@G/dim(G.TG)
49+
scale_factors = [
50+
1
51+
/ grad.shape[idx]
52+
* (torch.diag(eigenbasis_list[idx].T @ kronecker_factor_list[idx] @ eigenbasis_list[idx]) + eps) ** eigval_exp
53+
for idx in range(len(kronecker_factor_list))
54+
]
55+
print(scale_factors)
56+
kronecker_product_corrections = [
57+
(eigenbasis_list[idx] * scale_factors[idx][None, :]) @ eigenbasis_list[idx].T
58+
for idx in range(len(kronecker_factor_list))
59+
]
60+
kronecker_product_updates = [
61+
grad @ kronecker_product_corrections[1] @ grad.T,
62+
grad.T @ kronecker_product_corrections[0] @ grad,
63+
]
64+
for idx in range(len(kronecker_factor_list)):
65+
kronecker_factor_list[idx].lerp_(kronecker_product_updates[idx], 1 - shampoo_beta)
66+
67+
3068
class SoapFunctionsTest(parameterized.TestCase):
3169
def test_init_preconditioner_multidim_tensor_shapes(self) -> None:
3270
"""Tests init_preconditioner with a multi-dimensional tensor."""
@@ -246,6 +284,34 @@ def test_clip_update_rms(self, max_rms: float) -> None:
246284
else:
247285
self.assertTrue(torch.linalg.norm(u_clipped) / math.sqrt(u.numel()) <= max_rms)
248286

287+
@parameterized.parameters(
288+
(4, 5),
289+
(3, 3),
290+
(5, 4),
291+
)
292+
def test_kl_shampoo_update(self, m, n):
293+
rand_exp_fn = partial(torch.randint, low=-4, high=-1, dtype=torch.float32, device="cuda")
294+
kronecker_factor_list = [
295+
2 ** rand_exp_fn(size=(m, m)),
296+
2 ** rand_exp_fn(size=(n, n)),
297+
]
298+
kronecker_factor_list_ref = [f.clone() for f in kronecker_factor_list]
299+
300+
test_grad = 2 ** rand_exp_fn(size=(m, n))
301+
eigenbasis_list = [2 ** rand_exp_fn(size=(m, m)), 2 ** rand_exp_fn(size=(n, n))]
302+
kwargs = dict(
303+
grad=test_grad,
304+
shampoo_beta=0.5,
305+
eps=1e-8,
306+
eigval_exp=-1.0,
307+
eigenbasis_list=eigenbasis_list,
308+
)
309+
kl_shampoo_update_ref(kronecker_factor_list_ref, **kwargs)
310+
soap.update_kronecker_factors_kl_shampoo(kronecker_factor_list, **kwargs)
311+
312+
torch.testing.assert_close(kronecker_factor_list[0], kronecker_factor_list_ref[0], atol=1e-6, rtol=1e-6)
313+
torch.testing.assert_close(kronecker_factor_list[1], kronecker_factor_list_ref[1], atol=1e-6, rtol=1e-6)
314+
249315

250316
class SoapTest(parameterized.TestCase):
251317
def setUp(self):

0 commit comments

Comments
 (0)