Skip to content

Commit 5f830fe

Browse files
committed
addressed PR comments and made step counting change in test
Signed-off-by: mikail <[email protected]>
1 parent e4e5d65 commit 5f830fe

File tree

1 file changed

+37
-21
lines changed

1 file changed

+37
-21
lines changed

tests/test_procrustes_step.py

Lines changed: 37 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -163,36 +163,44 @@ def test_order3_converges_faster_amplitude_recovery(self, max_step_size: float =
163163
# Use amplitude recovery setup to compare convergence speed
164164
n = 10
165165
Q_init = torch.randn(n, n, device=self.device, dtype=torch.float32)
166-
U, S, Vh = torch.linalg.svd(Q_init)
167-
Amplitude = Vh.mH @ torch.diag(S) @ Vh
166+
u, s, vh = torch.linalg.svd(Q_init)
167+
amplitude = vh.mH @ torch.diag(s) @ vh
168168

169169
# Start from the same initial point for both orders
170170
Q_order2 = torch.clone(Q_init)
171171
Q_order3 = torch.clone(Q_init)
172172

173173
max_steps = 200
174+
tolerance = 0.01
174175
err_order2_list = []
175176
err_order3_list = []
176177

177-
# Run procrustes steps and track error
178-
for _ in range(max_steps):
178+
# Track convergence steps directly
179+
steps_to_converge_order2 = max_steps # Default to max_steps if doesn't converge
180+
steps_to_converge_order3 = max_steps
181+
step_count = 0
182+
183+
while step_count < max_steps:
179184
Q_order2 = procrustes_step(Q_order2, order=2, max_step_size=max_step_size)
180185
Q_order3 = procrustes_step(Q_order3, order=3, max_step_size=max_step_size)
181186

182-
err_order2 = torch.max(torch.abs(Q_order2 - Amplitude)) / torch.max(torch.abs(Amplitude))
183-
err_order3 = torch.max(torch.abs(Q_order3 - Amplitude)) / torch.max(torch.abs(Amplitude))
187+
err_order2 = torch.max(torch.abs(Q_order2 - amplitude)) / torch.max(torch.abs(amplitude))
188+
err_order3 = torch.max(torch.abs(Q_order3 - amplitude)) / torch.max(torch.abs(amplitude))
184189

185190
err_order2_list.append(err_order2.item())
186191
err_order3_list.append(err_order3.item())
192+
step_count += 1
193+
194+
# Record convergence step for each order (only record the first time)
195+
if err_order2 < tolerance and steps_to_converge_order2 == max_steps:
196+
steps_to_converge_order2 = step_count
197+
if err_order3 < tolerance and steps_to_converge_order3 == max_steps:
198+
steps_to_converge_order3 = step_count
187199

188200
# Stop if both have converged
189-
if err_order2 < 0.01 and err_order3 < 0.01:
201+
if err_order2 < tolerance and err_order3 < tolerance:
190202
break
191203

192-
# Count steps to convergence for each order
193-
steps_to_converge_order2 = next((i for i, err in enumerate(err_order2_list) if err < 0.01), max_steps)
194-
steps_to_converge_order3 = next((i for i, err in enumerate(err_order3_list) if err < 0.01), max_steps)
195-
196204
# Order 3 should converge in fewer steps or at least as fast
197205
self.assertLessEqual(
198206
steps_to_converge_order3,
@@ -213,23 +221,31 @@ def test_recovers_amplitude_with_sign_ambiguity(self, order: int) -> None:
213221
for trial in range(10):
214222
n = 10
215223
Q = torch.randn(n, n, device=self.device, dtype=torch.float32)
216-
U, S, Vh = torch.linalg.svd(Q)
217-
Amplitude = Vh.mH @ torch.diag(S) @ Vh
224+
u, s, vh = torch.linalg.svd(Q)
225+
amplitude = vh.mH @ torch.diag(s) @ vh
218226
Q1, Q2 = torch.clone(Q), torch.clone(Q)
219-
Q2[1] *= -1 # add a reflection to Q2
227+
Q2[1] *= -1 # add a reflection to Q2 to get Q2'
220228

221229
err1, err2 = float("inf"), float("inf")
222-
for _ in range(1000):
230+
max_iterations = 1000
231+
tolerance = 0.01
232+
step_count = 0
233+
234+
while step_count < max_iterations and err1 >= tolerance and err2 >= tolerance:
223235
Q1 = procrustes_step(Q1, order=order)
224236
Q2 = procrustes_step(Q2, order=order)
225-
err1 = torch.max(torch.abs(Q1 - Amplitude)) / torch.max(torch.abs(Amplitude))
226-
err2 = torch.max(torch.abs(Q2 - Amplitude)) / torch.max(torch.abs(Amplitude))
227-
if err1 < 0.01 or err2 < 0.01:
228-
break
237+
err1 = torch.max(torch.abs(Q1 - amplitude)) / torch.max(torch.abs(amplitude))
238+
err2 = torch.max(torch.abs(Q2 - amplitude)) / torch.max(torch.abs(amplitude))
239+
step_count += 1
240+
241+
# Record convergence information
242+
converged = err1 < tolerance or err2 < tolerance
243+
final_error = min(err1, err2)
229244

230245
self.assertTrue(
231-
err1 < 0.01 or err2 < 0.01,
232-
f"Trial {trial} (order={order}): procrustes_step failed to recover amplitude. err1={err1:.4f}, err2={err2:.4f}",
246+
converged,
247+
f"Trial {trial} (order={order}): procrustes_step failed to recover amplitude after {step_count} steps. "
248+
f"Final errors: err1={err1:.4f}, err2={err2:.4f}, best_error={final_error:.4f}",
233249
)
234250

235251

0 commit comments

Comments
 (0)