Skip to content

Commit 7b54565

Browse files
committed
added type hints for psgd
Signed-off-by: mikail <[email protected]>
1 parent f234e2f commit 7b54565

File tree

4 files changed

+33
-33
lines changed

4 files changed

+33
-33
lines changed

emerging_optimizers/psgd/procrustes_step.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
]
2424

2525

26-
def procrustes_step(Q, max_step_size=1 / 8):
26+
def procrustes_step(Q: torch.Tensor, max_step_size: float = 1 / 8) -> None:
2727
r"""One step of an in-place online solver for the orthogonal Procrustes problem.
2828
2929
The orthogonal Procrustes problem is min_U || U Q - I ||_F, s.t. U^H U = I

tests/test_procrustes_step.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,16 @@
1111
class ProcrustesStepTest(parameterized.TestCase):
1212
"""Test cases for procrustes_step function."""
1313

14-
def setUp(self):
14+
def setUp(self) -> None:
1515
"""Set up test fixtures."""
1616
torch.manual_seed(42)
1717
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
1818

19-
def _procrustes_objective(self, Q):
19+
def _procrustes_objective(self, Q: torch.Tensor) -> torch.Tensor:
2020
"""Helper function to compute Procrustes objective ||Q^H Q - I||_F^2."""
2121
return torch.linalg.matrix_norm(Q.H @ Q - torch.eye(Q.size(0), dtype=Q.dtype, device=Q.device), ord="fro") ** 2
2222

23-
def test_improves_orthogonality_simple_case(self):
23+
def test_improves_orthogonality_simple_case(self) -> None:
2424
"""Test that procrustes_step doesn't worsen orthogonality for a simple case."""
2525

2626
# Make a SPD non-orthogonal matrix
@@ -35,7 +35,7 @@ def test_improves_orthogonality_simple_case(self):
3535

3636
self.assertLessEqual(final_obj.item(), initial_obj.item() + 1e-6)
3737

38-
def test_modifies_matrix_in_place(self):
38+
def test_modifies_matrix_in_place(self) -> None:
3939
"""Test that procrustes_step modifies the matrix in place."""
4040
Q = torch.randn(3, 3, device=self.device)
4141
Q_original_id = id(Q)
@@ -49,7 +49,7 @@ def test_modifies_matrix_in_place(self):
4949
(128,),
5050
(1024,),
5151
)
52-
def test_minimal_change_when_already_orthogonal(self, size):
52+
def test_minimal_change_when_already_orthogonal(self, size: int) -> None:
5353
"""Test that procrustes_step makes minimal changes to an already orthogonal matrix."""
5454
# Create an orthogonal matrix using QR decomposition
5555
A = torch.randn(size, size, device=self.device, dtype=torch.float32)
@@ -66,7 +66,7 @@ def test_minimal_change_when_already_orthogonal(self, size):
6666
self.assertLess(final_obj.item(), 1e-5)
6767
self.assertLess(final_obj.item(), initial_obj.item() + 1e-5)
6868

69-
def test_handles_small_norm_gracefully(self):
69+
def test_handles_small_norm_gracefully(self) -> None:
7070
"""Test that procrustes_step handles matrices with small R norm improvement."""
7171
# Create a matrix very close to orthogonal
7272
A = torch.randn(3, 3, device=self.device, dtype=torch.float32)
@@ -90,7 +90,7 @@ def test_handles_small_norm_gracefully(self):
9090
(1 / 16,),
9191
(1 / 8,),
9292
)
93-
def test_different_step_sizes(self, max_step_size):
93+
def test_different_step_sizes(self, max_step_size: float) -> None:
9494
"""Test procrustes_step improvement with different step sizes."""
9595
perturbation = 1e-1 * torch.randn(10, 10, device=self.device, dtype=torch.float32) / math.sqrt(10)
9696
Q = torch.linalg.qr(torch.randn(10, 10, device=self.device, dtype=torch.float32)).Q + perturbation
@@ -108,7 +108,7 @@ def test_different_step_sizes(self, max_step_size):
108108
(512,),
109109
(8192,),
110110
)
111-
def test_different_matrix_sizes(self, size):
111+
def test_different_matrix_sizes(self, size: int) -> None:
112112
"""Test procrustes_step improvement with different matrix sizes."""
113113
# Create a non-orthogonal matrix by scaling an orthogonal one
114114
A = torch.randn(size, size, device=self.device, dtype=torch.float32)
@@ -124,7 +124,7 @@ def test_different_matrix_sizes(self, size):
124124

