@@ -152,7 +152,13 @@ def test_preserves_determinant_sign_for_real_matrices(self) -> None:
152152 self .assertGreater (initial_det_pos .item () * final_det_pos .item (), 0 )
153153 self .assertGreater (initial_det_neg .item () * final_det_neg .item (), 0 )
154154
155- def test_order3_converges_faster_amplitude_recovery (self ) -> None :
155+ @parameterized .parameters (
156+ (0.015625 ,),
157+ (0.03125 ,),
158+ (0.0625 ,),
159+ (0.125 ,),
160+ )
161+ def test_order3_converges_faster_amplitude_recovery (self , max_step_size : float = 0.0625 ) -> None :
156162 """Test that order 3 converges faster than order 2 in amplitude recovery setting."""
157163 # Use amplitude recovery setup to compare convergence speed
158164 n = 10
@@ -170,8 +176,8 @@ def test_order3_converges_faster_amplitude_recovery(self) -> None:
170176
171177 # Run procrustes steps and track error
172178 for _ in range (max_steps ):
173- Q_order2 = procrustes_step (Q_order2 , order = 2 )
174- Q_order3 = procrustes_step (Q_order3 , order = 3 )
179+ Q_order2 = procrustes_step (Q_order2 , order = 2 , max_step_size = max_step_size )
180+ Q_order3 = procrustes_step (Q_order3 , order = 3 , max_step_size = max_step_size )
175181
176182 err_order2 = torch .max (torch .abs (Q_order2 - Amplitude )) / torch .max (torch .abs (Amplitude ))
177183 err_order3 = torch .max (torch .abs (Q_order3 - Amplitude )) / torch .max (torch .abs (Amplitude ))
@@ -195,16 +201,6 @@ def test_order3_converges_faster_amplitude_recovery(self) -> None:
195201 f"order 2 in { steps_to_converge_order2 } steps. Order 3 should be faster." ,
196202 )
197203
198- # After the same number of steps, order 3 should have lower error
199- comparison_step = min (len (err_order2_list ), len (err_order3_list )) - 1
200- if comparison_step > 0 :
201- self .assertLessEqual (
202- err_order3_list [comparison_step ],
203- err_order2_list [comparison_step ],
204- f"At step { comparison_step } : order 3 error={ err_order3_list [comparison_step ]:.6f} , "
205- f"order 2 error={ err_order2_list [comparison_step ]:.6f} . Order 3 should have lower error." ,
206- )
207-
208204 @parameterized .product (
209205 order = [2 , 3 ],
210206 )
0 commit comments