Skip to content

Commit 2ebdb41

Browse files
authored
save unnecessary matmul (#30)
* save unnecessary matmul Signed-off-by: Hao Wu <[email protected]> * simplify criteria logic Signed-off-by: Hao Wu <[email protected]> * remove max precondition dim Signed-off-by: Hao Wu <[email protected]>
1 parent b015387 commit 2ebdb41

File tree

5 files changed

+83
-125
lines changed

5 files changed

+83
-125
lines changed

emerging_optimizers/soap/soap.py

Lines changed: 14 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ class SOAP(optim.Optimizer):
6161
precondition_warmup_steps: How many steps to warm up the preconditioner (i.e. update every step)
6262
adam_warmup_steps: How many steps to skip preconditioning in the beginning (i.e. use standard AdamW updates)
6363
precondition_1d: Whether to precondition 1D gradients (like biases).
64-
max_precond_dim: Maximum dimension of the preconditioner matrices. Skips preconditioning if any tensor dimension exceeds.
6564
trace_normalization: Whether to normalize update by the trace of the kronecker factor matrix
6665
normalize_preconditioned_grads: Whether to normalize preconditioned gradients per layer
6766
correct_bias: Whether to use bias correction in Inner Adam and Kronecker factor matrices EMA
@@ -91,7 +90,6 @@ def __init__(
9190
precondition_warmup_steps: int = 0,
9291
adam_warmup_steps: int = 1,
9392
precondition_1d: bool = False,
94-
max_precond_dim: int = 8192,
9593
trace_normalization: bool = False,
9694
normalize_preconditioned_grads: bool = False,
9795
correct_bias: bool = True,
@@ -141,7 +139,6 @@ def __init__(
141139
"precondition_warmup_steps": precondition_warmup_steps,
142140
"adam_warmup_steps": adam_warmup_steps,
143141
"precondition_1d": precondition_1d,
144-
"max_precond_dim": max_precond_dim,
145142
"trace_normalization": trace_normalization,
146143
"normalize_preconditioned_grads": normalize_preconditioned_grads,
147144
"use_nesterov": use_nesterov,
@@ -194,7 +191,6 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
194191
state["GG"] = init_kronecker_factors(
195192
grad,
196193
precondition_1d=group["precondition_1d"],
197-
max_precond_dim=group["max_precond_dim"],
198194
)
199195

200196
# Update preconditioner matrices with gradient statistics, do not use shampoo_beta for EMA at first step
@@ -204,7 +200,6 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
204200
grad=grad,
205201
shampoo_beta=0.0,
206202
precondition_1d=group["precondition_1d"],
207-
max_precond_dim=group["max_precond_dim"],
208203
)
209204

210205
# Increment step counter
@@ -284,7 +279,6 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
284279
grad=grad,
285280
shampoo_beta=shampoo_beta,
286281
precondition_1d=group["precondition_1d"],
287-
max_precond_dim=group["max_precond_dim"],
288282
)
289283
torch.cuda.nvtx.range_pop()
290284

@@ -330,7 +324,6 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
330324
def init_kronecker_factors(
331325
grad: torch.Tensor,
332326
precondition_1d: bool = False,
333-
max_precond_dim: int = 8192,
334327
) -> List[torch.Tensor]:
335328
"""Initializes the kronecker factor matrices for the SOAP optimizer.
336329
@@ -354,8 +347,6 @@ def init_kronecker_factors(
354347
The shape of this tensor determines the size of the kronecker factor matrices.
355348
precondition_1d: Whether to create kronecker factor matrices for 1D tensors
356349
(like biases). If False, 1D tensors will skip preconditioning.
357-
max_precond_dim: Maximum dimension of the preconditioner matrices.
358-
Skips preconditioning if any tensor dimension exceeds.
359350
360351
Returns:
361352
List[torch.Tensor]: List of kronecker factor matrices (L and R in paper).
@@ -387,21 +378,11 @@ def init_kronecker_factors(
387378
else:
388379
# Create a square preconditioner matrix for 1D tensors
389380
size = grad.shape[0]
390-
if size > max_precond_dim:
391-
# if tensor dimension is larger than max_precond_dim, skip preconditioning this dimension
392-
# append empty tensor to kronecker_factor_list so that subsequent check that use numel() to check if preconditioner is initialized will not fail
393-
kronecker_factor_list.append(torch.empty(0, device=grad.device))
394-
else:
395-
kronecker_factor_list.append(torch.zeros(size, size, device=grad.device))
381+
kronecker_factor_list.append(torch.zeros(size, size, device=grad.device))
396382
else:
397383
# Create a square kronecker factor matrix for each dimension
398384
for size in grad.shape:
399-
if size > max_precond_dim:
400-
# append empty tensor to kronecker_factor_list so that subsequent check that use numel() to check if preconditioner is initialized will not fail
401-
# skip preconditioning this dimension
402-
kronecker_factor_list.append(torch.empty(0, device=grad.device))
403-
else:
404-
kronecker_factor_list.append(torch.zeros(size, size, device=grad.device))
385+
kronecker_factor_list.append(torch.zeros(size, size, device=grad.device))
405386

406387
return kronecker_factor_list
407388

@@ -412,7 +393,6 @@ def update_kronecker_factors(
412393
grad: torch.Tensor,
413394
shampoo_beta: float,
414395
precondition_1d: bool = False,
415-
max_precond_dim: int = 8192,
416396
) -> None:
417397
"""Updates the preconditioner matrices using gradient outer products.
418398
@@ -429,8 +409,6 @@ def update_kronecker_factors(
429409
Controls how much weight to give to new vs old gradient statistics.
430410
precondition_1d: Whether to apply preconditioning to 1D tensors (like biases).
431411
If False, 1D tensors will skip preconditioning.
432-
max_precond_dim: Maximum dimension of the preconditioner matrices.
433-
Skips preconditioning if any tensor dimension exceeds.
434412
435413
Example:
436414
>>> grad = torch.randn(10, 20)
@@ -446,20 +424,22 @@ def update_kronecker_factors(
446424
kronecker_factor_list[0].lerp_(outer_product, 1 - shampoo_beta)
447425
else:
448426
# For 1D tensors, skip preconditioning
427+
logging.error(
428+
"1D tensor is passed to update_kronecker_factors, but precondition_1d is not set to True, skipping preconditioning."
429+
)
449430
return
450431
else:
451432
# For higher dimensional tensors, compute outer products for each dimension
452433
for idx, dim_size in enumerate(grad.shape):
453-
if dim_size <= max_precond_dim:
454-
# Compute outer product by contracting all dimensions except idx
455-
contract_dims = [*chain(range(idx), range(idx + 1, grad.dim()))]
456-
outer_product = torch.tensordot(
457-
grad,
458-
grad,
459-
dims=[contract_dims] * 2,
460-
)
461-
# Update the corresponding Kronecker factor
462-
kronecker_factor_list[idx].lerp_(outer_product, 1 - shampoo_beta)
434+
# Compute outer product by contracting all dimensions except idx
435+
contract_dims = [*chain(range(idx), range(idx + 1, grad.dim()))]
436+
outer_product = torch.tensordot(
437+
grad,
438+
grad,
439+
dims=[contract_dims] * 2,
440+
)
441+
# Update the corresponding Kronecker factor
442+
kronecker_factor_list[idx].lerp_(outer_product, 1 - shampoo_beta)
463443

464444

465445
@torch.no_grad() # type: ignore[misc]

emerging_optimizers/soap/soap_utils.py

Lines changed: 34 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -86,14 +86,11 @@ def get_eigenbasis_eigh(
8686
updated_eigenbasis_list.append(torch.empty(0, device=kronecker_factor.device))
8787
continue
8888
# Construct approximated eigenvalues using QL^T@L@QL or QR^T@R@QR.
89-
# The approximated eigenvalues should be close to diagonal if the eigenbasis is close to the true eigenbasis of the kronecker factor
90-
# (i.e. the approximated eigenvectors diagonalize the kronecker factor)
89+
# The approximated eigenvalues should be close to diagonal if the eigenbasis is close to the true
90+
# eigenbasis of the kronecker factor (i.e. the approximated eigenvectors diagonalize the kronecker factor)
9191
approx_eigenvalue_matrix = eigenbasis.T @ kronecker_factor @ eigenbasis
9292
# Update eigenbasis when necessary. Update is skipped only when adaptive update criteria is met.
93-
if _adaptive_criteria_met(
94-
approx_eigenvalue_matrix=approx_eigenvalue_matrix,
95-
tolerance=adaptive_update_tolerance,
96-
):
93+
if utils.eig.met_approx_eigvals_criteria(approx_eigenvalue_matrix, adaptive_update_tolerance):
9794
_, Q = utils.eig.eigh_with_fallback(
9895
kronecker_factor,
9996
force_double=False,
@@ -206,21 +203,23 @@ def get_eigenbasis_qr(
206203
if kronecker_factor.numel() == 0:
207204
updated_eigenbasis_list.append(torch.empty(0, device=kronecker_factor.device))
208205
continue
209-
# construct approximated eigenvalues using QL^T@L@QL or QR^T@R@QR, which should be close to diagonal
210-
# if the eigenbasis is close to the true eigenbasis of the kronecker factor (i.e. diagonalizes it)
211-
approx_eigenvalue_matrix = eigenbasis.T @ kronecker_factor @ eigenbasis
212206

213207
# Update eigenbasis when necessary. Update is skipped only when use_adaptive_criteria is True
214208
# but criteria is not met.
215209
if_update = True
216-
if use_adaptive_criteria and not _adaptive_criteria_met(
217-
approx_eigenvalue_matrix=approx_eigenvalue_matrix,
218-
tolerance=adaptive_update_tolerance,
219-
):
220-
if_update = False
210+
# construct approximated eigenvalues using QL^T@L@QL or QR^T@R@QR, which should be close to diagonal
211+
# if the eigenbasis is close to the true eigenbasis of the kronecker factor (i.e. diagonalizes it)
212+
if use_adaptive_criteria:
213+
approx_eigenvalue_matrix = _conjugate(kronecker_factor, eigenbasis)
214+
if_update = not utils.eig.met_approx_eigvals_criteria(approx_eigenvalue_matrix, adaptive_update_tolerance)
215+
if if_update:
216+
approx_eigvals = torch.diag(approx_eigenvalue_matrix)
217+
else:
218+
approx_eigvals = _conjugate(kronecker_factor, eigenbasis, diag=True)
219+
221220
if if_update:
222221
Q, exp_avg_sq = _orthogonal_iteration(
223-
approx_eigenvalue_matrix=approx_eigenvalue_matrix,
222+
approx_eigvals=approx_eigvals,
224223
kronecker_factor=kronecker_factor,
225224
eigenbasis=eigenbasis,
226225
ind=ind,
@@ -237,13 +236,13 @@ def get_eigenbasis_qr(
237236

238237

239238
def _orthogonal_iteration(
240-
approx_eigenvalue_matrix: torch.Tensor,
239+
approx_eigvals: torch.Tensor,
241240
kronecker_factor: torch.Tensor,
242241
eigenbasis: torch.Tensor,
243242
ind: int,
244243
exp_avg_sq: torch.Tensor,
245-
convert_to_float: bool = True,
246-
power_iter_steps: int = 1,
244+
convert_to_float: bool,
245+
power_iter_steps: int,
247246
) -> Tuple[torch.Tensor, torch.Tensor]:
248247
"""Computes the eigenbases of the preconditioner using power iteration and QR decomposition.
249248
@@ -267,8 +266,6 @@ def _orthogonal_iteration(
267266
- Q: The updated eigenbasis
268267
- exp_avg_sq: The updated (sorted) inner Adam second moment
269268
"""
270-
# extract approximated eigenvalues from the diagonal of the projection of kronecker factor onto eigenbases
271-
approx_eigvals = torch.diag(approx_eigenvalue_matrix)
272269
# Sort the approximated eigenvalues according to their magnitudes
273270
sort_idx = torch.argsort(approx_eigvals, descending=True)
274271
# re-order the inner adam second moment
@@ -292,27 +289,26 @@ def _orthogonal_iteration(
292289
return Q, exp_avg_sq
293290

294291

295-
@torch.compile # type: ignore[misc]
296-
def _adaptive_criteria_met(
297-
approx_eigenvalue_matrix: torch.Tensor,
298-
tolerance: Optional[float] = None,
299-
) -> bool:
300-
"""Determines whether the eigenbasis for a factor matrix should be updated in the next step of the orthogonal iteration.
292+
def _conjugate(a: torch.Tensor, p: torch.Tensor, diag: bool = False) -> torch.Tensor:
293+
"""Calculate similarity transformation
301294
302-
Determines whether the eigenbasis for a factor matrix should be updated based on computing
303-
the approximated eigenvalues Q^T GG Q, where Q is the approximated eigenvectors and
304-
GG is the Kronecker factor. The approximated eigenvalues update criteria is then defined as
305-
||diag(Q^T GG Q)||_F >= (1 - tolerance) * (Q^T GG Q)_F.
295+
This function calculates :math:`B = P^T A P`. It assumes P is orthogonal so that :math:`P^{-1} = P^T` and
296+
the similarity transformation exists.
306297
307298
Args:
308-
approx_eigenvalue_matrix: Projection of kronecker factor onto the eigenbasis, should be close to diagonal
309-
tolerance: Tolerance threshold for the normalized diagonal component of approximated eigenvalue matrix.
299+
a: matrix to be transformed
300+
p: An orthogonal matrix.
301+
diag: If True, only return the diagonal of the similarity transformation
310302
311303
Returns:
312-
perform_update: Whether to update eigenbasis this iteration
304+
b
313305
"""
314-
if tolerance is None:
315-
return True
316-
317-
# check if normalized diagonal component is not smaller than tolerance
318-
return not utils.eig.adaptive_early_exit_criteria(approx_eigenvalue_matrix, tolerance)
306+
if a.dim() != 2 or p.dim() != 2:
307+
raise TypeError("a and p must be 2D matrices")
308+
pta = p.T @ a
309+
if not diag:
310+
b = pta @ p
311+
else:
312+
# return the diagonal of the similarity transformation
313+
b = (pta * p.T).sum(dim=1)
314+
return b

emerging_optimizers/utils/eig.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from emerging_optimizers import utils
2222

2323

24-
__all__ = ["eigh_with_fallback", "eig_orthogonal_iteration", "adaptive_early_exit_criteria"]
24+
__all__ = ["eigh_with_fallback", "eig_orthogonal_iteration", "met_approx_eigvals_criteria"]
2525

2626

2727
def eigh_with_fallback(
@@ -135,7 +135,7 @@ def eig_orthogonal_iteration(
135135
approx_eigenvalues_matrix = Q.T @ x @ Q
136136
approx_eigenvalues = torch.diag(approx_eigenvalues_matrix)
137137
iteration = 0
138-
while iteration < max_iterations and not adaptive_early_exit_criteria(approx_eigenvalues_matrix, tolerance):
138+
while iteration < max_iterations and not met_approx_eigvals_criteria(approx_eigenvalues_matrix, tolerance):
139139
power_iteration = x @ Q
140140
Q = torch.linalg.qr(power_iteration).Q
141141
approx_eigenvalues_matrix = Q.T @ x @ Q
@@ -148,7 +148,7 @@ def eig_orthogonal_iteration(
148148
return approx_eigenvalues, Q
149149

150150

151-
def adaptive_early_exit_criteria(approx_eigenvalues_matrix: Tensor, tolerance: float) -> bool:
151+
def met_approx_eigvals_criteria(approx_eigenvalues_matrix: Tensor, tolerance: float) -> bool:
152152
"""Evaluates if a criteria using approximated eigenvalues is below or equal to the tolerance.
153153
154154
`approx_eigenvalues_matrix` is a matrix created from the approximated eigenvectors and the symmetric matrix that is being eigendecomposed.

tests/test_soap_functions.py

Lines changed: 2 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -35,28 +35,12 @@ def test_init_preconditioner_multidim_tensor_shapes(self) -> None:
3535
"""Tests init_preconditioner with a multi-dimensional tensor."""
3636
grad = torch.randn(3, 4, 5)
3737
state: dict[str, Any] = {}
38-
# No merge_dims: each dimension gets its own preconditioner unless dimension > max_precond_dim.
39-
state["GG"] = init_kronecker_factors(grad, precondition_1d=False, max_precond_dim=8192)
38+
state["GG"] = init_kronecker_factors(grad, precondition_1d=False)
4039
self.assertEqual(len(state["GG"]), grad.dim())
4140
self.assertEqual(state["GG"][0].shape, (3, 3))
4241
self.assertEqual(state["GG"][1].shape, (4, 4))
4342
self.assertEqual(state["GG"][2].shape, (5, 5))
4443

45-
def test_init_kronecker_factors_max_precond_dim(self) -> None:
46-
"""Tests init_kronecker_factors respects max_precond_dim."""
47-
max_dim = 8
48-
grad = torch.randn(3, max_dim + 2, 5) # Second dimension exceeds max_dim
49-
kronecker_factors = init_kronecker_factors(grad, precondition_1d=False, max_precond_dim=max_dim)
50-
51-
self.assertEqual(len(kronecker_factors), grad.dim())
52-
# Dimension 0 (size 3) <= max_dim
53-
self.assertEqual(kronecker_factors[0].shape, (3, 3))
54-
# Dimension 1 (size max_dim + 2) > max_dim -> Should be empty
55-
self.assertEqual(kronecker_factors[1].shape, (0,))
56-
self.assertEqual(kronecker_factors[1].numel(), 0)
57-
# Dimension 2 (size 5) <= max_dim
58-
self.assertEqual(kronecker_factors[2].shape, (5, 5))
59-
6044
@parameterized.parameters(
6145
(1,),
6246
(2,),
@@ -97,14 +81,13 @@ def test_adam_warmup_steps(self, adam_warmup_steps: int) -> None:
9781
self.assertEqual(state["Q"][1].shape, (3, 3))
9882

9983
def test_update_kronecker_factors(self) -> None:
100-
"""Tests update_kronecker_factors, including max_precond_dim handling."""
10184
max_dim = 8
10285
shampoo_beta = 0.9
10386
dim0, dim1, dim2 = 3, max_dim + 2, 5
10487
grad = torch.randn(dim0, dim1, dim2)
10588

10689
# Initialize factors
107-
initial_factors = init_kronecker_factors(grad, precondition_1d=False, max_precond_dim=max_dim)
90+
initial_factors = init_kronecker_factors(grad, precondition_1d=False)
10891

10992
kronecker_factors = [f.clone() for f in initial_factors]
11093

@@ -113,25 +96,15 @@ def test_update_kronecker_factors(self) -> None:
11396
grad=grad,
11497
shampoo_beta=shampoo_beta,
11598
precondition_1d=False,
116-
max_precond_dim=max_dim,
11799
)
118100

119101
self.assertEqual(len(kronecker_factors), grad.dim())
120102

121-
# Dimension 0 (size 3) <= max_dim: Should be updated
122103
contract_dims_0 = [1, 2]
123104
outer_product_0 = torch.tensordot(grad, grad, dims=[contract_dims_0] * 2)
124105
expected_factor_0 = initial_factors[0] * shampoo_beta + outer_product_0 * (1 - shampoo_beta)
125106
torch.testing.assert_close(kronecker_factors[0], expected_factor_0, atol=1e-6, rtol=1e-6)
126107

127-
# Dimension 1 (size 10) > max_dim: Should NOT be updated (still empty)
128-
self.assertEqual(kronecker_factors[1].shape, (0,))
129-
self.assertEqual(kronecker_factors[1].numel(), 0)
130-
131-
# Check it's the same object or has same properties as initial empty tensor
132-
self.assertTrue(torch.equal(kronecker_factors[1], initial_factors[1]))
133-
134-
# Dimension 2 (size 5) <= max_dim: Should be updated
135108
contract_dims_2 = [0, 1]
136109
outer_product_2 = torch.tensordot(grad, grad, dims=[contract_dims_2] * 2)
137110
expected_factor_2 = initial_factors[2] * shampoo_beta + outer_product_2 * (1 - shampoo_beta)

0 commit comments

Comments
 (0)