125125
self.assertLessEqual(final_obj.item(), initial_obj.item() + 1e-3)
126126

127-
def test_preserves_determinant_sign_for_real_matrices(self):
127+
def test_preserves_determinant_sign_for_real_matrices(self) -> None:
128128
"""Test that procrustes_step preserves the sign of determinant for real matrices."""
129129
# Create real matrices with positive and negative determinants
130130
Q_pos = torch.tensor([[2.0, 0.1], [0.1, 1.5]], device=self.device, dtype=torch.float32) # det > 0

tests/test_psgd_contractions.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
class TestPSGDKronContractions(parameterized.TestCase):
1515
"""Test cases for PSGD Kronecker contractions."""
1616

17-
def setUp(self):
17+
def setUp(self) -> None:
1818
"""Set up test fixtures."""
1919
torch.manual_seed(42)
2020
self.device = torch.device("cuda")
@@ -24,7 +24,7 @@ def setUp(self):
2424
(2, 3, 4),
2525
(2, 3, 5),
2626
)
27-
def test_partial_contraction_matches_reconstructed(self, size1, size2, size3):
27+
def test_partial_contraction_matches_reconstructed(self, size1: int, size2: int, size3: int) -> None:
2828
"""Test partial_contraction matches reconstructed."""
2929
G1 = torch.randn(size1, size2, size3, device=self.device)
3030
G2 = torch.randn(size1, size2, size3, device=self.device)
@@ -33,7 +33,7 @@ def test_partial_contraction_matches_reconstructed(self, size1, size2, size3):
3333
reconstructed = torch.tensordot(G1, G2, dims=([0, 2], [0, 2]))
3434
torch.testing.assert_close(result, reconstructed)
3535

36-
def test_apply_kronecker_factors_matches_reconstructed(self):
36+
def test_apply_kronecker_factors_matches_reconstructed(self) -> None:
3737
"""Test apply_kronecker_factors matches reconstructed."""
3838
Q_list = [
3939
torch.triu(torch.randn(2, 2, device=self.device)),
@@ -62,7 +62,7 @@ def test_apply_kronecker_factors_matches_reconstructed(self):
6262

6363
torch.testing.assert_close(result, reconstructed)
6464

65-
def test_apply_preconditioner_matches_reconstructed(self):
65+
def test_apply_preconditioner_matches_reconstructed(self) -> None:
6666
"""Test apply_preconditioner matches manual reconstruction for 2D tensor."""
6767
Q_list = [torch.triu(torch.randn(3, 3, device=self.device)), torch.triu(torch.randn(4, 4, device=self.device))]
6868
X = torch.randn(3, 4, device=self.device)
@@ -97,7 +97,7 @@ def test_apply_preconditioner_matches_reconstructed(self):
9797
(2, 3, 5, 2),
9898
(4, 6, 2, 1),
9999
)
100-
def test_mode_n_mul_and_permute_shapes(self, dim0, dim1, dim2, mode):
100+
def test_mode_n_mul_and_permute_shapes(self, dim0: int, dim1: int, dim2: int, mode: int) -> None:
101101
"""Test `_mode_n_mul_and_permute` with non-uniform shapes and different modes."""
102102
X = torch.randn(dim0, dim1, dim2, device=self.device)
103103
input_shape = X.shape

tests/test_psgd_utils.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,25 +12,25 @@
1212
class BalanceQTest(parameterized.TestCase):
1313
"""Test cases for balance_Q function."""
1414

15-
def setUp(self):
15+
def setUp(self) -> None:
1616
"""Set up test fixtures."""
1717
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
1818

