Skip to content

Commit 172c725

Browse files
committed
improved test for procrustes step and added third order expansion
Signed-off-by: mikail <[email protected]>
1 parent fac3056 commit 172c725

File tree

2 files changed

+123
-27
lines changed

2 files changed

+123
-27
lines changed

emerging_optimizers/psgd/procrustes_step.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
from typing import Literal
16+
1517
import torch
1618

1719
import emerging_optimizers.utils as utils
@@ -24,23 +26,28 @@
2426

2527

2628
@torch.compile # type: ignore[misc]
27-
def procrustes_step(Q: torch.Tensor, max_step_size: float = 0.125, eps: float = 1e-8) -> torch.Tensor:
29+
def procrustes_step(
30+
Q: torch.Tensor, max_step_size: float = 0.125, eps: float = 1e-8, order: Literal[2, 3] = 2
31+
) -> torch.Tensor:
2832
r"""One step of an online solver for the orthogonal Procrustes problem.
2933
3034
The orthogonal Procrustes problem is :math:`\min_U \| U Q - I \|_F` s.t. :math:`U^H U = I`
3135
by rotating Q as :math:`\exp(a R) Q`, where :math:`R = Q^H - Q` is the generator and :math:`\|a R\| < 1`.
3236
33-
`max_step_size` should be less than :math:`1/4` as we only expand :math:`\exp(a R)` to its 2nd order term.
34-
35-
This method is a second order expansion of a Lie algebra parametrized rotation that
37+
If using 2nd order expansion, `max_step_size` should be less than :math:`1/4` as we only expand :math:`\exp(a R)`
38+
to its 2nd order term. If using 3rd order expansion, `max_step_size` should be less than :math:`5/8`.
39+
This method is an expansion of a Lie algebra parametrized rotation that
3640
uses a simple approximate line search to find the optimal step size, from Xi-Lin Li.
3741
3842
Args:
3943
Q: Tensor of shape (n, n), general square matrix to orthogonalize.
4044
max_step_size: Maximum step size for the line search. Default is 1/8. (0.125)
4145
eps: Small number for numerical stability.
46+
order: Order of the Taylor expansion. Must be 2 or 3. Default is 2.
4247
"""
43-
# Note: this function is written in fp32 to avoid numerical instability while computing the taylor expansion of the exponential map
48+
if order not in (2, 3):
49+
raise ValueError(f"order must be 2 or 3, got {order}")
50+
# Note: this function is written in fp32 to avoid numerical instability while computing the expansion of the exponential map
4451
with utils.fp32_matmul_precision("highest"):
4552
R = Q.T - Q
4653
R /= torch.clamp(norm_lower_bound_skew(R), min=eps)
@@ -50,13 +57,21 @@ def procrustes_step(Q: torch.Tensor, max_step_size: float = 0.125, eps: float =
5057
tr_RQ = torch.trace(RQ)
5158
RRQ = R @ RQ
5259
tr_RRQ = torch.trace(RRQ)
53-
# clip step size to max_step_size, based on a 2nd order expansion.
54-
_step_size = torch.clamp(-tr_RQ / tr_RRQ, min=0, max=max_step_size)
55-
# If tr_RRQ >= 0, the quadratic approximation is not concave, we fallback to max_step_size.
56-
step_size = torch.where(tr_RRQ < 0, _step_size, max_step_size)
57-
# rotate Q as exp(a R) Q ~ (I + a R + a^2 R^2/2) Q with an optimal step size by line search
58-
# for 2nd order expansion, only expand exp(a R) to its 2nd term.
59-
# Q += step_size * (RQ + 0.5 * step_size * RRQ)
60-
Q = torch.add(Q, torch.add(RQ, RRQ, alpha=0.5 * step_size), alpha=step_size)
61-
60+
if order == 2:
61+
# clip step size to max_step_size, based on a 2nd order expansion.
62+
_step_size = torch.clamp(-tr_RQ / tr_RRQ, min=0, max=max_step_size)
63+
# If tr_RRQ >= 0, the quadratic approximation is not concave, we fallback to max_step_size.
64+
step_size = torch.where(tr_RRQ < 0, _step_size, max_step_size)
65+
# rotate Q as exp(a R) Q ~ (I + a R + a^2 R^2/2) Q with an optimal step size by line search
66+
# for 2nd order expansion, only expand exp(a R) to its 2nd term.
67+
# Q += _step_size * (RQ + 0.5 * _step_size * RRQ)
68+
Q = torch.add(Q, torch.add(RQ, RRQ, alpha=0.5 * step_size), alpha=step_size)
69+
if order == 3:
70+
RRRQ = R @ RRQ
71+
tr_RRRQ = torch.trace(RRRQ)
72+
# for a 3rd order expansion, we take the larger root of the cubic.
73+
_step_size = (-tr_RRQ - torch.sqrt(tr_RRQ * tr_RRQ - 1.5 * tr_RQ * tr_RRRQ)) / (0.75 * tr_RRRQ)
74+
step_size = torch.clamp(_step_size, max=max_step_size)
75+
# Q += step_size * (RQ + 0.5 * step_size * (RRQ + 0.25 * step_size * RRRQ))
76+
Q.add_(step_size * (RQ + 0.5 * step_size * (RRQ + 0.25 * step_size * RRRQ)))
6277
return Q

tests/test_procrustes_step.py

Lines changed: 94 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,11 @@ def test_improves_orthogonality_simple_case(self) -> None:
5454

5555
self.assertLessEqual(final_obj.item(), initial_obj.item() + 1e-6)
5656

57-
@parameterized.parameters(
58-
(8,),
59-
(128,),
60-
(1024,),
57+
@parameterized.product(
58+
size=[8, 128, 1024],
59+
order=[2, 3],
6160
)
62-
def test_minimal_change_when_already_orthogonal(self, size: int) -> None:
61+
def test_minimal_change_when_already_orthogonal(self, size: int, order: int) -> None:
6362
"""Test that procrustes_step makes minimal changes to an already orthogonal matrix."""
6463
# Create an orthogonal matrix using QR decomposition
6564
A = torch.randn(size, size, device=self.device, dtype=torch.float32)
@@ -68,7 +67,7 @@ def test_minimal_change_when_already_orthogonal(self, size: int) -> None:
6867

6968
initial_obj = self._procrustes_objective(Q)
7069

71-
Q = procrustes_step(Q, max_step_size=1 / 16)
70+
Q = procrustes_step(Q, max_step_size=1 / 16, order=order)
7271

7372
final_obj = self._procrustes_objective(Q)
7473

@@ -94,19 +93,17 @@ def test_handles_small_norm_gracefully(self) -> None:
9493
self.assertLess(final_obj.item(), 1e-6)
9594
self.assertLess(final_obj.item(), initial_obj.item() + 1e-6)
9695

97-
@parameterized.parameters(
98-
(0.015625,),
99-
(0.03125,),
100-
(0.0625,),
101-
(0.125,),
96+
@parameterized.product(
97+
max_step_size=[0.015625, 0.03125, 0.0625, 0.125],
98+
order=[2, 3],
10299
)
103-
def test_different_step_sizes_reduces_objective(self, max_step_size: float) -> None:
100+
def test_different_step_sizes_reduces_objective(self, max_step_size: float, order: int) -> None:
104101
"""Test procrustes_step improvement with different step sizes."""
105102
perturbation = 1e-1 * torch.randn(10, 10, device=self.device, dtype=torch.float32) / math.sqrt(10)
106103
Q = torch.linalg.qr(torch.randn(10, 10, device=self.device, dtype=torch.float32)).Q + perturbation
107104
initial_obj = self._procrustes_objective(Q)
108105

109-
Q = procrustes_step(Q, max_step_size=max_step_size)
106+
Q = procrustes_step(Q, max_step_size=max_step_size, order=order)
110107

111108
final_obj = self._procrustes_objective(Q)
112109

@@ -155,6 +152,90 @@ def test_preserves_determinant_sign_for_real_matrices(self) -> None:
155152
self.assertGreater(initial_det_pos.item() * final_det_pos.item(), 0)
156153
self.assertGreater(initial_det_neg.item() * final_det_neg.item(), 0)
157154

155+
def test_order3_converges_faster_amplitude_recovery(self) -> None:
156+
"""Test that order 3 converges faster than order 2 in amplitude recovery setting."""
157+
# Use amplitude recovery setup to compare convergence speed
158+
n = 10
159+
Q_init = torch.randn(n, n, device=self.device, dtype=torch.float32)
160+
U, S, Vh = torch.linalg.svd(Q_init)
161+
Amplitude = Vh.mH @ torch.diag(S) @ Vh
162+
163+
# Start from the same initial point for both orders
164+
Q_order2 = torch.clone(Q_init)
165+
Q_order3 = torch.clone(Q_init)
166+
167+
max_steps = 200
168+
err_order2_list = []
169+
err_order3_list = []
170+
171+
# Run procrustes steps and track error
172+
for _ in range(max_steps):
173+
Q_order2 = procrustes_step(Q_order2, order=2)
174+
Q_order3 = procrustes_step(Q_order3, order=3)
175+
176+
err_order2 = torch.max(torch.abs(Q_order2 - Amplitude)) / torch.max(torch.abs(Amplitude))
177+
err_order3 = torch.max(torch.abs(Q_order3 - Amplitude)) / torch.max(torch.abs(Amplitude))
178+
179+
err_order2_list.append(err_order2.item())
180+
err_order3_list.append(err_order3.item())
181+
182+
# Stop if both have converged
183+
if err_order2 < 0.01 and err_order3 < 0.01:
184+
break
185+
186+
# Count steps to convergence for each order
187+
steps_to_converge_order2 = next((i for i, err in enumerate(err_order2_list) if err < 0.01), max_steps)
188+
steps_to_converge_order3 = next((i for i, err in enumerate(err_order3_list) if err < 0.01), max_steps)
189+
190+
# Order 3 should converge in fewer steps or at least as fast
191+
self.assertLessEqual(
192+
steps_to_converge_order3,
193+
steps_to_converge_order2,
194+
f"Order 3 converged in {steps_to_converge_order3} steps, "
195+
f"order 2 in {steps_to_converge_order2} steps. Order 3 should be faster.",
196+
)
197+
198+
# After the same number of steps, order 3 should have lower error
199+
comparison_step = min(len(err_order2_list), len(err_order3_list)) - 1
200+
if comparison_step > 0:
201+
self.assertLessEqual(
202+
err_order3_list[comparison_step],
203+
err_order2_list[comparison_step],
204+
f"At step {comparison_step}: order 3 error={err_order3_list[comparison_step]:.6f}, "
205+
f"order 2 error={err_order2_list[comparison_step]:.6f}. Order 3 should have lower error.",
206+
)
207+
208+
@parameterized.product(
209+
order=[2, 3],
210+
)
211+
def test_recovers_amplitude_with_sign_ambiguity(self, order: int) -> None:
212+
"""Test procrustes_step recovers amplitude of real matrix up to sign ambiguity.
213+
214+
This is the main functional test for procrustes_step. It must recover the amplitude
215+
of a real matrix up to a sign ambiguity with probability 1.
216+
"""
217+
for trial in range(10):
218+
n = 10
219+
Q = torch.randn(n, n, device=self.device, dtype=torch.float32)
220+
U, S, Vh = torch.linalg.svd(Q)
221+
Amplitude = Vh.mH @ torch.diag(S) @ Vh
222+
Q1, Q2 = torch.clone(Q), torch.clone(Q)
223+
Q2[1] *= -1 # add a reflection to Q2
224+
225+
err1, err2 = float("inf"), float("inf")
226+
for _ in range(1000):
227+
Q1 = procrustes_step(Q1, order=order)
228+
Q2 = procrustes_step(Q2, order=order)
229+
err1 = torch.max(torch.abs(Q1 - Amplitude)) / torch.max(torch.abs(Amplitude))
230+
err2 = torch.max(torch.abs(Q2 - Amplitude)) / torch.max(torch.abs(Amplitude))
231+
if err1 < 0.01 or err2 < 0.01:
232+
break
233+
234+
self.assertTrue(
235+
err1 < 0.01 or err2 < 0.01,
236+
f"Trial {trial} (order={order}): procrustes_step failed to recover amplitude. err1={err1:.4f}, err2={err2:.4f}",
237+
)
238+
158239

159240
if __name__ == "__main__":
160241
torch.manual_seed(42)

0 commit comments

Comments
 (0)