Skip to content

Commit e4e5d65

Browse files
committed
3rd order convergence is only about number of steps, changed test to reflect that
Signed-off-by: mikail <[email protected]>
1 parent 6ef6f90 commit e4e5d65

File tree

1 file changed

+9
-13
lines changed

1 file changed

+9
-13
lines changed

tests/test_procrustes_step.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,13 @@ def test_preserves_determinant_sign_for_real_matrices(self) -> None:
152152
self.assertGreater(initial_det_pos.item() * final_det_pos.item(), 0)
153153
self.assertGreater(initial_det_neg.item() * final_det_neg.item(), 0)
154154

155-
def test_order3_converges_faster_amplitude_recovery(self) -> None:
155+
@parameterized.parameters(
156+
(0.015625,),
157+
(0.03125,),
158+
(0.0625,),
159+
(0.125,),
160+
)
161+
def test_order3_converges_faster_amplitude_recovery(self, max_step_size: float = 0.0625) -> None:
156162
"""Test that order 3 converges faster than order 2 in amplitude recovery setting."""
157163
# Use amplitude recovery setup to compare convergence speed
158164
n = 10
@@ -170,8 +176,8 @@ def test_order3_converges_faster_amplitude_recovery(self) -> None:
170176

171177
# Run procrustes steps and track error
172178
for _ in range(max_steps):
173-
Q_order2 = procrustes_step(Q_order2, order=2)
174-
Q_order3 = procrustes_step(Q_order3, order=3)
179+
Q_order2 = procrustes_step(Q_order2, order=2, max_step_size=max_step_size)
180+
Q_order3 = procrustes_step(Q_order3, order=3, max_step_size=max_step_size)
175181

176182
err_order2 = torch.max(torch.abs(Q_order2 - Amplitude)) / torch.max(torch.abs(Amplitude))
177183
err_order3 = torch.max(torch.abs(Q_order3 - Amplitude)) / torch.max(torch.abs(Amplitude))
@@ -195,16 +201,6 @@ def test_order3_converges_faster_amplitude_recovery(self) -> None:
195201
f"order 2 in {steps_to_converge_order2} steps. Order 3 should be faster.",
196202
)
197203

198-
# After the same number of steps, order 3 should have lower error
199-
comparison_step = min(len(err_order2_list), len(err_order3_list)) - 1
200-
if comparison_step > 0:
201-
self.assertLessEqual(
202-
err_order3_list[comparison_step],
203-
err_order2_list[comparison_step],
204-
f"At step {comparison_step}: order 3 error={err_order3_list[comparison_step]:.6f}, "
205-
f"order 2 error={err_order2_list[comparison_step]:.6f}. Order 3 should have lower error.",
206-
)
207-
208204
@parameterized.product(
209205
order=[2, 3],
210206
)

0 commit comments

Comments
 (0)