@@ -54,12 +54,11 @@ def test_improves_orthogonality_simple_case(self) -> None:
5454
5555 self .assertLessEqual (final_obj .item (), initial_obj .item () + 1e-6 )
5656
57- @parameterized .parameters (
58- (8 ,),
59- (128 ,),
60- (1024 ,),
57+ @parameterized .product (
58+ size = [8 , 128 , 1024 ],
59+ order = [2 , 3 ],
6160 )
62- def test_minimal_change_when_already_orthogonal (self , size : int ) -> None :
61+ def test_minimal_change_when_already_orthogonal (self , size : int , order : int ) -> None :
6362 """Test that procrustes_step makes minimal changes to an already orthogonal matrix."""
6463 # Create an orthogonal matrix using QR decomposition
6564 A = torch .randn (size , size , device = self .device , dtype = torch .float32 )
@@ -68,7 +67,7 @@ def test_minimal_change_when_already_orthogonal(self, size: int) -> None:
6867
6968 initial_obj = self ._procrustes_objective (Q )
7069
71- Q = procrustes_step (Q , max_step_size = 1 / 16 )
70+ Q = procrustes_step (Q , max_step_size = 1 / 16 , order = order )
7271
7372 final_obj = self ._procrustes_objective (Q )
7473
@@ -94,19 +93,17 @@ def test_handles_small_norm_gracefully(self) -> None:
9493 self .assertLess (final_obj .item (), 1e-6 )
9594 self .assertLess (final_obj .item (), initial_obj .item () + 1e-6 )
9695
97- @parameterized .parameters (
98- (0.015625 ,),
99- (0.03125 ,),
100- (0.0625 ,),
101- (0.125 ,),
96+ @parameterized .product (
97+ max_step_size = [0.015625 , 0.03125 , 0.0625 , 0.125 ],
98+ order = [2 , 3 ],
10299 )
103- def test_different_step_sizes_reduces_objective (self , max_step_size : float ) -> None :
100+ def test_different_step_sizes_reduces_objective (self , max_step_size : float , order : int ) -> None :
104101 """Test procrustes_step improvement with different step sizes."""
105102 perturbation = 1e-1 * torch .randn (10 , 10 , device = self .device , dtype = torch .float32 ) / math .sqrt (10 )
106103 Q = torch .linalg .qr (torch .randn (10 , 10 , device = self .device , dtype = torch .float32 )).Q + perturbation
107104 initial_obj = self ._procrustes_objective (Q )
108105
109- Q = procrustes_step (Q , max_step_size = max_step_size )
106+ Q = procrustes_step (Q , max_step_size = max_step_size , order = order )
110107
111108 final_obj = self ._procrustes_objective (Q )
112109
@@ -155,6 +152,102 @@ def test_preserves_determinant_sign_for_real_matrices(self) -> None:
155152 self .assertGreater (initial_det_pos .item () * final_det_pos .item (), 0 )
156153 self .assertGreater (initial_det_neg .item () * final_det_neg .item (), 0 )
157154
155+ @parameterized .parameters (
156+ (0.015625 ,),
157+ (0.03125 ,),
158+ (0.0625 ,),
159+ (0.125 ,),
160+ )
161+ def test_order3_converges_faster_amplitude_recovery (self , max_step_size : float = 0.0625 ) -> None :
162+ """Test that order 3 converges faster than order 2 in amplitude recovery setting."""
163+ # Use amplitude recovery setup to compare convergence speed
164+ n = 10
165+ Q_init = torch .randn (n , n , device = self .device , dtype = torch .float32 )
166+ u , s , vh = torch .linalg .svd (Q_init )
167+ amplitude = vh .mH @ torch .diag (s ) @ vh
168+
169+ # Start from the same initial point for both orders
170+ Q_order2 = torch .clone (Q_init )
171+ Q_order3 = torch .clone (Q_init )
172+
173+ max_steps = 200
174+ tolerance = 0.01
175+ err_order2_list = []
176+ err_order3_list = []
177+
178+ # Track convergence steps directly
179+ steps_to_converge_order2 = max_steps # Default to max_steps if doesn't converge
180+ steps_to_converge_order3 = max_steps
181+ step_count = 0
182+
183+ while step_count < max_steps :
184+ Q_order2 = procrustes_step (Q_order2 , order = 2 , max_step_size = max_step_size )
185+ Q_order3 = procrustes_step (Q_order3 , order = 3 , max_step_size = max_step_size )
186+
187+ err_order2 = torch .max (torch .abs (Q_order2 - amplitude )) / torch .max (torch .abs (amplitude ))
188+ err_order3 = torch .max (torch .abs (Q_order3 - amplitude )) / torch .max (torch .abs (amplitude ))
189+
190+ err_order2_list .append (err_order2 .item ())
191+ err_order3_list .append (err_order3 .item ())
192+ step_count += 1
193+
194+ # Record convergence step for each order (only record the first time)
195+ if err_order2 < tolerance and steps_to_converge_order2 == max_steps :
196+ steps_to_converge_order2 = step_count
197+ if err_order3 < tolerance and steps_to_converge_order3 == max_steps :
198+ steps_to_converge_order3 = step_count
199+
200+ # Stop if both have converged
201+ if err_order2 < tolerance and err_order3 < tolerance :
202+ break
203+
204+ # Order 3 should converge in fewer steps or at least as fast
205+ self .assertLessEqual (
206+ steps_to_converge_order3 ,
207+ steps_to_converge_order2 ,
208+ f"Order 3 converged in { steps_to_converge_order3 } steps, "
209+ f"order 2 in { steps_to_converge_order2 } steps. Order 3 should be faster." ,
210+ )
211+
212+ @parameterized .product (
213+ order = [2 , 3 ],
214+ )
215+ def test_recovers_amplitude_with_sign_ambiguity (self , order : int ) -> None :
216+ """Test procrustes_step recovers amplitude of real matrix up to sign ambiguity.
217+
218+ This is the main functional test for procrustes_step. It must recover the amplitude
219+ of a real matrix up to a sign ambiguity with probability 1.
220+ """
221+ for trial in range (10 ):
222+ n = 10
223+ Q = torch .randn (n , n , device = self .device , dtype = torch .float32 )
224+ u , s , vh = torch .linalg .svd (Q )
225+ amplitude = vh .mH @ torch .diag (s ) @ vh
226+ Q1 , Q2 = torch .clone (Q ), torch .clone (Q )
227+ Q2 [1 ] *= - 1 # add a reflection to Q2 to get Q2'
228+
229+ err1 , err2 = float ("inf" ), float ("inf" )
230+ max_iterations = 1000
231+ tolerance = 0.01
232+ step_count = 0
233+
234+ while step_count < max_iterations and err1 >= tolerance and err2 >= tolerance :
235+ Q1 = procrustes_step (Q1 , order = order )
236+ Q2 = procrustes_step (Q2 , order = order )
237+ err1 = torch .max (torch .abs (Q1 - amplitude )) / torch .max (torch .abs (amplitude ))
238+ err2 = torch .max (torch .abs (Q2 - amplitude )) / torch .max (torch .abs (amplitude ))
239+ step_count += 1
240+
241+ # Record convergence information
242+ converged = err1 < tolerance or err2 < tolerance
243+ final_error = min (err1 , err2 )
244+
245+ self .assertTrue (
246+ converged ,
247+ f"Trial { trial } (order={ order } ): procrustes_step failed to recover amplitude after { step_count } steps. "
248+ f"Final errors: err1={ err1 :.4f} , err2={ err2 :.4f} , best_error={ final_error :.4f} " ,
249+ )
250+
158251
159252if __name__ == "__main__" :
160253 torch .manual_seed (42 )
0 commit comments