1212class BalanceQTest (parameterized .TestCase ):
1313 """Test cases for balance_Q function."""
1414
15- def setUp (self ):
15+ def setUp (self ) -> None :
1616 """Set up test fixtures."""
1717 self .device = torch .device ("cuda" ) if torch .cuda .is_available () else torch .device ("cpu" )
1818
19- def test_normalization_on_empty_list (self ):
19+ def test_normalization_on_empty_list (self ) -> None :
2020 """Test balance_Q with empty list."""
2121 Q_list = []
2222 balance_q_in_place (Q_list ) # Should not raise any errors
2323 self .assertEqual (len (Q_list ), 0 )
2424
25- def test_normalization_on_single_tensor (self ):
25+ def test_normalization_on_single_tensor (self ) -> None :
2626 """Test balance_Q with single tensor."""
2727 Q = torch .randn (3 , 3 , device = self .device )
2828 original_Q = Q .clone ()
2929 balance_q_in_place ([Q ])
3030 # for a single tensor, the result should be the same as the original
3131 torch .testing .assert_close (Q , original_Q )
3232
33- def test_normalization_on_two_tensors (self ):
33+ def test_normalization_on_two_tensors (self ) -> None :
3434 """Test balance_Q with two tensors."""
3535 Q1 = torch .tensor ([[1.0 , 2.0 ], [3.0 , 4.0 ]], device = self .device )
3636 Q2 = torch .tensor ([[0.1 , 0.2 ], [0.3 , 0.4 ]], device = self .device )
@@ -53,7 +53,7 @@ def test_normalization_on_two_tensors(self):
5353 (256 , 256 , 256 ),
5454 (4096 , 4096 , 4096 ),
5555 )
56- def test_normalization_on_three_tensors (self , size1 , size2 , size3 ) :
56+ def test_normalization_on_three_tensors (self , size1 : int , size2 : int , size3 : int ) -> None :
5757 """Test balance_Q with multiple tensors of different dynamic ranges."""
5858 Q1 = torch .randn (size1 , size1 , device = self .device ) * 10.0
5959 Q2 = torch .randn (size2 , size2 , device = self .device ) * 0.01
@@ -76,7 +76,7 @@ def test_normalization_on_three_tensors(self, size1, size2, size3):
7676 self .assertAlmostEqual (new_max2 .item (), expected_max .item (), places = 5 )
7777 self .assertAlmostEqual (new_max3 .item (), expected_max .item (), places = 5 )
7878
79- def test_modifies_in_place_on_three_tensors (self ):
79+ def test_modifies_in_place_on_three_tensors (self ) -> None :
8080 """Test that balance_Q modifies tensors in place."""
8181 Q = torch .randn (3 , 3 , device = self .device )
8282 original_id = id (Q )
@@ -89,12 +89,12 @@ def test_modifies_in_place_on_three_tensors(self):
8989class NormLowerBoundSpdTest (parameterized .TestCase ):
9090 """Test cases for norm_lower_bound_spd function."""
9191
92- def setUp (self ):
92+ def setUp (self ) -> None :
9393 """Set up test fixtures."""
9494 torch .manual_seed (42 ) # For reproducible tests
9595 self .device = torch .device ("cuda" ) if torch .cuda .is_available () else torch .device ("cpu" )
9696
97- def test_diagonal_matrix (self ):
97+ def test_diagonal_matrix (self ) -> None :
9898 """Test norm_lower_bound_spd with diagonal matrix."""
9999 # For diagonal matrix, spectral norm equals largest diagonal entry
100100 diag_values = torch .tensor ([1.0 , 3.0 , 2.0 ], device = self .device )
@@ -108,15 +108,15 @@ def test_diagonal_matrix(self):
108108 # For diagonal matrix, bound should be reasonably tight
109109 self .assertGreater (bound .item (), 0.5 * actual_norm .item ())
110110
111- def test_identity_matrix (self ):
111+ def test_identity_matrix (self ) -> None :
112112 """Test norm_lower_bound_spd with identity matrix."""
113113 A = torch .eye (3 , device = self .device )
114114 bound = norm_lower_bound_spd (A )
115115
116116 # For identity matrix, spectral norm is 1
117117 self .assertAlmostEqual (bound .item (), 1.0 , places = 5 )
118118
119- def test_zero_matrix (self ):
119+ def test_zero_matrix (self ) -> None :
120120 """Test norm_lower_bound_spd with zero matrix."""
121121 A = torch .zeros (3 , 3 , device = self .device )
122122 bound = norm_lower_bound_spd (A )
@@ -128,7 +128,7 @@ def test_zero_matrix(self):
128128 dtype = [torch .float32 , torch .bfloat16 ],
129129 size = [32 , 256 , 4096 ],
130130 )
131- def test_norm_lower_bound_spd_is_lower_bound (self , dtype , size ) :
131+ def test_norm_lower_bound_spd_is_lower_bound (self , dtype : torch . dtype , size : int ) -> None :
132132 """Test that norm_lower_bound_spd provides a valid lower bound."""
133133 # Create a random SPD matrix
134134 B = torch .randn (size , size , dtype = dtype , device = self .device )
@@ -150,20 +150,20 @@ def test_norm_lower_bound_spd_is_lower_bound(self, dtype, size):
150150class NormLowerBoundSkewTest (parameterized .TestCase ):
151151 """Test cases for norm_lower_bound_skew function."""
152152
153- def setUp (self ):
153+ def setUp (self ) -> None :
154154 """Set up test fixtures."""
155155 torch .manual_seed (42 ) # For reproducible tests
156156 self .device = torch .device ("cuda" ) if torch .cuda .is_available () else torch .device ("cpu" )
157157
158- def test_zero_matrix (self ):
158+ def test_zero_matrix (self ) -> None :
159159 """Test norm_lower_bound_skew with zero matrix."""
160160 A = torch .zeros (3 , 3 , device = self .device )
161161 bound = norm_lower_bound_skew (A )
162162
163163 # For zero matrix, bound should be 0
164164 self .assertAlmostEqual (bound .item (), 0.0 , places = 5 )
165165
166- def test_small_skew_symmetric_matrix (self ):
166+ def test_small_skew_symmetric_matrix (self ) -> None :
167167 """Test norm_lower_bound_skew with a simple skew-symmetric matrix."""
168168 # Create a simple 3x3 skew-symmetric matrix
169169 A = torch .tensor ([[0.0 , 1.0 , - 2.0 ], [- 1.0 , 0.0 , 3.0 ], [2.0 , - 3.0 , 0.0 ]], device = self .device )
@@ -177,7 +177,7 @@ def test_small_skew_symmetric_matrix(self):
177177 # Bound should be positive for non-zero matrix
178178 self .assertGreater (bound .item (), 0.0 )
179179
180- def test_identity_based_skew_matrix (self ):
180+ def test_identity_based_skew_matrix (self ) -> None :
181181 """Test norm_lower_bound_skew with matrix based on identity structure."""
182182 # Create skew-symmetric matrix from anti-symmetric part of random matrix
183183 n = 4
@@ -194,7 +194,7 @@ def test_identity_based_skew_matrix(self):
194194 dtype = [torch .float32 , torch .float64 ],
195195 size = [32 , 128 , 256 ],
196196 )
197- def test_norm_lower_bound_skew_is_lower_bound (self , dtype , size ) :
197+ def test_norm_lower_bound_skew_is_lower_bound (self , dtype : torch . dtype , size : int ) -> None :
198198 """Test that norm_lower_bound_skew provides a valid lower bound."""
199199 # Create a random skew-symmetric matrix
200200 B = torch .randn (size , size , dtype = dtype , device = self .device )
@@ -211,7 +211,7 @@ def test_norm_lower_bound_skew_is_lower_bound(self, dtype, size):
211211 self .assertGreaterEqual (bound .item (), 0.0 )
212212
213213 @parameterized .parameters ([4 , 16 , 32 ])
214- def test_different_subspace_dimensions (self , rank ) :
214+ def test_different_subspace_dimensions (self , rank : int ) -> None :
215215 """Test norm_lower_bound_skew with different subspace dimensions."""
216216 # Create a skew-symmetric matrix
217217 B = torch .randn (64 , 64 , device = self .device )
@@ -222,7 +222,7 @@ def test_different_subspace_dimensions(self, rank):
222222 self .assertGreaterEqual (bound .item (), 0.0 )
223223
224224 actual_norm = torch .linalg .matrix_norm (A , ord = 2 )
225- self .assertLessEqual (bound .item (), actual_norm .item () + 1e-4 )
225+ self .assertLessEqual (bound .item (), actual_norm .item () + 1e-5 )
226226
227227
228228if __name__ == "__main__" :
0 commit comments