Skip to content

Commit dab19bc

Browse files
committed
addressed PR comments
Signed-off-by: mikail <[email protected]>
1 parent 34a25e8 commit dab19bc

File tree

6 files changed

+98
-83
lines changed

6 files changed

+98
-83
lines changed

emerging_optimizers/psgd/procrustes_step.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def procrustes_step(Q: torch.Tensor, max_step_size: float = 0.125) -> torch.Tens
3838
Q: Tensor of shape (n, n), general square matrix to orthogonalize.
3939
max_step_size: Maximum step size for the line search. Default is 1/8. (0.125)
4040
"""
41+
# Note: this function is written in fp32 to avoid numerical instability while computing the taylor expansion of the exponential map
4142
with utils.fp32_matmul_precision("highest"):
4243
R = Q.T - Q
4344
R /= norm_lower_bound_skew(R) + torch.finfo(R.dtype).smallest_normal
@@ -48,11 +49,12 @@ def procrustes_step(Q: torch.Tensor, max_step_size: float = 0.125) -> torch.Tens
4849
RRQ = R @ RQ
4950
tr_RRQ = torch.trace(RRQ)
5051
# clip step size to max_step_size, based on a 2nd order expansion.
51-
step_size = torch.clamp(-tr_RQ / tr_RRQ, min=0, max=max_step_size)
52+
_step_size = torch.clamp(-tr_RQ / tr_RRQ, min=0, max=max_step_size)
5253
# If tr_RRQ >= 0, the quadratic approximation is not concave, we fallback to max_step_size.
53-
a = torch.where(tr_RRQ < 0, step_size, max_step_size)
54+
step_size = torch.where(tr_RRQ < 0, _step_size, max_step_size)
5455
# rotate Q as exp(a R) Q ~ (I + a R + a^2 R^2/2) Q with an optimal step size by line search
5556
# for 2nd order expansion, only expand exp(a R) to its 2nd term.
56-
Q += a * (RQ + 0.5 * a * RRQ)
57+
# Q += step_size * (RQ + 0.5 * step_size * RRQ)
58+
Q = torch.add(Q, torch.add(RQ, RRQ, alpha=0.5 * step_size), alpha=step_size)
5759

5860
return Q

emerging_optimizers/psgd/psgd_kron_contractions.py

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -90,39 +90,43 @@ def apply_preconditioner(Q_list: List[torch.Tensor], X: torch.Tensor) -> torch.T
9090
return Px
9191

9292

93-
def _mode_n_mul_and_permute(X: torch.Tensor, M: torch.Tensor, mode: int) -> torch.Tensor:
94-
"""Multiply tensor X along axis `mode` by 2D matrix M.
93+
def _dim_n_mul_and_permute(X: torch.Tensor, M: torch.Tensor, contract_dim: int) -> torch.Tensor:
94+
"""Multiply tensor X along axis `contract_dim` by 2D matrix M.
9595
9696
Helper function for `_apply_single_kronecker_factor`.
97-
If M is (d_out, d_in) we contract M’s second index with X’s `mode` index.
98-
`torch.tensordot` is used to contract the two tensors, and then the result is permuted to move the new axis 0 to position `mode`.
99-
Returns a new tensor of the same rank, but with size[mode] replaced by d_out.
100-
Note that d_{mode} == d_in.
97+
If M is (d_out, d_in) we contract M’s second index with X’s `contract_dim` index.
98+
`torch.tensordot` is used to contract the two tensors, and then the result is permuted to move the new axis 0 to position `contract_dim`.
99+
Returns a new tensor of the same rank, but with size[contract_dim] replaced by d_out.
100+
Note that d_{contract_dim} == d_in.
101101
102102
Args:
103-
X: Tensor of shape (d_0, d_1, ..., d_{mode-1}, d_{mode}, d_{mode+1}, ..., d_N)
103+
X: Tensor of shape (d_0, d_1, ..., d_{contract_dim-1}, d_{contract_dim}, d_{contract_dim+1}, ..., d_N)
104104
M: Tensor of shape (d_out, d_in)
105-
mode: int, the mode to contract with M, with d_{mode} == d_in
105+
contract_dim: int, the dimension to contract with M, with d_{contract_dim} == d_in
106106
107107
Returns:
108-
Tensor of shape (d_0, d_1, ..., d_{mode-1}, d_out, d_{mode+1}, ..., d_N)
108+
Tensor of shape (d_0, d_1, ..., d_{contract_dim-1}, d_out, d_{contract_dim+1}, ..., d_N)
109109
110-
Example:
111-
X = torch.randn(2, 3, 6)
112-
M = torch.randn(5, 6)
113-
mode = 2
114-
result = _mode_n_mul_and_permute(X, M, mode)
115-
print(result.shape) # Output: torch.Size([2, 3, 5])
110+
Examples
111+
--------
112+
>>> X = torch.randn(2, 3, 6)
113+
>>> M = torch.randn(5, 6)
114+
>>> contract_dim = 2
115+
>>> result = _dim_n_mul_and_permute(X, M, contract_dim)
116+
>>> print(result.shape)
117+
torch.Size([2, 3, 5])
116118
117119
"""
118-
if X.shape[mode] != M.shape[1]:
119-
raise ValueError(f"Shape mismatch: X.shape[{mode}] = {X.shape[mode]}, M.shape[1] = {M.shape[1]}")
120-
# Contract M's 2nd dim (idx=1) with X's `mode` dim
121-
Y = torch.tensordot(M, X, dims=([1], [mode]))
122-
# Y now has shape (d_out, d_0, …, d_{mode-1}, d_{mode+1}, …).
123-
# We want to move that new axis 0 back to position `mode`, due to `torch.tensordot`.
120+
if X.shape[contract_dim] != M.shape[1]:
121+
raise ValueError(
122+
f"Shape mismatch: X.shape[{contract_dim}] = {X.shape[contract_dim]}, M.shape[1] = {M.shape[1]}"
123+
)
124+
# Contract M's 2nd dim (idx=1) with X's `contract_dim` dim
125+
Y = torch.tensordot(M, X, dims=([1], [contract_dim]))
126+
# Y now has shape (d_out, d_0, …, d_{contract_dim-1}, d_{contract_dim+1}, …).
127+
# We want to move that new axis 0 back to position `contract_dim`, due to `torch.tensordot`.
124128
nd = X.dim()
125-
perm = list(range(1, mode + 1)) + [0] + list(range(mode + 1, nd))
129+
perm = list(range(1, contract_dim + 1)) + [0] + list(range(contract_dim + 1, nd))
126130
return Y.permute(perm)
127131