19-
def test_normalization_on_empty_list(self):
19+
def test_normalization_on_empty_list(self) -> None:
2020
"""Test balance_Q with empty list."""
2121
Q_list = []
2222
balance_q_in_place(Q_list) # Should not raise any errors
2323
self.assertEqual(len(Q_list), 0)
2424

25-
def test_normalization_on_single_tensor(self):
25+
def test_normalization_on_single_tensor(self) -> None:
2626
"""Test balance_Q with single tensor."""
2727
Q = torch.randn(3, 3, device=self.device)
2828
original_Q = Q.clone()
2929
balance_q_in_place([Q])
3030
# for a single tensor, the result should be the same as the original
3131
torch.testing.assert_close(Q, original_Q)
3232

33-
def test_normalization_on_two_tensors(self):
33+
def test_normalization_on_two_tensors(self) -> None:
3434
"""Test balance_Q with two tensors."""
3535
Q1 = torch.tensor([[1.0, 2.0], [3.0, 4.0]], device=self.device)
3636
Q2 = torch.tensor([[0.1, 0.2], [0.3, 0.4]], device=self.device)
@@ -53,7 +53,7 @@ def test_normalization_on_two_tensors(self):
5353
(256, 256, 256),
5454
(4096, 4096, 4096),
5555
)
56-
def test_normalization_on_three_tensors(self, size1, size2, size3):
56+
def test_normalization_on_three_tensors(self, size1: int, size2: int, size3: int) -> None:
5757
"""Test balance_Q with multiple tensors of different dynamic ranges."""
5858
Q1 = torch.randn(size1, size1, device=self.device) * 10.0
5959
Q2 = torch.randn(size2, size2, device=self.device) * 0.01
@@ -76,7 +76,7 @@ def test_normalization_on_three_tensors(self, size1, size2, size3):
7676
self.assertAlmostEqual(new_max2.item(), expected_max.item(), places=5)
7777
self.assertAlmostEqual(new_max3.item(), expected_max.item(), places=5)
7878

79-
def test_modifies_in_place_on_three_tensors(self):
79+
def test_modifies_in_place_on_three_tensors(self) -> None:
8080
"""Test that balance_Q modifies tensors in place."""
8181
Q = torch.randn(3, 3, device=self.device)
8282
original_id = id(Q)
@@ -89,12 +89,12 @@ def test_modifies_in_place_on_three_tensors(self):
8989
class NormLowerBoundSpdTest(parameterized.TestCase):
9090
"""Test cases for norm_lower_bound_spd function."""
9191

92-
def setUp(self):
92+
def setUp(self) -> None:
9393
"""Set up test fixtures."""
9494
torch.manual_seed(42) # For reproducible tests
9595
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
9696

97-
def test_diagonal_matrix(self):
97+
def test_diagonal_matrix(self) -> None:
9898
"""Test norm_lower_bound_spd with diagonal matrix."""
9999
# For diagonal matrix, spectral norm equals largest diagonal entry
100100
diag_values = torch.tensor([1.0, 3.0, 2.0], device=self.device)
@@ -108,15 +108,15 @@ def test_diagonal_matrix(self):
108108
# For diagonal matrix, bound should be reasonably tight
109109
self.assertGreater(bound.item(), 0.5 * actual_norm.item())
110110

111-
def test_identity_matrix(self):
111+
def test_identity_matrix(self) -> None:
112112
"""Test norm_lower_bound_spd with identity matrix."""
113113
A = torch.eye(3, device=self.device)
114114
bound = norm_lower_bound_spd(A)
115115

116116
# For identity matrix, spectral norm is 1
117117
self.assertAlmostEqual(bound.item(), 1.0, places=5)
118118

