Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 31 additions & 14 deletions emerging_optimizers/psgd/procrustes_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Literal

import torch

import emerging_optimizers.utils as utils
Expand All @@ -24,23 +26,28 @@


@torch.compile # type: ignore[misc]
def procrustes_step(Q: torch.Tensor, max_step_size: float = 0.125, eps: float = 1e-8) -> torch.Tensor:
def procrustes_step(
Q: torch.Tensor, max_step_size: float = 0.125, eps: float = 1e-8, order: Literal[2, 3] = 2
) -> torch.Tensor:
r"""One step of an online solver for the orthogonal Procrustes problem.

The orthogonal Procrustes problem is :math:`\min_U \| U Q - I \|_F` s.t. :math:`U^H U = I`
by rotating Q as :math:`\exp(a R) Q`, where :math:`R = Q^H - Q` is the generator and :math:`\|a R\| < 1`.

`max_step_size` should be less than :math:`1/4` as we only expand :math:`\exp(a R)` to its 2nd order term.

This method is a second order expansion of a Lie algebra parametrized rotation that
If using 2nd order expansion, `max_step_size` should be less than :math:`1/4` as we only expand :math:`\exp(a R)`
to its 2nd order term. If using 3rd order expansion, `max_step_size` should be less than :math:`5/8`.
This method is an expansion of a Lie algebra parametrized rotation that
uses a simple approximate line search to find the optimal step size, from Xi-Lin Li.

Args:
Q: Tensor of shape (n, n), general square matrix to orthogonalize.
max_step_size: Maximum step size for the line search. Default is 1/8. (0.125)
eps: Small number for numerical stability.
order: Order of the Taylor expansion. Must be 2 or 3. Default is 2.
"""
# Note: this function is written in fp32 to avoid numerical instability while computing the taylor expansion of the exponential map
if order not in (2, 3):
raise ValueError(f"order must be 2 or 3, got {order}")
# Note: this function is written in fp32 to avoid numerical instability while computing the expansion of the exponential map
with utils.fp32_matmul_precision("highest"):
R = Q.T - Q
R /= torch.clamp(norm_lower_bound_skew(R), min=eps)
Expand All @@ -50,13 +57,23 @@ def procrustes_step(Q: torch.Tensor, max_step_size: float = 0.125, eps: float =
tr_RQ = torch.trace(RQ)
RRQ = R @ RQ
tr_RRQ = torch.trace(RRQ)
# clip step size to max_step_size, based on a 2nd order expansion.
_step_size = torch.clamp(-tr_RQ / tr_RRQ, min=0, max=max_step_size)
# If tr_RRQ >= 0, the quadratic approximation is not concave, we fallback to max_step_size.
step_size = torch.where(tr_RRQ < 0, _step_size, max_step_size)
# rotate Q as exp(a R) Q ~ (I + a R + a^2 R^2/2) Q with an optimal step size by line search
# for 2nd order expansion, only expand exp(a R) to its 2nd term.
# Q += step_size * (RQ + 0.5 * step_size * RRQ)
Q = torch.add(Q, torch.add(RQ, RRQ, alpha=0.5 * step_size), alpha=step_size)

if order == 2:
# clip step size to max_step_size, based on a 2nd order expansion.
_step_size = torch.clamp(-tr_RQ / tr_RRQ, min=0, max=max_step_size)
# If tr_RRQ >= 0, the quadratic approximation is not concave, we fallback to max_step_size.
step_size = torch.where(tr_RRQ < 0, _step_size, max_step_size)
# rotate Q as exp(a R) Q ~ (I + a R + a^2 R^2/2) Q with an optimal step size by line search
# for 2nd order expansion, only expand exp(a R) to its 2nd term.
# Q += _step_size * (RQ + 0.5 * _step_size * RRQ)
Q = torch.add(Q, torch.add(RQ, RRQ, alpha=0.5 * step_size), alpha=step_size)
if order == 3:
RRRQ = R @ RRQ
tr_RRRQ = torch.trace(RRRQ)
# for a 3rd order expansion, we take the larger root of the cubic.
_step_size = (-tr_RRQ - torch.sqrt(tr_RRQ * tr_RRQ - 1.5 * tr_RQ * tr_RRRQ)) / (0.75 * tr_RRRQ)
step_size = torch.clamp(_step_size, max=max_step_size)
# Q += step_size * (RQ + 0.5 * step_size * (RRQ + 0.25 * step_size * RRRQ))
Q = torch.add(
Q, torch.add(RQ, torch.add(RRQ, RRRQ, alpha=0.25 * step_size), alpha=0.5 * step_size), alpha=step_size
)
return Q
119 changes: 106 additions & 13 deletions tests/test_procrustes_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,11 @@ def test_improves_orthogonality_simple_case(self) -> None:

self.assertLessEqual(final_obj.item(), initial_obj.item() + 1e-6)

@parameterized.parameters(
(8,),
(128,),
(1024,),
@parameterized.product(
size=[8, 128, 1024],
order=[2, 3],
)
def test_minimal_change_when_already_orthogonal(self, size: int) -> None:
def test_minimal_change_when_already_orthogonal(self, size: int, order: int) -> None:
"""Test that procrustes_step makes minimal changes to an already orthogonal matrix."""
# Create an orthogonal matrix using QR decomposition
A = torch.randn(size, size, device=self.device, dtype=torch.float32)
Expand All @@ -68,7 +67,7 @@ def test_minimal_change_when_already_orthogonal(self, size: int) -> None:

initial_obj = self._procrustes_objective(Q)

Q = procrustes_step(Q, max_step_size=1 / 16)
Q = procrustes_step(Q, max_step_size=1 / 16, order=order)

final_obj = self._procrustes_objective(Q)

Expand All @@ -94,19 +93,17 @@ def test_handles_small_norm_gracefully(self) -> None:
self.assertLess(final_obj.item(), 1e-6)
self.assertLess(final_obj.item(), initial_obj.item() + 1e-6)

@parameterized.parameters(
(0.015625,),
(0.03125,),
(0.0625,),
(0.125,),
@parameterized.product(
max_step_size=[0.015625, 0.03125, 0.0625, 0.125],
order=[2, 3],
)
def test_different_step_sizes_reduces_objective(self, max_step_size: float) -> None:
def test_different_step_sizes_reduces_objective(self, max_step_size: float, order: int) -> None:
"""Test procrustes_step improvement with different step sizes."""
perturbation = 1e-1 * torch.randn(10, 10, device=self.device, dtype=torch.float32) / math.sqrt(10)
Q = torch.linalg.qr(torch.randn(10, 10, device=self.device, dtype=torch.float32)).Q + perturbation
initial_obj = self._procrustes_objective(Q)

Q = procrustes_step(Q, max_step_size=max_step_size)
Q = procrustes_step(Q, max_step_size=max_step_size, order=order)

final_obj = self._procrustes_objective(Q)

Expand Down Expand Up @@ -155,6 +152,102 @@ def test_preserves_determinant_sign_for_real_matrices(self) -> None:
self.assertGreater(initial_det_pos.item() * final_det_pos.item(), 0)
self.assertGreater(initial_det_neg.item() * final_det_neg.item(), 0)

@parameterized.parameters(
(0.015625,),
(0.03125,),
(0.0625,),
(0.125,),
)
def test_order3_converges_faster_amplitude_recovery(self, max_step_size: float = 0.0625) -> None:
"""Test that order 3 converges faster than order 2 in amplitude recovery setting."""
# Use amplitude recovery setup to compare convergence speed
n = 10
Q_init = torch.randn(n, n, device=self.device, dtype=torch.float32)
u, s, vh = torch.linalg.svd(Q_init)
amplitude = vh.mH @ torch.diag(s) @ vh

# Start from the same initial point for both orders
Q_order2 = torch.clone(Q_init)
Q_order3 = torch.clone(Q_init)

max_steps = 200
tolerance = 0.01
err_order2_list = []
err_order3_list = []

# Track convergence steps directly
steps_to_converge_order2 = max_steps # Default to max_steps if doesn't converge
steps_to_converge_order3 = max_steps
step_count = 0

while step_count < max_steps:
Q_order2 = procrustes_step(Q_order2, order=2, max_step_size=max_step_size)
Q_order3 = procrustes_step(Q_order3, order=3, max_step_size=max_step_size)

err_order2 = torch.max(torch.abs(Q_order2 - amplitude)) / torch.max(torch.abs(amplitude))
err_order3 = torch.max(torch.abs(Q_order3 - amplitude)) / torch.max(torch.abs(amplitude))

err_order2_list.append(err_order2.item())
err_order3_list.append(err_order3.item())
step_count += 1

# Record convergence step for each order (only record the first time)
if err_order2 < tolerance and steps_to_converge_order2 == max_steps:
steps_to_converge_order2 = step_count
if err_order3 < tolerance and steps_to_converge_order3 == max_steps:
steps_to_converge_order3 = step_count

# Stop if both have converged
if err_order2 < tolerance and err_order3 < tolerance:
break

# Order 3 should converge in fewer steps or at least as fast
self.assertLessEqual(
steps_to_converge_order3,
steps_to_converge_order2,
f"Order 3 converged in {steps_to_converge_order3} steps, "
f"order 2 in {steps_to_converge_order2} steps. Order 3 should be faster.",
)

@parameterized.product(
order=[2, 3],
)
def test_recovers_amplitude_with_sign_ambiguity(self, order: int) -> None:
"""Test procrustes_step recovers amplitude of real matrix up to sign ambiguity.

This is the main functional test for procrustes_step. It must recover the amplitude
of a real matrix up to a sign ambiguity with probability 1.
"""
for trial in range(10):
n = 10
Q = torch.randn(n, n, device=self.device, dtype=torch.float32)
u, s, vh = torch.linalg.svd(Q)
amplitude = vh.mH @ torch.diag(s) @ vh
Q1, Q2 = torch.clone(Q), torch.clone(Q)
Q2[1] *= -1 # add a reflection to Q2 to get Q2'

err1, err2 = float("inf"), float("inf")
max_iterations = 1000
tolerance = 0.01
step_count = 0

while step_count < max_iterations and err1 >= tolerance and err2 >= tolerance:
Q1 = procrustes_step(Q1, order=order)
Q2 = procrustes_step(Q2, order=order)
err1 = torch.max(torch.abs(Q1 - amplitude)) / torch.max(torch.abs(amplitude))
err2 = torch.max(torch.abs(Q2 - amplitude)) / torch.max(torch.abs(amplitude))
step_count += 1

# Record convergence information
converged = err1 < tolerance or err2 < tolerance
final_error = min(err1, err2)

self.assertTrue(
converged,
f"Trial {trial} (order={order}): procrustes_step failed to recover amplitude after {step_count} steps. "
f"Final errors: err1={err1:.4f}, err2={err2:.4f}, best_error={final_error:.4f}",
)


if __name__ == "__main__":
torch.manual_seed(42)
Expand Down