@@ -163,36 +163,44 @@ def test_order3_converges_faster_amplitude_recovery(self, max_step_size: float =
163163 # Use amplitude recovery setup to compare convergence speed
164164 n = 10
165165 Q_init = torch .randn (n , n , device = self .device , dtype = torch .float32 )
166- U , S , Vh = torch .linalg .svd (Q_init )
167- Amplitude = Vh .mH @ torch .diag (S ) @ Vh
166+ u , s , vh = torch .linalg .svd (Q_init )
167+ amplitude = vh .mH @ torch .diag (s ) @ vh
168168
169169 # Start from the same initial point for both orders
170170 Q_order2 = torch .clone (Q_init )
171171 Q_order3 = torch .clone (Q_init )
172172
173173 max_steps = 200
174+ tolerance = 0.01
174175 err_order2_list = []
175176 err_order3_list = []
176177
177- # Run procrustes steps and track error
178- for _ in range (max_steps ):
178+ # Track convergence steps directly
179+ steps_to_converge_order2 = max_steps # Default to max_steps if doesn't converge
180+ steps_to_converge_order3 = max_steps
181+ step_count = 0
182+
183+ while step_count < max_steps :
179184 Q_order2 = procrustes_step (Q_order2 , order = 2 , max_step_size = max_step_size )
180185 Q_order3 = procrustes_step (Q_order3 , order = 3 , max_step_size = max_step_size )
181186
182- err_order2 = torch .max (torch .abs (Q_order2 - Amplitude )) / torch .max (torch .abs (Amplitude ))
183- err_order3 = torch .max (torch .abs (Q_order3 - Amplitude )) / torch .max (torch .abs (Amplitude ))
187+ err_order2 = torch .max (torch .abs (Q_order2 - amplitude )) / torch .max (torch .abs (amplitude ))
188+ err_order3 = torch .max (torch .abs (Q_order3 - amplitude )) / torch .max (torch .abs (amplitude ))
184189
185190 err_order2_list .append (err_order2 .item ())
186191 err_order3_list .append (err_order3 .item ())
192+ step_count += 1
193+
194+ # Record convergence step for each order (only record the first time)
195+ if err_order2 < tolerance and steps_to_converge_order2 == max_steps :
196+ steps_to_converge_order2 = step_count
197+ if err_order3 < tolerance and steps_to_converge_order3 == max_steps :
198+ steps_to_converge_order3 = step_count
187199
188200 # Stop if both have converged
189- if err_order2 < 0.01 and err_order3 < 0.01 :
201+ if err_order2 < tolerance and err_order3 < tolerance :
190202 break
191203
192- # Count steps to convergence for each order
193- steps_to_converge_order2 = next ((i for i , err in enumerate (err_order2_list ) if err < 0.01 ), max_steps )
194- steps_to_converge_order3 = next ((i for i , err in enumerate (err_order3_list ) if err < 0.01 ), max_steps )
195-
196204 # Order 3 should converge in fewer steps or at least as fast
197205 self .assertLessEqual (
198206 steps_to_converge_order3 ,
@@ -213,23 +221,31 @@ def test_recovers_amplitude_with_sign_ambiguity(self, order: int) -> None:
213221 for trial in range (10 ):
214222 n = 10
215223 Q = torch .randn (n , n , device = self .device , dtype = torch .float32 )
216- U , S , Vh = torch .linalg .svd (Q )
217- Amplitude = Vh .mH @ torch .diag (S ) @ Vh
224+ u , s , vh = torch .linalg .svd (Q )
225+ amplitude = vh .mH @ torch .diag (s ) @ vh
218226 Q1 , Q2 = torch .clone (Q ), torch .clone (Q )
219- Q2 [1 ] *= - 1 # add a reflection to Q2
227+ Q2 [1 ] *= - 1 # add a reflection to Q2 to get Q2'
220228
221229 err1 , err2 = float ("inf" ), float ("inf" )
222- for _ in range (1000 ):
230+ max_iterations = 1000
231+ tolerance = 0.01
232+ step_count = 0
233+
234+ while step_count < max_iterations and err1 >= tolerance and err2 >= tolerance :
223235 Q1 = procrustes_step (Q1 , order = order )
224236 Q2 = procrustes_step (Q2 , order = order )
225- err1 = torch .max (torch .abs (Q1 - Amplitude )) / torch .max (torch .abs (Amplitude ))
226- err2 = torch .max (torch .abs (Q2 - Amplitude )) / torch .max (torch .abs (Amplitude ))
227- if err1 < 0.01 or err2 < 0.01 :
228- break
237+ err1 = torch .max (torch .abs (Q1 - amplitude )) / torch .max (torch .abs (amplitude ))
238+ err2 = torch .max (torch .abs (Q2 - amplitude )) / torch .max (torch .abs (amplitude ))
239+ step_count += 1
240+
241+ # Record convergence information
242+ converged = err1 < tolerance or err2 < tolerance
243+ final_error = min (err1 , err2 )
229244
230245 self .assertTrue (
231- err1 < 0.01 or err2 < 0.01 ,
232- f"Trial { trial } (order={ order } ): procrustes_step failed to recover amplitude. err1={ err1 :.4f} , err2={ err2 :.4f} " ,
246+ converged ,
247+ f"Trial { trial } (order={ order } ): procrustes_step failed to recover amplitude after { step_count } steps. "
248+ f"Final errors: err1={ err1 :.4f} , err2={ err2 :.4f} , best_error={ final_error :.4f} " ,
233249 )
234250
235251
0 commit comments