From 172c72575aa6e98139b8039760bb425ce13a5073 Mon Sep 17 00:00:00 2001 From: mikail Date: Tue, 14 Oct 2025 10:53:14 -0700 Subject: [PATCH 1/5] improved test for procrustes step and added third order expansion Signed-off-by: mikail --- emerging_optimizers/psgd/procrustes_step.py | 43 +++++--- tests/test_procrustes_step.py | 107 +++++++++++++++++--- 2 files changed, 123 insertions(+), 27 deletions(-) diff --git a/emerging_optimizers/psgd/procrustes_step.py b/emerging_optimizers/psgd/procrustes_step.py index e35b1f1..4b034d9 100644 --- a/emerging_optimizers/psgd/procrustes_step.py +++ b/emerging_optimizers/psgd/procrustes_step.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Literal + import torch import emerging_optimizers.utils as utils @@ -24,23 +26,28 @@ @torch.compile # type: ignore[misc] -def procrustes_step(Q: torch.Tensor, max_step_size: float = 0.125, eps: float = 1e-8) -> torch.Tensor: +def procrustes_step( + Q: torch.Tensor, max_step_size: float = 0.125, eps: float = 1e-8, order: Literal[2, 3] = 2 +) -> torch.Tensor: r"""One step of an online solver for the orthogonal Procrustes problem. The orthogonal Procrustes problem is :math:`\min_U \| U Q - I \|_F` s.t. :math:`U^H U = I` by rotating Q as :math:`\exp(a R) Q`, where :math:`R = Q^H - Q` is the generator and :math:`\|a R\| < 1`. - `max_step_size` should be less than :math:`1/4` as we only expand :math:`\exp(a R)` to its 2nd order term. - - This method is a second order expansion of a Lie algebra parametrized rotation that + If using 2nd order expansion, `max_step_size` should be less than :math:`1/4` as we only expand :math:`\exp(a R)` + to its 2nd order term. If using 3rd order expansion, `max_step_size` should be less than :math:`5/8`. + This method is an expansion of a Lie algebra parametrized rotation that uses a simple approximate line search to find the optimal step size, from Xi-Lin Li. Args: Q: Tensor of shape (n, n), general square matrix to orthogonalize. max_step_size: Maximum step size for the line search. Default is 1/8. (0.125) eps: Small number for numerical stability. + order: Order of the Taylor expansion. Must be 2 or 3. Default is 2. """ - # Note: this function is written in fp32 to avoid numerical instability while computing the taylor expansion of the exponential map + if order not in (2, 3): + raise ValueError(f"order must be 2 or 3, got {order}") + # Note: this function is written in fp32 to avoid numerical instability while computing the expansion of the exponential map with utils.fp32_matmul_precision("highest"): R = Q.T - Q R /= torch.clamp(norm_lower_bound_skew(R), min=eps) @@ -50,13 +57,21 @@ def procrustes_step(Q: torch.Tensor, max_step_size: float = 0.125, eps: float = tr_RQ = torch.trace(RQ) RRQ = R @ RQ tr_RRQ = torch.trace(RRQ) - # clip step size to max_step_size, based on a 2nd order expansion. - _step_size = torch.clamp(-tr_RQ / tr_RRQ, min=0, max=max_step_size) - # If tr_RRQ >= 0, the quadratic approximation is not concave, we fallback to max_step_size. - step_size = torch.where(tr_RRQ < 0, _step_size, max_step_size) - # rotate Q as exp(a R) Q ~ (I + a R + a^2 R^2/2) Q with an optimal step size by line search - # for 2nd order expansion, only expand exp(a R) to its 2nd term. - # Q += step_size * (RQ + 0.5 * step_size * RRQ) - Q = torch.add(Q, torch.add(RQ, RRQ, alpha=0.5 * step_size), alpha=step_size) - + if order == 2: + # clip step size to max_step_size, based on a 2nd order expansion. + _step_size = torch.clamp(-tr_RQ / tr_RRQ, min=0, max=max_step_size) + # If tr_RRQ >= 0, the quadratic approximation is not concave, we fallback to max_step_size. + step_size = torch.where(tr_RRQ < 0, _step_size, max_step_size) + # rotate Q as exp(a R) Q ~ (I + a R + a^2 R^2/2) Q with an optimal step size by line search + # for 2nd order expansion, only expand exp(a R) to its 2nd term. + # Q += _step_size * (RQ + 0.5 * _step_size * RRQ) + Q = torch.add(Q, torch.add(RQ, RRQ, alpha=0.5 * step_size), alpha=step_size) + if order == 3: + RRRQ = R @ RRQ + tr_RRRQ = torch.trace(RRRQ) + # for a 3rd order expansion, we take the larger root of the cubic. + _step_size = (-tr_RRQ - torch.sqrt(tr_RRQ * tr_RRQ - 1.5 * tr_RQ * tr_RRRQ)) / (0.75 * tr_RRRQ) + step_size = torch.clamp(_step_size, max=max_step_size) + # Q += step_size * (RQ + 0.5 * step_size * (RRQ + 0.25 * step_size * RRRQ)) + Q.add_(step_size * (RQ + 0.5 * step_size * (RRQ + 0.25 * step_size * RRRQ))) return Q diff --git a/tests/test_procrustes_step.py b/tests/test_procrustes_step.py index 8593c11..6d49857 100644 --- a/tests/test_procrustes_step.py +++ b/tests/test_procrustes_step.py @@ -54,12 +54,11 @@ def test_improves_orthogonality_simple_case(self) -> None: self.assertLessEqual(final_obj.item(), initial_obj.item() + 1e-6) - @parameterized.parameters( - (8,), - (128,), - (1024,), + @parameterized.product( + size=[8, 128, 1024], + order=[2, 3], ) - def test_minimal_change_when_already_orthogonal(self, size: int) -> None: + def test_minimal_change_when_already_orthogonal(self, size: int, order: int) -> None: """Test that procrustes_step makes minimal changes to an already orthogonal matrix.""" # Create an orthogonal matrix using QR decomposition A = torch.randn(size, size, device=self.device, dtype=torch.float32) @@ -68,7 +67,7 @@ def test_minimal_change_when_already_orthogonal(self, size: int) -> None: initial_obj = self._procrustes_objective(Q) - Q = procrustes_step(Q, max_step_size=1 / 16) + Q = procrustes_step(Q, max_step_size=1 / 16, order=order) final_obj = self._procrustes_objective(Q) @@ -94,19 +93,17 @@ def test_handles_small_norm_gracefully(self) -> None: self.assertLess(final_obj.item(), 1e-6) self.assertLess(final_obj.item(), initial_obj.item() + 1e-6) - @parameterized.parameters( - (0.015625,), - (0.03125,), - (0.0625,), - (0.125,), + @parameterized.product( + max_step_size=[0.015625, 0.03125, 0.0625, 0.125], + order=[2, 3], ) - def test_different_step_sizes_reduces_objective(self, max_step_size: float) -> None: + def test_different_step_sizes_reduces_objective(self, max_step_size: float, order: int) -> None: """Test procrustes_step improvement with different step sizes.""" perturbation = 1e-1 * torch.randn(10, 10, device=self.device, dtype=torch.float32) / math.sqrt(10) Q = torch.linalg.qr(torch.randn(10, 10, device=self.device, dtype=torch.float32)).Q + perturbation initial_obj = self._procrustes_objective(Q) - Q = procrustes_step(Q, max_step_size=max_step_size) + Q = procrustes_step(Q, max_step_size=max_step_size, order=order) final_obj = self._procrustes_objective(Q) @@ -155,6 +152,90 @@ def test_preserves_determinant_sign_for_real_matrices(self) -> None: self.assertGreater(initial_det_pos.item() * final_det_pos.item(), 0) self.assertGreater(initial_det_neg.item() * final_det_neg.item(), 0) + def test_order3_converges_faster_amplitude_recovery(self) -> None: + """Test that order 3 converges faster than order 2 in amplitude recovery setting.""" + # Use amplitude recovery setup to compare convergence speed + n = 10 + Q_init = torch.randn(n, n, device=self.device, dtype=torch.float32) + U, S, Vh = torch.linalg.svd(Q_init) + Amplitude = Vh.mH @ torch.diag(S) @ Vh + + # Start from the same initial point for both orders + Q_order2 = torch.clone(Q_init) + Q_order3 = torch.clone(Q_init) + + max_steps = 200 + err_order2_list = [] + err_order3_list = [] + + # Run procrustes steps and track error + for _ in range(max_steps): + Q_order2 = procrustes_step(Q_order2, order=2) + Q_order3 = procrustes_step(Q_order3, order=3) + + err_order2 = torch.max(torch.abs(Q_order2 - Amplitude)) / torch.max(torch.abs(Amplitude)) + err_order3 = torch.max(torch.abs(Q_order3 - Amplitude)) / torch.max(torch.abs(Amplitude)) + + err_order2_list.append(err_order2.item()) + err_order3_list.append(err_order3.item()) + + # Stop if both have converged + if err_order2 < 0.01 and err_order3 < 0.01: + break + + # Count steps to convergence for each order + steps_to_converge_order2 = next((i for i, err in enumerate(err_order2_list) if err < 0.01), max_steps) + steps_to_converge_order3 = next((i for i, err in enumerate(err_order3_list) if err < 0.01), max_steps) + + # Order 3 should converge in fewer steps or at least as fast + self.assertLessEqual( + steps_to_converge_order3, + steps_to_converge_order2, + f"Order 3 converged in {steps_to_converge_order3} steps, " + f"order 2 in {steps_to_converge_order2} steps. Order 3 should be faster.", + ) + + # After the same number of steps, order 3 should have lower error + comparison_step = min(len(err_order2_list), len(err_order3_list)) - 1 + if comparison_step > 0: + self.assertLessEqual( + err_order3_list[comparison_step], + err_order2_list[comparison_step], + f"At step {comparison_step}: order 3 error={err_order3_list[comparison_step]:.6f}, " + f"order 2 error={err_order2_list[comparison_step]:.6f}. Order 3 should have lower error.", + ) + + @parameterized.product( + order=[2, 3], + ) + def test_recovers_amplitude_with_sign_ambiguity(self, order: int) -> None: + """Test procrustes_step recovers amplitude of real matrix up to sign ambiguity. + + This is the main functional test for procrustes_step. It must recover the amplitude + of a real matrix up to a sign ambiguity with probability 1. + """ + for trial in range(10): + n = 10 + Q = torch.randn(n, n, device=self.device, dtype=torch.float32) + U, S, Vh = torch.linalg.svd(Q) + Amplitude = Vh.mH @ torch.diag(S) @ Vh + Q1, Q2 = torch.clone(Q), torch.clone(Q) + Q2[1] *= -1 # add a reflection to Q2 + + err1, err2 = float("inf"), float("inf") + for _ in range(1000): + Q1 = procrustes_step(Q1, order=order) + Q2 = procrustes_step(Q2, order=order) + err1 = torch.max(torch.abs(Q1 - Amplitude)) / torch.max(torch.abs(Amplitude)) + err2 = torch.max(torch.abs(Q2 - Amplitude)) / torch.max(torch.abs(Amplitude)) + if err1 < 0.01 or err2 < 0.01: + break + + self.assertTrue( + err1 < 0.01 or err2 < 0.01, + f"Trial {trial} (order={order}): procrustes_step failed to recover amplitude. err1={err1:.4f}, err2={err2:.4f}", + ) + if __name__ == "__main__": torch.manual_seed(42) From 9859c48b27db9122455140c31a15ff0df956321d Mon Sep 17 00:00:00 2001 From: mikail Date: Tue, 14 Oct 2025 11:07:28 -0700 Subject: [PATCH 2/5] replace inplace with torch.add Signed-off-by: mikail --- emerging_optimizers/psgd/procrustes_step.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/emerging_optimizers/psgd/procrustes_step.py b/emerging_optimizers/psgd/procrustes_step.py index 4b034d9..415e197 100644 --- a/emerging_optimizers/psgd/procrustes_step.py +++ b/emerging_optimizers/psgd/procrustes_step.py @@ -73,5 +73,6 @@ def procrustes_step( _step_size = (-tr_RRQ - torch.sqrt(tr_RRQ * tr_RRQ - 1.5 * tr_RQ * tr_RRRQ)) / (0.75 * tr_RRRQ) step_size = torch.clamp(_step_size, max=max_step_size) # Q += step_size * (RQ + 0.5 * step_size * (RRQ + 0.25 * step_size * RRRQ)) - Q.add_(step_size * (RQ + 0.5 * step_size * (RRQ + 0.25 * step_size * RRRQ))) + Q = torch.add(Q, torch.add(RQ, RRQ, alpha=0.5 * step_size), alpha=step_size) + Q = torch.add(Q, torch.add(RRQ, RRRQ, alpha=0.25 * step_size), alpha=step_size) return Q From 6ef6f908c28213d6c8b773272df14a3b70c7d0d1 Mon Sep 17 00:00:00 2001 From: mikail Date: Tue, 14 Oct 2025 15:40:31 -0700 Subject: [PATCH 3/5] fixed error in 3rd order procrustes Signed-off-by: mikail --- emerging_optimizers/psgd/procrustes_step.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/emerging_optimizers/psgd/procrustes_step.py b/emerging_optimizers/psgd/procrustes_step.py index 415e197..b3a6039 100644 --- a/emerging_optimizers/psgd/procrustes_step.py +++ b/emerging_optimizers/psgd/procrustes_step.py @@ -73,6 +73,7 @@ def procrustes_step( _step_size = (-tr_RRQ - torch.sqrt(tr_RRQ * tr_RRQ - 1.5 * tr_RQ * tr_RRRQ)) / (0.75 * tr_RRRQ) step_size = torch.clamp(_step_size, max=max_step_size) # Q += step_size * (RQ + 0.5 * step_size * (RRQ + 0.25 * step_size * RRRQ)) - Q = torch.add(Q, torch.add(RQ, RRQ, alpha=0.5 * step_size), alpha=step_size) - Q = torch.add(Q, torch.add(RRQ, RRRQ, alpha=0.25 * step_size), alpha=step_size) + Q = torch.add( + Q, torch.add(RQ, torch.add(RRQ, RRRQ, alpha=0.25 * step_size), alpha=0.5 * step_size), alpha=step_size + ) return Q From e4e5d65156cbbed2fb4a2666d110c62060618d5b Mon Sep 17 00:00:00 2001 From: mikail Date: Tue, 14 Oct 2025 16:31:21 -0700 Subject: [PATCH 4/5] 3rd order convergence is only about number of steps, changed test to reflect that Signed-off-by: mikail --- tests/test_procrustes_step.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/tests/test_procrustes_step.py b/tests/test_procrustes_step.py index 6d49857..af1cf5f 100644 --- a/tests/test_procrustes_step.py +++ b/tests/test_procrustes_step.py @@ -152,7 +152,13 @@ def test_preserves_determinant_sign_for_real_matrices(self) -> None: self.assertGreater(initial_det_pos.item() * final_det_pos.item(), 0) self.assertGreater(initial_det_neg.item() * final_det_neg.item(), 0) - def test_order3_converges_faster_amplitude_recovery(self) -> None: + @parameterized.parameters( + (0.015625,), + (0.03125,), + (0.0625,), + (0.125,), + ) + def test_order3_converges_faster_amplitude_recovery(self, max_step_size: float = 0.0625) -> None: """Test that order 3 converges faster than order 2 in amplitude recovery setting.""" # Use amplitude recovery setup to compare convergence speed n = 10 @@ -170,8 +176,8 @@ def test_order3_converges_faster_amplitude_recovery(self) -> None: # Run procrustes steps and track error for _ in range(max_steps): - Q_order2 = procrustes_step(Q_order2, order=2) - Q_order3 = procrustes_step(Q_order3, order=3) + Q_order2 = procrustes_step(Q_order2, order=2, max_step_size=max_step_size) + Q_order3 = procrustes_step(Q_order3, order=3, max_step_size=max_step_size) err_order2 = torch.max(torch.abs(Q_order2 - Amplitude)) / torch.max(torch.abs(Amplitude)) err_order3 = torch.max(torch.abs(Q_order3 - Amplitude)) / torch.max(torch.abs(Amplitude)) @@ -195,16 +201,6 @@ def test_order3_converges_faster_amplitude_recovery(self) -> None: f"order 2 in {steps_to_converge_order2} steps. Order 3 should be faster.", ) - # After the same number of steps, order 3 should have lower error - comparison_step = min(len(err_order2_list), len(err_order3_list)) - 1 - if comparison_step > 0: - self.assertLessEqual( - err_order3_list[comparison_step], - err_order2_list[comparison_step], - f"At step {comparison_step}: order 3 error={err_order3_list[comparison_step]:.6f}, " - f"order 2 error={err_order2_list[comparison_step]:.6f}. Order 3 should have lower error.", - ) - @parameterized.product( order=[2, 3], ) From 5f830fe717e821ffdac41946c017b6b671341b28 Mon Sep 17 00:00:00 2001 From: mikail Date: Tue, 14 Oct 2025 17:27:13 -0700 Subject: [PATCH 5/5] addressed PR comments and made step counting change in test Signed-off-by: mikail --- tests/test_procrustes_step.py | 58 ++++++++++++++++++++++------------- 1 file changed, 37 insertions(+), 21 deletions(-) diff --git a/tests/test_procrustes_step.py b/tests/test_procrustes_step.py index af1cf5f..9654937 100644 --- a/tests/test_procrustes_step.py +++ b/tests/test_procrustes_step.py @@ -163,36 +163,44 @@ def test_order3_converges_faster_amplitude_recovery(self, max_step_size: float = # Use amplitude recovery setup to compare convergence speed n = 10 Q_init = torch.randn(n, n, device=self.device, dtype=torch.float32) - U, S, Vh = torch.linalg.svd(Q_init) - Amplitude = Vh.mH @ torch.diag(S) @ Vh + u, s, vh = torch.linalg.svd(Q_init) + amplitude = vh.mH @ torch.diag(s) @ vh # Start from the same initial point for both orders Q_order2 = torch.clone(Q_init) Q_order3 = torch.clone(Q_init) max_steps = 200 + tolerance = 0.01 err_order2_list = [] err_order3_list = [] - # Run procrustes steps and track error - for _ in range(max_steps): + # Track convergence steps directly + steps_to_converge_order2 = max_steps # Default to max_steps if doesn't converge + steps_to_converge_order3 = max_steps + step_count = 0 + + while step_count < max_steps: Q_order2 = procrustes_step(Q_order2, order=2, max_step_size=max_step_size) Q_order3 = procrustes_step(Q_order3, order=3, max_step_size=max_step_size) - err_order2 = torch.max(torch.abs(Q_order2 - Amplitude)) / torch.max(torch.abs(Amplitude)) - err_order3 = torch.max(torch.abs(Q_order3 - Amplitude)) / torch.max(torch.abs(Amplitude)) + err_order2 = torch.max(torch.abs(Q_order2 - amplitude)) / torch.max(torch.abs(amplitude)) + err_order3 = torch.max(torch.abs(Q_order3 - amplitude)) / torch.max(torch.abs(amplitude)) err_order2_list.append(err_order2.item()) err_order3_list.append(err_order3.item()) + step_count += 1 + + # Record convergence step for each order (only record the first time) + if err_order2 < tolerance and steps_to_converge_order2 == max_steps: + steps_to_converge_order2 = step_count + if err_order3 < tolerance and steps_to_converge_order3 == max_steps: + steps_to_converge_order3 = step_count # Stop if both have converged - if err_order2 < 0.01 and err_order3 < 0.01: + if err_order2 < tolerance and err_order3 < tolerance: break - # Count steps to convergence for each order - steps_to_converge_order2 = next((i for i, err in enumerate(err_order2_list) if err < 0.01), max_steps) - steps_to_converge_order3 = next((i for i, err in enumerate(err_order3_list) if err < 0.01), max_steps) - # Order 3 should converge in fewer steps or at least as fast self.assertLessEqual( steps_to_converge_order3, @@ -213,23 +221,31 @@ def test_recovers_amplitude_with_sign_ambiguity(self, order: int) -> None: for trial in range(10): n = 10 Q = torch.randn(n, n, device=self.device, dtype=torch.float32) - U, S, Vh = torch.linalg.svd(Q) - Amplitude = Vh.mH @ torch.diag(S) @ Vh + u, s, vh = torch.linalg.svd(Q) + amplitude = vh.mH @ torch.diag(s) @ vh Q1, Q2 = torch.clone(Q), torch.clone(Q) - Q2[1] *= -1 # add a reflection to Q2 + Q2[1] *= -1 # add a reflection to Q2 to get Q2' err1, err2 = float("inf"), float("inf") - for _ in range(1000): + max_iterations = 1000 + tolerance = 0.01 + step_count = 0 + + while step_count < max_iterations and err1 >= tolerance and err2 >= tolerance: Q1 = procrustes_step(Q1, order=order) Q2 = procrustes_step(Q2, order=order) - err1 = torch.max(torch.abs(Q1 - Amplitude)) / torch.max(torch.abs(Amplitude)) - err2 = torch.max(torch.abs(Q2 - Amplitude)) / torch.max(torch.abs(Amplitude)) - if err1 < 0.01 or err2 < 0.01: - break + err1 = torch.max(torch.abs(Q1 - amplitude)) / torch.max(torch.abs(amplitude)) + err2 = torch.max(torch.abs(Q2 - amplitude)) / torch.max(torch.abs(amplitude)) + step_count += 1 + + # Record convergence information + converged = err1 < tolerance or err2 < tolerance + final_error = min(err1, err2) self.assertTrue( - err1 < 0.01 or err2 < 0.01, - f"Trial {trial} (order={order}): procrustes_step failed to recover amplitude. err1={err1:.4f}, err2={err2:.4f}", + converged, + f"Trial {trial} (order={order}): procrustes_step failed to recover amplitude after {step_count} steps. " + f"Final errors: err1={err1:.4f}, err2={err2:.4f}, best_error={final_error:.4f}", )