@@ -54,12 +54,11 @@ def test_improves_orthogonality_simple_case(self) -> None:
5454
5555 self .assertLessEqual (final_obj .item (), initial_obj .item () + 1e-6 )
5656
57- @parameterized .parameters (
58- (8 ,),
59- (128 ,),
60- (1024 ,),
57+ @parameterized .product (
58+ size = [8 , 128 , 1024 ],
59+ order = [2 , 3 ],
6160 )
62- def test_minimal_change_when_already_orthogonal (self , size : int ) -> None :
61+ def test_minimal_change_when_already_orthogonal (self , size : int , order : int ) -> None :
6362 """Test that procrustes_step makes minimal changes to an already orthogonal matrix."""
6463 # Create an orthogonal matrix using QR decomposition
6564 A = torch .randn (size , size , device = self .device , dtype = torch .float32 )
@@ -68,7 +67,7 @@ def test_minimal_change_when_already_orthogonal(self, size: int) -> None:
6867
6968 initial_obj = self ._procrustes_objective (Q )
7069
71- Q = procrustes_step (Q , max_step_size = 1 / 16 )
70+ Q = procrustes_step (Q , max_step_size = 1 / 16 , order = order )
7271
7372 final_obj = self ._procrustes_objective (Q )
7473
@@ -94,19 +93,17 @@ def test_handles_small_norm_gracefully(self) -> None:
9493 self .assertLess (final_obj .item (), 1e-6 )
9594 self .assertLess (final_obj .item (), initial_obj .item () + 1e-6 )
9695
97- @parameterized .parameters (
98- (0.015625 ,),
99- (0.03125 ,),
100- (0.0625 ,),
101- (0.125 ,),
96+ @parameterized .product (
97+ max_step_size = [0.015625 , 0.03125 , 0.0625 , 0.125 ],
98+ order = [2 , 3 ],
10299 )
103- def test_different_step_sizes_reduces_objective (self , max_step_size : float ) -> None :
100+ def test_different_step_sizes_reduces_objective (self , max_step_size : float , order : int ) -> None :
104101 """Test procrustes_step improvement with different step sizes."""
105102 perturbation = 1e-1 * torch .randn (10 , 10 , device = self .device , dtype = torch .float32 ) / math .sqrt (10 )
106103 Q = torch .linalg .qr (torch .randn (10 , 10 , device = self .device , dtype = torch .float32 )).Q + perturbation
107104 initial_obj = self ._procrustes_objective (Q )
108105
109- Q = procrustes_step (Q , max_step_size = max_step_size )
106+ Q = procrustes_step (Q , max_step_size = max_step_size , order = order )
110107
111108 final_obj = self ._procrustes_objective (Q )
112109
@@ -155,6 +152,90 @@ def test_preserves_determinant_sign_for_real_matrices(self) -> None:
155152 self .assertGreater (initial_det_pos .item () * final_det_pos .item (), 0 )
156153 self .assertGreater (initial_det_neg .item () * final_det_neg .item (), 0 )
157154
155+ def test_order3_converges_faster_amplitude_recovery (self ) -> None :
156+ """Test that order 3 converges faster than order 2 in amplitude recovery setting."""
157+ # Use amplitude recovery setup to compare convergence speed
158+ n = 10
159+ Q_init = torch .randn (n , n , device = self .device , dtype = torch .float32 )
160+ U , S , Vh = torch .linalg .svd (Q_init )
161+ Amplitude = Vh .mH @ torch .diag (S ) @ Vh
162+
163+ # Start from the same initial point for both orders
164+ Q_order2 = torch .clone (Q_init )
165+ Q_order3 = torch .clone (Q_init )
166+
167+ max_steps = 200
168+ err_order2_list = []
169+ err_order3_list = []
170+
171+ # Run procrustes steps and track error
172+ for _ in range (max_steps ):
173+ Q_order2 = procrustes_step (Q_order2 , order = 2 )
174+ Q_order3 = procrustes_step (Q_order3 , order = 3 )
175+
176+ err_order2 = torch .max (torch .abs (Q_order2 - Amplitude )) / torch .max (torch .abs (Amplitude ))
177+ err_order3 = torch .max (torch .abs (Q_order3 - Amplitude )) / torch .max (torch .abs (Amplitude ))
178+
179+ err_order2_list .append (err_order2 .item ())
180+ err_order3_list .append (err_order3 .item ())
181+
182+ # Stop if both have converged
183+ if err_order2 < 0.01 and err_order3 < 0.01 :
184+ break
185+
186+ # Count steps to convergence for each order
187+ steps_to_converge_order2 = next ((i for i , err in enumerate (err_order2_list ) if err < 0.01 ), max_steps )
188+ steps_to_converge_order3 = next ((i for i , err in enumerate (err_order3_list ) if err < 0.01 ), max_steps )
189+
190+ # Order 3 should converge in fewer steps or at least as fast
191+ self .assertLessEqual (
192+ steps_to_converge_order3 ,
193+ steps_to_converge_order2 ,
194+ f"Order 3 converged in { steps_to_converge_order3 } steps, "
195+ f"order 2 in { steps_to_converge_order2 } steps. Order 3 should be faster." ,
196+ )
197+
198+ # After the same number of steps, order 3 should have lower error
199+ comparison_step = min (len (err_order2_list ), len (err_order3_list )) - 1
200+ if comparison_step > 0 :
201+ self .assertLessEqual (
202+ err_order3_list [comparison_step ],
203+ err_order2_list [comparison_step ],
204+ f"At step { comparison_step } : order 3 error={ err_order3_list [comparison_step ]:.6f} , "
205+ f"order 2 error={ err_order2_list [comparison_step ]:.6f} . Order 3 should have lower error." ,
206+ )
207+
208+ @parameterized .product (
209+ order = [2 , 3 ],
210+ )
211+ def test_recovers_amplitude_with_sign_ambiguity (self , order : int ) -> None :
212+ """Test procrustes_step recovers amplitude of real matrix up to sign ambiguity.
213+
214+ This is the main functional test for procrustes_step. It must recover the amplitude
215+ of a real matrix up to a sign ambiguity with probability 1.
216+ """
217+ for trial in range (10 ):
218+ n = 10
219+ Q = torch .randn (n , n , device = self .device , dtype = torch .float32 )
220+ U , S , Vh = torch .linalg .svd (Q )
221+ Amplitude = Vh .mH @ torch .diag (S ) @ Vh
222+ Q1 , Q2 = torch .clone (Q ), torch .clone (Q )
223+ Q2 [1 ] *= - 1 # add a reflection to Q2
224+
225+ err1 , err2 = float ("inf" ), float ("inf" )
226+ for _ in range (1000 ):
227+ Q1 = procrustes_step (Q1 , order = order )
228+ Q2 = procrustes_step (Q2 , order = order )
229+ err1 = torch .max (torch .abs (Q1 - Amplitude )) / torch .max (torch .abs (Amplitude ))
230+ err2 = torch .max (torch .abs (Q2 - Amplitude )) / torch .max (torch .abs (Amplitude ))
231+ if err1 < 0.01 or err2 < 0.01 :
232+ break
233+
234+ self .assertTrue (
235+ err1 < 0.01 or err2 < 0.01 ,
236+ f"Trial { trial } (order={ order } ): procrustes_step failed to recover amplitude. err1={ err1 :.4f} , err2={ err2 :.4f} " ,
237+ )
238+
158239
159240if __name__ == "__main__" :
160241 torch .manual_seed (42 )
0 commit comments