128132

@@ -141,5 +145,5 @@ def _apply_single_kronecker_factor(Q_list: List[torch.Tensor], X: torch.Tensor,
141145
shape = [1] * X.dim()
142146
shape[axis] = Q.size(0)
143147
return X * Q.view(shape)
144-
else:
145-
return _mode_n_mul_and_permute(X, Q, mode=axis)
148+
149+
return _dim_n_mul_and_permute(X, Q, contract_dim=axis)

emerging_optimizers/psgd/psgd_utils.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@
1818

1919

2020
__all__ = [
21-
"balance_q_in_place",
21+
"uniformize_q_in_place",
2222
"norm_lower_bound_spd",
2323
"norm_lower_bound_skew",
2424
]
2525

2626

2727
@torch.compile # type: ignore[misc]
28-
def balance_q_in_place(Q_list: List[torch.Tensor]) -> None:
28+
def uniformize_q_in_place(Q_list: List[torch.Tensor]) -> None:
2929
"""Balance the dynamic ranges of kronecker factors in place to prevent numerical underflow or overflow.
3030
3131
Each tensor in `Q_list` is rescaled so that its maximum absolute entry
@@ -71,7 +71,7 @@ def balance_q_in_place(Q_list: List[torch.Tensor]) -> None:
7171

7272
@torch.compile # type: ignore[misc]
7373
def norm_lower_bound_spd(A: torch.Tensor, k: int = 4, half_iters: int = 2, eps: float = 1e-8) -> torch.Tensor:
74-
r"""Returns a cheap lower bound for the spectral norm of a symmetric positive definite matrix.
74+
r"""A cheap lower bound for the spectral norm of a symmetric positive definite matrix.
7575
7676
7777
Args:
@@ -84,7 +84,7 @@ def norm_lower_bound_spd(A: torch.Tensor, k: int = 4, half_iters: int = 2, eps:
8484
A scalar giving a lower bound on :math:`\\|A\\|_2`.
8585
"""
8686

87-
# Compute normalizing factor from the largest diagonal entry to prevent overflow/underflow and use smallest representable normal number for numerical stability
87+
# Compute normalizing factor from the largest diagonal entry to prevent overflow/underflow and use small number for numerical stability
8888
normalization = A.diagonal().amax() + eps
8989
A = A / normalization
9090

@@ -95,7 +95,7 @@ def norm_lower_bound_spd(A: torch.Tensor, k: int = 4, half_iters: int = 2, eps:
9595

9696
@torch.compile # type: ignore[misc]
9797
def norm_lower_bound_skew(A: torch.Tensor, k: int = 32, half_iters: int = 2, eps: float = 1e-8) -> torch.Tensor:
98-
"""Compute a cheap lower bound on the spectral norm (largest eigenvalue) of skew-symmetric matrix.
98+
"""A cheap lower bound on the spectral norm (largest eigenvalue) of skew-symmetric matrix.
9999
100100
101101
Note: For skew-symmetric matrices, all diagonal entries are zero and :math:`A^T = -A`.
@@ -112,7 +112,7 @@ def norm_lower_bound_skew(A: torch.Tensor, k: int = 32, half_iters: int = 2, eps
112112
113113
"""
114114

115-
# Normalize to avoid extreme values, by extracting the max absolute value and use smallest representable normal number for numerical stability
115+
# Normalize to avoid extreme values, by extracting the max absolute value and use small number for numerical stability
116116
normalizing_factor = A.abs().amax() + eps
117117
A = A / normalizing_factor
118118

@@ -128,7 +128,7 @@ def _subspace_iteration_bound(
128128
half_iters: int = 2,
129129
eps: float = 1e-8,
130130
) -> torch.Tensor:
131-
"""Helper function for subspace iteration to estimate spectral norm bounds.
131+
"""A helper function for subspace iteration to estimate spectral norm bounds.
132132
133133
Uses numerically stable subspace iteration with a random initialization that aligns with the
134134
largest row of A to approximate the dominant eigenspace. This is more robust than simple

tests/test_procrustes_step.py

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,25 @@
1515
import math
1616

1717
import torch
18-
from absl import testing
18+
from absl import flags, testing
1919
from absl.testing import parameterized
2020

2121
from emerging_optimizers.psgd.procrustes_step import procrustes_step
2222
from emerging_optimizers.utils import fp32_matmul_precision
2323

2424

25+
# Define command line flags
26+
flags.DEFINE_string("device", "cpu", "Device to run tests on: 'cpu' or 'cuda'")
27+
28+
FLAGS = flags.FLAGS
29+
30+
2531
class ProcrustesStepTest(parameterized.TestCase):
2632
"""Test cases for procrustes_step function."""
2733

2834
def setUp(self) -> None:
2935
"""Set up test fixtures."""
30-
torch.manual_seed(42)
31-
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
36+
self.device = FLAGS.device
3237

3338
def _procrustes_objective(self, Q: torch.Tensor) -> torch.Tensor:
3439
"""Helper function to compute Procrustes objective ||Q^H Q - I||_F^2."""
@@ -49,15 +54,6 @@ def test_improves_orthogonality_simple_case(self) -> None:
4954

5055
self.assertLessEqual(final_obj.item(), initial_obj.item() + 1e-6)
5156

52-
def test_modifies_matrix_in_place(self) -> None:
53-
"""Test that procrustes_step modifies the matrix in place."""
54-
Q = torch.randn(3, 3, device=self.device)
55-
Q_original_id = id(Q)
56-
57-
Q = procrustes_step(Q, max_step_size=1 / 16)
58-
59-
self.assertEqual(id(Q), Q_original_id)
60-
6157
@parameterized.parameters(
6258
(8,),
6359
(128,),
@@ -91,20 +87,20 @@ def test_handles_small_norm_gracefully(self) -> None:
9187

9288
initial_obj = self._procrustes_objective(Q)
9389

94-
Q = procrustes_step(Q, max_step_size=1 / 16)
90+
Q = procrustes_step(Q, max_step_size=0.0625)
9591

9692
final_obj = self._procrustes_objective(Q)
9793

9894
self.assertLess(final_obj.item(), 1e-6)
9995
self.assertLess(final_obj.item(), initial_obj.item() + 1e-6)
10096

10197
@parameterized.parameters(
102-
(1 / 64,),
103-
(1 / 32,),
104-
(1 / 16,),
105-
(1 / 8,),
98+
(0.015625,),
99+
(0.03125,),
100+
(0.0625,),
101+
(0.125,),
106102
)
107-
def test_different_step_sizes(self, max_step_size: float) -> None:
103+
def test_different_step_sizes_reduces_objective(self, max_step_size: float) -> None:
108104
"""Test procrustes_step improvement with different step sizes."""
109105
perturbation = 1e-1 * torch.randn(10, 10, device=self.device, dtype=torch.float32) / math.sqrt(10)
110106
Q = torch.linalg.qr(torch.randn(10, 10, device=self.device, dtype=torch.float32)).Q + perturbation
@@ -122,12 +118,14 @@ def test_different_step_sizes(self, max_step_size: float) -> None:
122118
(512,),
123119
(8192,),
124120
)
125-
def test_different_matrix_sizes(self, size: int) -> None:
121+
def test_different_matrix_sizes_reduces_objective(self, size: int) -> None:
126122
"""Test procrustes_step improvement with different matrix sizes."""
127123
# Create a non-orthogonal matrix by scaling an orthogonal one
128124
A = torch.randn(size, size, device=self.device, dtype=torch.float32)
129125
with fp32_matmul_precision("highest"):
130126
Q_orth, _ = torch.linalg.qr(A)
127+
# Add perturbation, we choose 1e-2 to be small enough to not affect the objective too much
128+
# but large enough to make the matrix non-orthogonal.
131129
Q = Q_orth + 1e-2 * torch.randn(size, size, device=self.device, dtype=torch.float32) / math.sqrt(size)
132130
max_step_size = 0.5 * size ** (-1 / 3)
133131
initial_obj = self._procrustes_objective(Q)
@@ -147,8 +145,8 @@ def test_preserves_determinant_sign_for_real_matrices(self) -> None:
147145
initial_det_pos = torch.det(Q_pos)
148146
initial_det_neg = torch.det(Q_neg)
149147

150-
Q_pos = procrustes_step(Q_pos, max_step_size=1 / 16)
151-
Q_neg = procrustes_step(Q_neg, max_step_size=1 / 16)
148+
Q_pos = procrustes_step(Q_pos, max_step_size=0.0625)
149+
Q_neg = procrustes_step(Q_neg, max_step_size=0.0625)
152150

153151
final_det_pos = torch.det(Q_pos)
154152
final_det_neg = torch.det(Q_neg)
@@ -159,4 +157,5 @@ def test_preserves_determinant_sign_for_real_matrices(self) -> None:
159157

160158

161159
if __name__ == "__main__":
160+
torch.manual_seed(42)
162161
testing.absltest.main()

tests/test_psgd_contractions.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,25 +13,30 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
import torch
16-
from absl import testing
16+
from absl import flags, testing
1717
from absl.testing import parameterized
1818

1919
from emerging_optimizers.psgd.psgd_kron_contractions import (
20-
_mode_n_mul_and_permute,
20+
_dim_n_mul_and_permute,
2121
apply_kronecker_factors,
2222
apply_preconditioner,
2323
partial_contraction,
2424
)
2525
from emerging_optimizers.utils import fp32_matmul_precision
2626

2727

28+
# Define command line flags
29+
flags.DEFINE_string("device", "cpu", "Device to run tests on: 'cpu' or 'cuda'")
30+
31+
FLAGS = flags.FLAGS
32+
33+
2834
class TestPSGDKronContractions(parameterized.TestCase):
2935
"""Test cases for PSGD Kronecker contractions."""
3036

3137
def setUp(self) -> None:
3238
"""Set up test fixtures."""
33-
torch.manual_seed(42)
34-
self.device = torch.device("cuda")
39+
self.device = FLAGS.device
3540

3641
@parameterized.parameters(
3742
(2, 3, 3),
@@ -111,22 +116,23 @@ def test_apply_preconditioner_matches_reconstructed(self) -> None:
111116
(2, 3, 5, 2),
112117
(4, 6, 2, 1),
113118
)
114-
def test_mode_n_mul_and_permute_shapes(self, dim0: int, dim1: int, dim2: int, mode: int) -> None:
115-
"""Test `_mode_n_mul_and_permute` with non-uniform shapes and different modes."""
119+
def test_dim_n_mul_and_permute__matches_shapes(self, dim0: int, dim1: int, dim2: int, contract_dim: int) -> None:
120+
"""Test `_dim_n_mul_and_permute` with non-uniform shapes and different contract_dim."""
116121
X = torch.randn(dim0, dim1, dim2, device=self.device)
117122
input_shape = X.shape
118123

119-
input_dim = input_shape[mode]
124+
input_dim = input_shape[contract_dim]
120125
output_dim = 7 # arbitrary output dimension
121126
M = torch.randn(output_dim, input_dim, device=self.device)
122127

123-
result = _mode_n_mul_and_permute(X, M, mode)
128+
result = _dim_n_mul_and_permute(X, M, contract_dim)
124129

125-
# Verify output shape: same as input but dimension `mode` replaced by output_dim
130+
# Verify output shape: same as input but dimension `contract_dim` replaced by output_dim
126131
expected_shape = list(input_shape)
127-
expected_shape[mode] = output_dim
132+
expected_shape[contract_dim] = output_dim
128133
self.assertEqual(result.shape, torch.Size(expected_shape))
129134

130135

131136
if __name__ == "__main__":
137+
torch.manual_seed(42)
132138
testing.absltest.main()

0 commit comments

Comments
 (0)