diff --git a/emerging_optimizers/psgd/procrustes_step.py b/emerging_optimizers/psgd/procrustes_step.py index e35b1f1..b3a6039 100644 --- a/emerging_optimizers/psgd/procrustes_step.py +++ b/emerging_optimizers/psgd/procrustes_step.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Literal + import torch import emerging_optimizers.utils as utils @@ -24,23 +26,28 @@ @torch.compile # type: ignore[misc] -def procrustes_step(Q: torch.Tensor, max_step_size: float = 0.125, eps: float = 1e-8) -> torch.Tensor: +def procrustes_step( + Q: torch.Tensor, max_step_size: float = 0.125, eps: float = 1e-8, order: Literal[2, 3] = 2 +) -> torch.Tensor: r"""One step of an online solver for the orthogonal Procrustes problem. The orthogonal Procrustes problem is :math:`\min_U \| U Q - I \|_F` s.t. :math:`U^H U = I` by rotating Q as :math:`\exp(a R) Q`, where :math:`R = Q^H - Q` is the generator and :math:`\|a R\| < 1`. - `max_step_size` should be less than :math:`1/4` as we only expand :math:`\exp(a R)` to its 2nd order term. - - This method is a second order expansion of a Lie algebra parametrized rotation that + If using 2nd order expansion, `max_step_size` should be less than :math:`1/4` as we only expand :math:`\exp(a R)` + to its 2nd order term. If using 3rd order expansion, `max_step_size` should be less than :math:`5/8`. + This method is an expansion of a Lie algebra parametrized rotation that uses a simple approximate line search to find the optimal step size, from Xi-Lin Li. Args: Q: Tensor of shape (n, n), general square matrix to orthogonalize. max_step_size: Maximum step size for the line search. Default is 1/8. (0.125) eps: Small number for numerical stability. + order: Order of the Taylor expansion. Must be 2 or 3. Default is 2. """ - # Note: this function is written in fp32 to avoid numerical instability while computing the taylor expansion of the exponential map + if order not in (2, 3): + raise ValueError(f"order must be 2 or 3, got {order}") + # Note: this function is written in fp32 to avoid numerical instability while computing the expansion of the exponential map with utils.fp32_matmul_precision("highest"): R = Q.T - Q R /= torch.clamp(norm_lower_bound_skew(R), min=eps) @@ -50,13 +57,23 @@ def procrustes_step(Q: torch.Tensor, max_step_size: float = 0.125, eps: float = tr_RQ = torch.trace(RQ) RRQ = R @ RQ tr_RRQ = torch.trace(RRQ) - # clip step size to max_step_size, based on a 2nd order expansion. - _step_size = torch.clamp(-tr_RQ / tr_RRQ, min=0, max=max_step_size) - # If tr_RRQ >= 0, the quadratic approximation is not concave, we fallback to max_step_size. - step_size = torch.where(tr_RRQ < 0, _step_size, max_step_size) - # rotate Q as exp(a R) Q ~ (I + a R + a^2 R^2/2) Q with an optimal step size by line search - # for 2nd order expansion, only expand exp(a R) to its 2nd term. - # Q += step_size * (RQ + 0.5 * step_size * RRQ) - Q = torch.add(Q, torch.add(RQ, RRQ, alpha=0.5 * step_size), alpha=step_size) - + if order == 2: + # clip step size to max_step_size, based on a 2nd order expansion. + _step_size = torch.clamp(-tr_RQ / tr_RRQ, min=0, max=max_step_size) + # If tr_RRQ >= 0, the quadratic approximation is not concave, we fallback to max_step_size. + step_size = torch.where(tr_RRQ < 0, _step_size, max_step_size) + # rotate Q as exp(a R) Q ~ (I + a R + a^2 R^2/2) Q with an optimal step size by line search + # for 2nd order expansion, only expand exp(a R) to its 2nd term. + # Q += _step_size * (RQ + 0.5 * _step_size * RRQ) + Q = torch.add(Q, torch.add(RQ, RRQ, alpha=0.5 * step_size), alpha=step_size) + if order == 3: + RRRQ = R @ RRQ + tr_RRRQ = torch.trace(RRRQ) + # for a 3rd order expansion, we take the larger root of the cubic. + _step_size = (-tr_RRQ - torch.sqrt(tr_RRQ * tr_RRQ - 1.5 * tr_RQ * tr_RRRQ)) / (0.75 * tr_RRRQ) + step_size = torch.clamp(_step_size, max=max_step_size) + # Q += step_size * (RQ + 0.5 * step_size * (RRQ + 0.25 * step_size * RRRQ)) + Q = torch.add( + Q, torch.add(RQ, torch.add(RRQ, RRRQ, alpha=0.25 * step_size), alpha=0.5 * step_size), alpha=step_size + ) return Q diff --git a/tests/test_procrustes_step.py b/tests/test_procrustes_step.py index 8593c11..9654937 100644 --- a/tests/test_procrustes_step.py +++ b/tests/test_procrustes_step.py @@ -54,12 +54,11 @@ def test_improves_orthogonality_simple_case(self) -> None: self.assertLessEqual(final_obj.item(), initial_obj.item() + 1e-6) - @parameterized.parameters( - (8,), - (128,), - (1024,), + @parameterized.product( + size=[8, 128, 1024], + order=[2, 3], ) - def test_minimal_change_when_already_orthogonal(self, size: int) -> None: + def test_minimal_change_when_already_orthogonal(self, size: int, order: int) -> None: """Test that procrustes_step makes minimal changes to an already orthogonal matrix.""" # Create an orthogonal matrix using QR decomposition A = torch.randn(size, size, device=self.device, dtype=torch.float32) @@ -68,7 +67,7 @@ def test_minimal_change_when_already_orthogonal(self, size: int) -> None: initial_obj = self._procrustes_objective(Q) - Q = procrustes_step(Q, max_step_size=1 / 16) + Q = procrustes_step(Q, max_step_size=1 / 16, order=order) final_obj = self._procrustes_objective(Q) @@ -94,19 +93,17 @@ def test_handles_small_norm_gracefully(self) -> None: self.assertLess(final_obj.item(), 1e-6) self.assertLess(final_obj.item(), initial_obj.item() + 1e-6) - @parameterized.parameters( - (0.015625,), - (0.03125,), - (0.0625,), - (0.125,), + @parameterized.product( + max_step_size=[0.015625, 0.03125, 0.0625, 0.125], + order=[2, 3], ) - def test_different_step_sizes_reduces_objective(self, max_step_size: float) -> None: + def test_different_step_sizes_reduces_objective(self, max_step_size: float, order: int) -> None: """Test procrustes_step improvement with different step sizes.""" perturbation = 1e-1 * torch.randn(10, 10, device=self.device, dtype=torch.float32) / math.sqrt(10) Q = torch.linalg.qr(torch.randn(10, 10, device=self.device, dtype=torch.float32)).Q + perturbation initial_obj = self._procrustes_objective(Q) - Q = procrustes_step(Q, max_step_size=max_step_size) + Q = procrustes_step(Q, max_step_size=max_step_size, order=order) final_obj = self._procrustes_objective(Q) @@ -155,6 +152,102 @@ def test_preserves_determinant_sign_for_real_matrices(self) -> None: self.assertGreater(initial_det_pos.item() * final_det_pos.item(), 0) self.assertGreater(initial_det_neg.item() * final_det_neg.item(), 0) + @parameterized.parameters( + (0.015625,), + (0.03125,), + (0.0625,), + (0.125,), + ) + def test_order3_converges_faster_amplitude_recovery(self, max_step_size: float = 0.0625) -> None: + """Test that order 3 converges faster than order 2 in amplitude recovery setting.""" + # Use amplitude recovery setup to compare convergence speed + n = 10 + Q_init = torch.randn(n, n, device=self.device, dtype=torch.float32) + u, s, vh = torch.linalg.svd(Q_init) + amplitude = vh.mH @ torch.diag(s) @ vh + + # Start from the same initial point for both orders + Q_order2 = torch.clone(Q_init) + Q_order3 = torch.clone(Q_init) + + max_steps = 200 + tolerance = 0.01 + err_order2_list = [] + err_order3_list = [] + + # Track convergence steps directly + steps_to_converge_order2 = max_steps # Default to max_steps if doesn't converge + steps_to_converge_order3 = max_steps + step_count = 0 + + while step_count < max_steps: + Q_order2 = procrustes_step(Q_order2, order=2, max_step_size=max_step_size) + Q_order3 = procrustes_step(Q_order3, order=3, max_step_size=max_step_size) + + err_order2 = torch.max(torch.abs(Q_order2 - amplitude)) / torch.max(torch.abs(amplitude)) + err_order3 = torch.max(torch.abs(Q_order3 - amplitude)) / torch.max(torch.abs(amplitude)) + + err_order2_list.append(err_order2.item()) + err_order3_list.append(err_order3.item()) + step_count += 1 + + # Record convergence step for each order (only record the first time) + if err_order2 < tolerance and steps_to_converge_order2 == max_steps: + steps_to_converge_order2 = step_count + if err_order3 < tolerance and steps_to_converge_order3 == max_steps: + steps_to_converge_order3 = step_count + + # Stop if both have converged + if err_order2 < tolerance and err_order3 < tolerance: + break + + # Order 3 should converge in fewer steps or at least as fast + self.assertLessEqual( + steps_to_converge_order3, + steps_to_converge_order2, + f"Order 3 converged in {steps_to_converge_order3} steps, " + f"order 2 in {steps_to_converge_order2} steps. Order 3 should be faster.", + ) + + @parameterized.product( + order=[2, 3], + ) + def test_recovers_amplitude_with_sign_ambiguity(self, order: int) -> None: + """Test procrustes_step recovers amplitude of real matrix up to sign ambiguity. + + This is the main functional test for procrustes_step. It must recover the amplitude + of a real matrix up to a sign ambiguity with probability 1. + """ + for trial in range(10): + n = 10 + Q = torch.randn(n, n, device=self.device, dtype=torch.float32) + u, s, vh = torch.linalg.svd(Q) + amplitude = vh.mH @ torch.diag(s) @ vh + Q1, Q2 = torch.clone(Q), torch.clone(Q) + Q2[1] *= -1 # add a reflection to Q2 to get Q2' + + err1, err2 = float("inf"), float("inf") + max_iterations = 1000 + tolerance = 0.01 + step_count = 0 + + while step_count < max_iterations and err1 >= tolerance and err2 >= tolerance: + Q1 = procrustes_step(Q1, order=order) + Q2 = procrustes_step(Q2, order=order) + err1 = torch.max(torch.abs(Q1 - amplitude)) / torch.max(torch.abs(amplitude)) + err2 = torch.max(torch.abs(Q2 - amplitude)) / torch.max(torch.abs(amplitude)) + step_count += 1 + + # Record convergence information + converged = err1 < tolerance or err2 < tolerance + final_error = min(err1, err2) + + self.assertTrue( + converged, + f"Trial {trial} (order={order}): procrustes_step failed to recover amplitude after {step_count} steps. " + f"Final errors: err1={err1:.4f}, err2={err2:.4f}, best_error={final_error:.4f}", + ) + if __name__ == "__main__": torch.manual_seed(42)