|
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
15 | 15 | import math |
16 | | -from typing import Any |
| 16 | +from functools import partial |
| 17 | +from typing import Any, List |
17 | 18 |
|
18 | 19 | import torch |
19 | 20 | from absl.testing import absltest, parameterized |
|
27 | 28 | from emerging_optimizers.utils.precondition_schedules import LinearSchedule |
28 | 29 |
|
29 | 30 |
|
| 31 | +def kl_shampoo_update_ref( |
| 32 | + kronecker_factor_list: List[torch.Tensor], |
| 33 | + grad: torch.Tensor, |
| 34 | + eigenbasis_list: List[torch.Tensor], |
| 35 | + shampoo_beta: float, |
| 36 | + eps: float, |
| 37 | + eigval_exp: float = -1.0, |
| 38 | +) -> None: |
| 39 | + """Reference implementation of KL-Shampoo update. |
| 40 | +
|
| 41 | + Using same functionality implemented by different people as testing reference. The chance of two |
| 42 | + independent implementations having the same bug is very low. |
| 43 | +
|
| 44 | + """ |
| 45 | + if grad.dim() != 2: |
| 46 | + raise ValueError("KL-Shampoo mathematical correction is only supported for 2D tensors") |
| 47 | + # scale the gradient matrix by the approximate eigenvalues and the eigenbasis |
| 48 | + # G@Q_R@λ_R^(−1)@Q_R.T@G.T/dim(GG.T) and G.T@Q_L@λ_L^(−1)@Q_L.T@G/dim(G.TG) |
| 49 | + scale_factors = [ |
| 50 | + 1 |
| 51 | + / grad.shape[idx] |
| 52 | + * (torch.diag(eigenbasis_list[idx].T @ kronecker_factor_list[idx] @ eigenbasis_list[idx]) + eps) ** eigval_exp |
| 53 | + for idx in range(len(kronecker_factor_list)) |
| 54 | + ] |
| 55 | + print(scale_factors) |
| 56 | + kronecker_product_corrections = [ |
| 57 | + (eigenbasis_list[idx] * scale_factors[idx][None, :]) @ eigenbasis_list[idx].T |
| 58 | + for idx in range(len(kronecker_factor_list)) |
| 59 | + ] |
| 60 | + kronecker_product_updates = [ |
| 61 | + grad @ kronecker_product_corrections[1] @ grad.T, |
| 62 | + grad.T @ kronecker_product_corrections[0] @ grad, |
| 63 | + ] |
| 64 | + for idx in range(len(kronecker_factor_list)): |
| 65 | + kronecker_factor_list[idx].lerp_(kronecker_product_updates[idx], 1 - shampoo_beta) |
| 66 | + |
| 67 | + |
30 | 68 | class SoapFunctionsTest(parameterized.TestCase): |
31 | 69 | def test_init_preconditioner_multidim_tensor_shapes(self) -> None: |
32 | 70 | """Tests init_preconditioner with a multi-dimensional tensor.""" |
@@ -246,6 +284,34 @@ def test_clip_update_rms(self, max_rms: float) -> None: |
246 | 284 | else: |
247 | 285 | self.assertTrue(torch.linalg.norm(u_clipped) / math.sqrt(u.numel()) <= max_rms) |
248 | 286 |
|
| 287 | + @parameterized.parameters( |
| 288 | + (4, 5), |
| 289 | + (3, 3), |
| 290 | + (5, 4), |
| 291 | + ) |
| 292 | + def test_kl_shampoo_update(self, m, n): |
| 293 | + rand_exp_fn = partial(torch.randint, low=-4, high=-1, dtype=torch.float32, device="cuda") |
| 294 | + kronecker_factor_list = [ |
| 295 | + 2 ** rand_exp_fn(size=(m, m)), |
| 296 | + 2 ** rand_exp_fn(size=(n, n)), |
| 297 | + ] |
| 298 | + kronecker_factor_list_ref = [f.clone() for f in kronecker_factor_list] |
| 299 | + |
| 300 | + test_grad = 2 ** rand_exp_fn(size=(m, n)) |
| 301 | + eigenbasis_list = [2 ** rand_exp_fn(size=(m, m)), 2 ** rand_exp_fn(size=(n, n))] |
| 302 | + kwargs = dict( |
| 303 | + grad=test_grad, |
| 304 | + shampoo_beta=0.5, |
| 305 | + eps=1e-8, |
| 306 | + eigval_exp=-1.0, |
| 307 | + eigenbasis_list=eigenbasis_list, |
| 308 | + ) |
| 309 | + kl_shampoo_update_ref(kronecker_factor_list_ref, **kwargs) |
| 310 | + soap.update_kronecker_factors_kl_shampoo(kronecker_factor_list, **kwargs) |
| 311 | + |
| 312 | + torch.testing.assert_close(kronecker_factor_list[0], kronecker_factor_list_ref[0], atol=1e-6, rtol=1e-6) |
| 313 | + torch.testing.assert_close(kronecker_factor_list[1], kronecker_factor_list_ref[1], atol=1e-6, rtol=1e-6) |
| 314 | + |
249 | 315 |
|
250 | 316 | class SoapTest(parameterized.TestCase): |
251 | 317 | def setUp(self): |
|
0 commit comments