Skip to content

Commit 9d93954

Browse files
improved test for procrustes step and added third order expansion (#58)
* improved test for procrustes step and added third order expansion Signed-off-by: mikail <[email protected]>
1 parent fac3056 commit 9d93954

File tree

2 files changed

+137
-27
lines changed

2 files changed

+137
-27
lines changed

emerging_optimizers/psgd/procrustes_step.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
from typing import Literal
16+
1517
import torch
1618

1719
import emerging_optimizers.utils as utils
@@ -24,23 +26,28 @@
2426

2527

2628
@torch.compile # type: ignore[misc]
27-
def procrustes_step(Q: torch.Tensor, max_step_size: float = 0.125, eps: float = 1e-8) -> torch.Tensor:
29+
def procrustes_step(
30+
Q: torch.Tensor, max_step_size: float = 0.125, eps: float = 1e-8, order: Literal[2, 3] = 2
31+
) -> torch.Tensor:
2832
r"""One step of an online solver for the orthogonal Procrustes problem.
2933
3034
The orthogonal Procrustes problem is :math:`\min_U \| U Q - I \|_F` s.t. :math:`U^H U = I`
3135
by rotating Q as :math:`\exp(a R) Q`, where :math:`R = Q^H - Q` is the generator and :math:`\|a R\| < 1`.
3236
33-
`max_step_size` should be less than :math:`1/4` as we only expand :math:`\exp(a R)` to its 2nd order term.
34-
35-
This method is a second order expansion of a Lie algebra parametrized rotation that
37+
If using 2nd order expansion, `max_step_size` should be less than :math:`1/4` as we only expand :math:`\exp(a R)`
38+
to its 2nd order term. If using 3rd order expansion, `max_step_size` should be less than :math:`5/8`.
39+
This method is an expansion of a Lie algebra parametrized rotation that
3640
uses a simple approximate line search to find the optimal step size, from Xi-Lin Li.
3741
3842
Args:
3943
Q: Tensor of shape (n, n), general square matrix to orthogonalize.
4044
max_step_size: Maximum step size for the line search. Default is 1/8. (0.125)
4145
eps: Small number for numerical stability.
46+
order: Order of the Taylor expansion. Must be 2 or 3. Default is 2.
4247
"""
43-
# Note: this function is written in fp32 to avoid numerical instability while computing the taylor expansion of the exponential map
48+
if order not in (2, 3):
49+
raise ValueError(f"order must be 2 or 3, got {order}")
50+
# Note: this function is written in fp32 to avoid numerical instability while computing the expansion of the exponential map
4451
with utils.fp32_matmul_precision("highest"):
4552
R = Q.T - Q
4653
R /= torch.clamp(norm_lower_bound_skew(R), min=eps)
@@ -50,13 +57,23 @@ def procrustes_step(Q: torch.Tensor, max_step_size: float = 0.125, eps: float =
5057
tr_RQ = torch.trace(RQ)
5158
RRQ = R @ RQ
5259
tr_RRQ = torch.trace(RRQ)
53-
# clip step size to max_step_size, based on a 2nd order expansion.
54-
_step_size = torch.clamp(-tr_RQ / tr_RRQ, min=0, max=max_step_size)
55-
# If tr_RRQ >= 0, the quadratic approximation is not concave, we fallback to max_step_size.
56-
step_size = torch.where(tr_RRQ < 0, _step_size, max_step_size)
57-
# rotate Q as exp(a R) Q ~ (I + a R + a^2 R^2/2) Q with an optimal step size by line search
58-
# for 2nd order expansion, only expand exp(a R) to its 2nd term.
59-
# Q += step_size * (RQ + 0.5 * step_size * RRQ)
60-
Q = torch.add(Q, torch.add(RQ, RRQ, alpha=0.5 * step_size), alpha=step_size)
61-
60+
if order == 2:
61+
# clip step size to max_step_size, based on a 2nd order expansion.
62+
_step_size = torch.clamp(-tr_RQ / tr_RRQ, min=0, max=max_step_size)
63+
# If tr_RRQ >= 0, the quadratic approximation is not concave, we fallback to max_step_size.
64+
step_size = torch.where(tr_RRQ < 0, _step_size, max_step_size)
65+
# rotate Q as exp(a R) Q ~ (I + a R + a^2 R^2/2) Q with an optimal step size by line search
66+
# for 2nd order expansion, only expand exp(a R) to its 2nd term.
67+
# Q += _step_size * (RQ + 0.5 * _step_size * RRQ)
68+
Q = torch.add(Q, torch.add(RQ, RRQ, alpha=0.5 * step_size), alpha=step_size)
69+
if order == 3:
70+
RRRQ = R @ RRQ
71+
tr_RRRQ = torch.trace(RRRQ)
72+
# for a 3rd order expansion, we take the larger root of the cubic.
73+
_step_size = (-tr_RRQ - torch.sqrt(tr_RRQ * tr_RRQ - 1.5 * tr_RQ * tr_RRRQ)) / (0.75 * tr_RRRQ)
74+
step_size = torch.clamp(_step_size, max=max_step_size)
75+
# Q += step_size * (RQ + 0.5 * step_size * (RRQ + 0.25 * step_size * RRRQ))
76+
Q = torch.add(
77+
Q, torch.add(RQ, torch.add(RRQ, RRRQ, alpha=0.25 * step_size), alpha=0.5 * step_size), alpha=step_size
78+
)
6279
return Q

tests/test_procrustes_step.py

Lines changed: 106 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,11 @@ def test_improves_orthogonality_simple_case(self) -> None:
5454

5555
self.assertLessEqual(final_obj.item(), initial_obj.item() + 1e-6)
5656

57-
@parameterized.parameters(
58-
(8,),
59-
(128,),
60-
(1024,),
57+
@parameterized.product(
58+
size=[8, 128, 1024],
59+
order=[2, 3],
6160
)
62-
def test_minimal_change_when_already_orthogonal(self, size: int) -> None:
61+
def test_minimal_change_when_already_orthogonal(self, size: int, order: int) -> None:
6362
"""Test that procrustes_step makes minimal changes to an already orthogonal matrix."""
6463
# Create an orthogonal matrix using QR decomposition
6564
A = torch.randn(size, size, device=self.device, dtype=torch.float32)
@@ -68,7 +67,7 @@ def test_minimal_change_when_already_orthogonal(self, size: int) -> None:
6867

6968
initial_obj = self._procrustes_objective(Q)
7069

71-
Q = procrustes_step(Q, max_step_size=1 / 16)
70+
Q = procrustes_step(Q, max_step_size=1 / 16, order=order)
7271

7372
final_obj = self._procrustes_objective(Q)
7473

@@ -94,19 +93,17 @@ def test_handles_small_norm_gracefully(self) -> None:
9493
self.assertLess(final_obj.item(), 1e-6)
9594
self.assertLess(final_obj.item(), initial_obj.item() + 1e-6)
9695

97-
@parameterized.parameters(
98-
(0.015625,),
99-
(0.03125,),
100-
(0.0625,),
101-
(0.125,),
96+
@parameterized.product(
97+
max_step_size=[0.015625, 0.03125, 0.0625, 0.125],
98+
order=[2, 3],
10299
)
103-
def test_different_step_sizes_reduces_objective(self, max_step_size: float) -> None:
100+
def test_different_step_sizes_reduces_objective(self, max_step_size: float, order: int) -> None:
104101
"""Test procrustes_step improvement with different step sizes."""
105102
perturbation = 1e-1 * torch.randn(10, 10, device=self.device, dtype=torch.float32) / math.sqrt(10)
106103
Q = torch.linalg.qr(torch.randn(10, 10, device=self.device, dtype=torch.float32)).Q + perturbation
107104
initial_obj = self._procrustes_objective(Q)
108105

109-
Q = procrustes_step(Q, max_step_size=max_step_size)
106+
Q = procrustes_step(Q, max_step_size=max_step_size, order=order)
110107

111108
final_obj = self._procrustes_objective(Q)
112109

@@ -155,6 +152,102 @@ def test_preserves_determinant_sign_for_real_matrices(self) -> None:
155152
self.assertGreater(initial_det_pos.item() * final_det_pos.item(), 0)
156153
self.assertGreater(initial_det_neg.item() * final_det_neg.item(), 0)
157154

155+
@parameterized.parameters(
156+
(0.015625,),
157+
(0.03125,),
158+
(0.0625,),
159+
(0.125,),
160+
)
161+
def test_order3_converges_faster_amplitude_recovery(self, max_step_size: float = 0.0625) -> None:
162+
"""Test that order 3 converges faster than order 2 in amplitude recovery setting."""
163+
# Use amplitude recovery setup to compare convergence speed
164+
n = 10
165+
Q_init = torch.randn(n, n, device=self.device, dtype=torch.float32)
166+
u, s, vh = torch.linalg.svd(Q_init)
167+
amplitude = vh.mH @ torch.diag(s) @ vh
168+
169+
# Start from the same initial point for both orders
170+
Q_order2 = torch.clone(Q_init)
171+
Q_order3 = torch.clone(Q_init)
172+
173+
max_steps = 200
174+
tolerance = 0.01
175+
err_order2_list = []
176+
err_order3_list = []
177+
178+
# Track convergence steps directly
179+
steps_to_converge_order2 = max_steps # Default to max_steps if doesn't converge
180+
steps_to_converge_order3 = max_steps
181+
step_count = 0
182+
183+
while step_count < max_steps:
184+
Q_order2 = procrustes_step(Q_order2, order=2, max_step_size=max_step_size)
185+
Q_order3 = procrustes_step(Q_order3, order=3, max_step_size=max_step_size)
186+
187+
err_order2 = torch.max(torch.abs(Q_order2 - amplitude)) / torch.max(torch.abs(amplitude))
188+
err_order3 = torch.max(torch.abs(Q_order3 - amplitude)) / torch.max(torch.abs(amplitude))
189+
190+
err_order2_list.append(err_order2.item())
191+
err_order3_list.append(err_order3.item())
192+
step_count += 1
193+
194+
# Record convergence step for each order (only record the first time)
195+
if err_order2 < tolerance and steps_to_converge_order2 == max_steps:
196+
steps_to_converge_order2 = step_count
197+
if err_order3 < tolerance and steps_to_converge_order3 == max_steps:
198+
steps_to_converge_order3 = step_count
199+
200+
# Stop if both have converged
201+
if err_order2 < tolerance and err_order3 < tolerance:
202+
break
203+
204+
# Order 3 should converge in fewer steps or at least as fast
205+
self.assertLessEqual(
206+
steps_to_converge_order3,
207+
steps_to_converge_order2,
208+
f"Order 3 converged in {steps_to_converge_order3} steps, "
209+
f"order 2 in {steps_to_converge_order2} steps. Order 3 should be faster.",
210+
)
211+
212+
@parameterized.product(
213+
order=[2, 3],
214+
)
215+
def test_recovers_amplitude_with_sign_ambiguity(self, order: int) -> None:
216+
"""Test procrustes_step recovers amplitude of real matrix up to sign ambiguity.
217+
218+
This is the main functional test for procrustes_step. It must recover the amplitude
219+
of a real matrix up to a sign ambiguity with probability 1.
220+
"""
221+
for trial in range(10):
222+
n = 10
223+
Q = torch.randn(n, n, device=self.device, dtype=torch.float32)
224+
u, s, vh = torch.linalg.svd(Q)
225+
amplitude = vh.mH @ torch.diag(s) @ vh
226+
Q1, Q2 = torch.clone(Q), torch.clone(Q)
227+
Q2[1] *= -1 # add a reflection to Q2 to get Q2'
228+
229+
err1, err2 = float("inf"), float("inf")
230+
max_iterations = 1000
231+
tolerance = 0.01
232+
step_count = 0
233+
234+
while step_count < max_iterations and err1 >= tolerance and err2 >= tolerance:
235+
Q1 = procrustes_step(Q1, order=order)
236+
Q2 = procrustes_step(Q2, order=order)
237+
err1 = torch.max(torch.abs(Q1 - amplitude)) / torch.max(torch.abs(amplitude))
238+
err2 = torch.max(torch.abs(Q2 - amplitude)) / torch.max(torch.abs(amplitude))
239+
step_count += 1
240+
241+
# Record convergence information
242+
converged = err1 < tolerance or err2 < tolerance
243+
final_error = min(err1, err2)
244+
245+
self.assertTrue(
246+
converged,
247+
f"Trial {trial} (order={order}): procrustes_step failed to recover amplitude after {step_count} steps. "
248+
f"Final errors: err1={err1:.4f}, err2={err2:.4f}, best_error={final_error:.4f}",
249+
)
250+
158251

159252
if __name__ == "__main__":
160253
torch.manual_seed(42)

0 commit comments

Comments
 (0)