|
12 | 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
| 15 | +import math |
| 16 | +from typing import Any |
| 17 | + |
15 | 18 | import torch |
16 | 19 | from absl.testing import absltest, parameterized |
17 | 20 |
|
18 | 21 | from emerging_optimizers.soap import soap |
| 22 | +from emerging_optimizers.soap.soap import ( |
| 23 | + _clip_update_rms_in_place, |
| 24 | + _get_precondition_frequency, |
| 25 | + _is_eigenbasis_update_step, |
| 26 | +) |
| 27 | +from emerging_optimizers.utils.precondition_schedules import LinearSchedule |
| 28 | + |
| 29 | + |
| 30 | +class SoapFunctionsTest(parameterized.TestCase): |
| 31 | + def test_init_preconditioner_multidim_tensor_shapes(self) -> None: |
| 32 | + """Tests init_preconditioner with a multi-dimensional tensor.""" |
| 33 | + grad = torch.randn(3, 4, 5) |
| 34 | + state: dict[str, Any] = {} |
| 35 | + state["GG"] = soap.init_kronecker_factors(grad, precondition_1d=False) |
| 36 | + self.assertEqual(len(state["GG"]), grad.dim()) |
| 37 | + self.assertEqual(state["GG"][0].shape, (3, 3)) |
| 38 | + self.assertEqual(state["GG"][1].shape, (4, 4)) |
| 39 | + self.assertEqual(state["GG"][2].shape, (5, 5)) |
| 40 | + |
| 41 | + @parameterized.parameters( |
| 42 | + (1,), |
| 43 | + (2,), |
| 44 | + (3,), |
| 45 | + ) |
| 46 | + def test_adam_warmup_steps(self, adam_warmup_steps: int) -> None: |
| 47 | + """Tests that adam_warmup_steps causes state["Q"] to be None until the specified steps are completed.""" |
| 48 | + |
| 49 | + param = torch.randn(5, 3, requires_grad=True, device="cuda") |
| 50 | + |
| 51 | + optimizer = soap.SOAP( |
| 52 | + [param], |
| 53 | + lr=0.001, |
| 54 | + weight_decay=0.01, |
| 55 | + adam_warmup_steps=adam_warmup_steps, |
| 56 | + precondition_frequency=1, |
| 57 | + ) |
| 58 | + |
| 59 | + dummy_Q = [torch.eye(shape, device=param.device) for shape in param.shape] |
| 60 | + for step in range(adam_warmup_steps - 1): |
| 61 | + param.grad = torch.randn_like(param) |
| 62 | + optimizer.step() |
| 63 | + state = optimizer.state[param] |
| 64 | + |
| 65 | + torch.testing.assert_close( |
| 66 | + state["Q"], dummy_Q, atol=0, rtol=0, msg=f"Q should stay identity at step {step}" |
| 67 | + ) |
| 68 | + |
| 69 | + for step in range(adam_warmup_steps - 1, adam_warmup_steps + 3): |
| 70 | + param.grad = torch.randn_like(param) |
| 71 | + optimizer.step() |
| 72 | + state = optimizer.state[param] |
| 73 | + |
| 74 | + # Verify Q has the right shape (a list with tensors for each dim) |
| 75 | + self.assertIsInstance(state["Q"], list) |
| 76 | + self.assertEqual(len(state["Q"]), param.dim()) |
| 77 | + # Verify Q has the right shape (a list with square eigenvector matrices for each dim) |
| 78 | + self.assertEqual(state["Q"][0].shape, (5, 5)) |
| 79 | + self.assertEqual(state["Q"][1].shape, (3, 3)) |
| 80 | + |
| 81 | + def test_update_kronecker_factors(self) -> None: |
| 82 | + max_dim = 8 |
| 83 | + shampoo_beta = 0.9 |
| 84 | + dim0, dim1, dim2 = 3, max_dim + 2, 5 |
| 85 | + grad = torch.randn(dim0, dim1, dim2) |
| 86 | + |
| 87 | + # Initialize factors |
| 88 | + initial_factors = soap.init_kronecker_factors(grad, precondition_1d=False) |
| 89 | + |
| 90 | + kronecker_factors = [f.clone() for f in initial_factors] |
| 91 | + |
| 92 | + soap.update_kronecker_factors( |
| 93 | + kronecker_factor_list=kronecker_factors, |
| 94 | + grad=grad, |
| 95 | + shampoo_beta=shampoo_beta, |
| 96 | + precondition_1d=False, |
| 97 | + ) |
| 98 | + |
| 99 | + self.assertEqual(len(kronecker_factors), grad.dim()) |
| 100 | + |
| 101 | + contract_dims_0 = [1, 2] |
| 102 | + outer_product_0 = torch.tensordot(grad, grad, dims=[contract_dims_0] * 2) |
| 103 | + expected_factor_0 = initial_factors[0] * shampoo_beta + outer_product_0 * (1 - shampoo_beta) |
| 104 | + torch.testing.assert_close(kronecker_factors[0], expected_factor_0, atol=1e-6, rtol=1e-6) |
| 105 | + |
| 106 | + contract_dims_2 = [0, 1] |
| 107 | + outer_product_2 = torch.tensordot(grad, grad, dims=[contract_dims_2] * 2) |
| 108 | + expected_factor_2 = initial_factors[2] * shampoo_beta + outer_product_2 * (1 - shampoo_beta) |
| 109 | + torch.testing.assert_close(kronecker_factors[2], expected_factor_2, atol=1e-6, rtol=1e-6) |
| 110 | + |
| 111 | + @parameterized.parameters( |
| 112 | + (4, 5), |
| 113 | + (3, 3), |
| 114 | + (5, 4), |
| 115 | + ) |
| 116 | + def test_tensordot_vs_matmul(self, m, n): |
| 117 | + # Create tensors with random eigenvectors for rotation matrices QL and QR |
| 118 | + grad = torch.randn(m, n) |
| 119 | + left_matrix = torch.randn(m, m) |
| 120 | + Q_L = torch.linalg.qr(left_matrix + left_matrix.T).Q |
| 121 | + right_matrix = torch.randn(n, n) |
| 122 | + Q_R = torch.linalg.qr(right_matrix + right_matrix.T).Q |
| 123 | + |
| 124 | + # Test that project operation to eigenbasis is correct |
| 125 | + # Calculate using sequential tensordot as used by the code |
| 126 | + grad_intermediate = torch.tensordot(grad, Q_L, dims=([0], [0])) |
| 127 | + # Check that grad_intermediate is transposed |
| 128 | + self.assertTrue(grad_intermediate.dim() == grad.transpose(0, 1).dim()) |
| 129 | + grad_td = torch.tensordot(grad_intermediate, Q_R, dims=([0], [0])) |
| 130 | + # Calculate using pure sequential matmul |
| 131 | + grad_pt = Q_L.transpose(0, 1).matmul(grad).matmul(Q_R) |
| 132 | + self.assertTrue(torch.allclose(grad_td, grad_pt, atol=1e-6)) |
| 133 | + |
| 134 | + # Test that project_back operation out of eigenbasis is correct |
| 135 | + # Calculate using sequential tensordot as used by the code |
| 136 | + grad_intermediate = torch.tensordot(grad, Q_L, dims=([0], [1])) |
| 137 | + # Check that grad_intermediate is transposed |
| 138 | + self.assertTrue(grad_intermediate.dim() == grad.transpose(0, 1).dim()) |
| 139 | + grad_td = torch.tensordot(grad_intermediate, Q_R, dims=([0], [1])) |
| 140 | + # Calculate using pure sequential matmul |
| 141 | + grad_pt = Q_L.matmul(grad).matmul(Q_R.transpose(0, 1)) |
| 142 | + self.assertTrue(torch.allclose(grad_td, grad_pt, atol=1e-6)) |
| 143 | + |
| 144 | + @parameterized.parameters( # type: ignore[misc] |
| 145 | + {"N": 4, "M": 8}, |
| 146 | + {"N": 16, "M": 8}, |
| 147 | + {"N": 32, "M": 8}, |
| 148 | + ) |
| 149 | + def test_project_and_project_back(self, N: int, M: int) -> None: |
| 150 | + """Tests that projecting a tensor to eigenbasis of QL and QR and back |
| 151 | +
|
| 152 | + The projected tensor should approximately recover the original tensor. |
| 153 | + """ |
| 154 | + torch.manual_seed(0) |
| 155 | + # Create a random tensor to project in and out of eigenbasis |
| 156 | + grad = torch.randn(M, N) |
| 157 | + # Create a state with 2 orthonormal matrix. |
| 158 | + Q_L = torch.linalg.qr(torch.randn(M, M))[0] |
| 159 | + Q_R = torch.linalg.qr(torch.randn(N, N))[0] |
| 160 | + orthonormal_matrix_list = [Q_L, Q_R] |
| 161 | + |
| 162 | + projected = soap.precondition( |
| 163 | + grad=grad, |
| 164 | + eigenbasis_list=orthonormal_matrix_list, |
| 165 | + dims=[[0], [0]], |
| 166 | + ) |
| 167 | + recov = soap.precondition( |
| 168 | + grad=projected, |
| 169 | + eigenbasis_list=orthonormal_matrix_list, |
| 170 | + dims=[[0], [1]], |
| 171 | + ) |
| 172 | + # Check that the recovered tensor is close to the original. |
| 173 | + torch.testing.assert_close( |
| 174 | + grad, |
| 175 | + recov, |
| 176 | + atol=1e-6, |
| 177 | + rtol=1e-6, |
| 178 | + msg="Project and project_back did not recover the original tensor.", |
| 179 | + ) |
| 180 | + |
| 181 | + def test_get_precondition_frequency_fixed(self) -> None: |
| 182 | + """Test that _get_precondition_frequency works with fixed frequency (default case).""" |
| 183 | + freq = _get_precondition_frequency(10, 100) |
| 184 | + self.assertEqual(freq, 10) |
| 185 | + |
| 186 | + @parameterized.parameters( |
| 187 | + (5, 10, 20, 10, False), |
| 188 | + (15, 10, 20, 10, True), |
| 189 | + (20, 10, 15, 10, True), |
| 190 | + (21, 10, 15, 10, False), |
| 191 | + (30, 10, 15, 10, True), |
| 192 | + (31, 10, 15, 10, False), |
| 193 | + ) |
| 194 | + def test_is_eigenbasis_update_step_fixed_frequency( |
| 195 | + self, step: int, adam_warmup_steps: int, precondition_warmup: int, precondition_frequency: int, expected: bool |
| 196 | + ) -> None: |
| 197 | + """Test _is_eigenbasis_update_step with fixed frequency.""" |
| 198 | + result = _is_eigenbasis_update_step(step, adam_warmup_steps, precondition_warmup, precondition_frequency) |
| 199 | + self.assertEqual(result, expected) |
| 200 | + |
| 201 | + def test_soap_optimizer_fixed_frequency(self) -> None: |
| 202 | + """Test that SOAP optimizer can be created with fixed precondition frequency (default case).""" |
| 203 | + param = torch.randn(10, 5, requires_grad=True) |
| 204 | + optimizer = soap.SOAP([param], lr=1e-3, precondition_frequency=10) |
| 205 | + self.assertEqual(optimizer.param_groups[0]["precondition_frequency"], 10) |
| 206 | + |
| 207 | + def test_soap_optimizer_class_based_schedule(self) -> None: |
| 208 | + """Test that SOAP optimizer can be created with class-based precondition frequency schedule.""" |
| 209 | + param = torch.randn(10, 5, requires_grad=True) |
| 210 | + schedule = LinearSchedule(min_freq=2, max_freq=10, transition_steps=100) |
| 211 | + optimizer = soap.SOAP([param], lr=1e-3, precondition_frequency=schedule) |
| 212 | + self.assertTrue((optimizer.param_groups[0]["precondition_frequency"]) == schedule) |
| 213 | + |
| 214 | + self.assertEqual(schedule(0), 2) |
| 215 | + self.assertEqual(schedule(50), 6) |
| 216 | + self.assertEqual(schedule(100), 10) |
| 217 | + |
| 218 | + adam_warmup = 1 |
| 219 | + precondition_warmup = 0 |
| 220 | + |
| 221 | + self.assertTrue(_is_eigenbasis_update_step(10, adam_warmup, precondition_warmup, schedule)) |
| 222 | + self.assertFalse(_is_eigenbasis_update_step(11, adam_warmup, precondition_warmup, schedule)) |
| 223 | + self.assertTrue(_is_eigenbasis_update_step(60, adam_warmup, precondition_warmup, schedule)) |
| 224 | + self.assertFalse(_is_eigenbasis_update_step(61, adam_warmup, precondition_warmup, schedule)) |
| 225 | + self.assertTrue(_is_eigenbasis_update_step(120, adam_warmup, precondition_warmup, schedule)) |
| 226 | + self.assertFalse(_is_eigenbasis_update_step(121, adam_warmup, precondition_warmup, schedule)) |
| 227 | + |
| 228 | + @parameterized.parameters( |
| 229 | + (1.0,), |
| 230 | + (0.0,), |
| 231 | + (0.5,), |
| 232 | + ) |
| 233 | + def test_clip_update_rms(self, max_rms: float) -> None: |
| 234 | + """Test that _clip_update_rms works by clipping the update RMS to max_rms in place.""" |
| 235 | + # test for 5 different u values |
| 236 | + u_s = [ |
| 237 | + torch.tensor([4.0, -1.0, 1.0, -1.0, 1.0], device="cuda"), |
| 238 | + torch.tensor([0.2, 0.2, 0.2, 0.2, 0.0], device="cuda"), |
| 239 | + torch.tensor([0.8, 0.0, 0.0, 0.0, 0.0], device="cuda"), |
| 240 | + ] |
| 241 | + for u in u_s: |
| 242 | + u_clipped = u.clone() |
| 243 | + _clip_update_rms_in_place(u_clipped, max_rms=max_rms) |
| 244 | + if max_rms == 0: |
| 245 | + self.assertTrue(torch.linalg.norm(u_clipped) == torch.linalg.norm(u)) |
| 246 | + else: |
| 247 | + self.assertTrue(torch.linalg.norm(u_clipped) / math.sqrt(u.numel()) <= max_rms) |
19 | 248 |
|
20 | 249 |
|
21 | 250 | class SoapTest(parameterized.TestCase): |
|
0 commit comments