1616from absl .testing import absltest , parameterized
1717
1818from emerging_optimizers import utils
19- from emerging_optimizers .soap .soap_utils import (
20- _adaptive_criteria_met ,
21- _orthogonal_iteration ,
22- get_eigenbasis_eigh ,
23- get_eigenbasis_qr ,
24- )
19+ from emerging_optimizers .soap import soap_utils
2520
2621
2722# Base class for tests requiring seeding for determinism
@@ -44,7 +39,7 @@ def test_adaptive_criteria_met(self) -> None:
4439
4540 # Test with small tolerance - should not update since matrix is diagonal
4641 self .assertFalse (
47- _adaptive_criteria_met (
42+ soap_utils . _adaptive_criteria_met (
4843 approx_eigenvalue_matrix = diagonal_matrix ,
4944 tolerance = 0.1 ,
5045 ),
@@ -63,7 +58,7 @@ def test_adaptive_criteria_met(self) -> None:
6358
6459 # Test with small tolerance - should update since matrix has significant off-diagonal elements
6560 self .assertTrue (
66- _adaptive_criteria_met (
61+ soap_utils . _adaptive_criteria_met (
6762 approx_eigenvalue_matrix = off_diagonal_matrix ,
6863 tolerance = 0.1 ,
6964 ),
@@ -72,7 +67,7 @@ def test_adaptive_criteria_met(self) -> None:
7267
7368 # Test with large tolerance - should not update even with off-diagonal elements
7469 self .assertFalse (
75- _adaptive_criteria_met (
70+ soap_utils . _adaptive_criteria_met (
7671 approx_eigenvalue_matrix = off_diagonal_matrix ,
7772 tolerance = 10.0 ,
7873 ),
@@ -102,7 +97,7 @@ def test_get_eigenbasis_qr(self, N: int, M: int) -> None:
10297 }
10398
10499 # We'll call get_eigenbasis_qr
105- Q_new_list , exp_avg_sq_new = get_eigenbasis_qr (
100+ Q_new_list , exp_avg_sq_new = soap_utils . get_eigenbasis_qr (
106101 kronecker_factor_list = state ["GG" ],
107102 eigenbasis_list = state ["Q" ],
108103 exp_avg_sq = state ["exp_avg_sq" ],
@@ -171,11 +166,11 @@ def test_update_eigenbasis_with_QR(self, N: int, power_iter_steps: int) -> None:
171166 # Create estimated eigenvalue matrix by projecting kronecker_factor onto eigenbasis's basis
172167 approx_eigenvalue_matrix = eigenbasis .T .mm (kronecker_factor ).mm (eigenbasis )
173168 # Extract eigenvalues from the diagonal of the estimated eigenvalue matrix
174- est_eigvals = torch .diag (approx_eigenvalue_matrix )
169+ approx_eigvals = torch .diag (approx_eigenvalue_matrix )
175170
176171 # Call the QR function to update the eigenbases and re-order the inner adam second moment
177- Q_new , exp_avg_sq_new = _orthogonal_iteration (
178- approx_eigenvalue_matrix = approx_eigenvalue_matrix ,
172+ Q_new , exp_avg_sq_new = soap_utils . _orthogonal_iteration (
173+ approx_eigvals = approx_eigvals ,
179174 kronecker_factor = kronecker_factor ,
180175 eigenbasis = eigenbasis ,
181176 ind = 0 , # Test with first dimension
@@ -200,7 +195,7 @@ def test_update_eigenbasis_with_QR(self, N: int, power_iter_steps: int) -> None:
200195
201196 # Test 3: Check that exp_avg_sq is properly sorted based on eigenvalues
202197 # The sorting should be based on the diagonal elements of estimated_eigenvalue_matrix
203- sort_idx = torch .argsort (est_eigvals , descending = True )
198+ sort_idx = torch .argsort (approx_eigvals , descending = True )
204199 expected_exp_avg_sq = exp_avg_sq .index_select (0 , sort_idx )
205200 torch .testing .assert_close (
206201 exp_avg_sq_new ,
@@ -230,7 +225,7 @@ def test_get_eigenbasis_eigh(self, dims: list[int]) -> None:
230225 k_factor = k_factor @ k_factor .T + torch .eye (dim , device = "cuda" ) * 1e-5
231226 kronecker_factor_list .append (k_factor )
232227
233- Q_list = get_eigenbasis_eigh (kronecker_factor_list , convert_to_float = True )
228+ Q_list = soap_utils . get_eigenbasis_eigh (kronecker_factor_list , convert_to_float = True )
234229
235230 self .assertEqual (len (Q_list ), len (kronecker_factor_list ))
236231
@@ -265,6 +260,20 @@ def test_get_eigenbasis_eigh(self, dims: list[int]) -> None:
265260 msg = f"Matrix { i } was not properly diagonalized. Off-diagonal norm: { off_diagonal_norm } " ,
266261 )
267262
263+ def test_conjugate_assert_2d_input (self ) -> None :
264+ """Tests the conjugate function."""
265+ a = torch .randn (2 , 3 , 4 , device = "cuda" )
266+ with self .assertRaises (TypeError ):
267+ soap_utils ._conjugate (a , a )
268+
269+ def test_conjugate_match_reference (self ) -> None :
270+ x = torch .randn (15 , 17 , device = "cuda" )
271+ a = x @ x .T
272+ _ , p = torch .linalg .eigh (a )
273+
274+ ref = p .T @ a @ p
275+ torch .testing .assert_close (soap_utils ._conjugate (a , p ), ref , atol = 0 , rtol = 0 )
276+
268277
269278if __name__ == "__main__" :
270279 absltest .main ()
0 commit comments