119-
def test_zero_matrix(self):
119+
def test_zero_matrix(self) -> None:
120120
"""Test norm_lower_bound_spd with zero matrix."""
121121
A = torch.zeros(3, 3, device=self.device)
122122
bound = norm_lower_bound_spd(A)
@@ -128,7 +128,7 @@ def test_zero_matrix(self):
128128
dtype=[torch.float32, torch.bfloat16],
129129
size=[32, 256, 4096],
130130
)
131-
def test_norm_lower_bound_spd_is_lower_bound(self, dtype, size):
131+
def test_norm_lower_bound_spd_is_lower_bound(self, dtype: torch.dtype, size: int) -> None:
132132
"""Test that norm_lower_bound_spd provides a valid lower bound."""
133133
# Create a random SPD matrix
134134
B = torch.randn(size, size, dtype=dtype, device=self.device)
@@ -150,20 +150,20 @@ def test_norm_lower_bound_spd_is_lower_bound(self, dtype, size):
150150
class NormLowerBoundSkewTest(parameterized.TestCase):
151151
"""Test cases for norm_lower_bound_skew function."""
152152

153-
def setUp(self):
153+
def setUp(self) -> None:
154154
"""Set up test fixtures."""
155155
torch.manual_seed(42) # For reproducible tests
156156
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
157157

158-
def test_zero_matrix(self):
158+
def test_zero_matrix(self) -> None:
159159
"""Test norm_lower_bound_skew with zero matrix."""
160160
A = torch.zeros(3, 3, device=self.device)
161161
bound = norm_lower_bound_skew(A)
162162

163163
# For zero matrix, bound should be 0
164164
self.assertAlmostEqual(bound.item(), 0.0, places=5)
165165

166-
def test_small_skew_symmetric_matrix(self):
166+
def test_small_skew_symmetric_matrix(self) -> None:
167167
"""Test norm_lower_bound_skew with a simple skew-symmetric matrix."""
168168
# Create a simple 3x3 skew-symmetric matrix
169169
A = torch.tensor([[0.0, 1.0, -2.0], [-1.0, 0.0, 3.0], [2.0, -3.0, 0.0]], device=self.device)
@@ -177,7 +177,7 @@ def test_small_skew_symmetric_matrix(self):
177177
# Bound should be positive for non-zero matrix
178178
self.assertGreater(bound.item(), 0.0)
179179

180-
def test_identity_based_skew_matrix(self):
180+
def test_identity_based_skew_matrix(self) -> None:
181181
"""Test norm_lower_bound_skew with matrix based on identity structure."""
182182
# Create skew-symmetric matrix from anti-symmetric part of random matrix
183183
n = 4
@@ -194,7 +194,7 @@ def test_identity_based_skew_matrix(self):
194194
dtype=[torch.float32, torch.float64],
195195
size=[32, 128, 256],
196196
)
197-
def test_norm_lower_bound_skew_is_lower_bound(self, dtype, size):
197+
def test_norm_lower_bound_skew_is_lower_bound(self, dtype: torch.dtype, size: int) -> None:
198198
"""Test that norm_lower_bound_skew provides a valid lower bound."""
199199
# Create a random skew-symmetric matrix
200200
B = torch.randn(size, size, dtype=dtype, device=self.device)
@@ -211,7 +211,7 @@ def test_norm_lower_bound_skew_is_lower_bound(self, dtype, size):
211211
self.assertGreaterEqual(bound.item(), 0.0)
212212

213213
@parameterized.parameters([4, 16, 32])
214-
def test_different_subspace_dimensions(self, rank):
214+
def test_different_subspace_dimensions(self, rank: int) -> None:
215215
"""Test norm_lower_bound_skew with different subspace dimensions."""
216216
# Create a skew-symmetric matrix
217217
B = torch.randn(64, 64, device=self.device)
@@ -222,7 +222,7 @@ def test_different_subspace_dimensions(self, rank):
222222
self.assertGreaterEqual(bound.item(), 0.0)
223223

224224
actual_norm = torch.linalg.matrix_norm(A, ord=2)
225-
self.assertLessEqual(bound.item(), actual_norm.item() + 1e-4)
225+
self.assertLessEqual(bound.item(), actual_norm.item() + 1e-5)
226226

227227

228228
if __name__ == "__main__":

0 commit comments

Comments
 (0)