Skip to content

Commit 9b433f7

Browse files
authored
Update eig utils (#35)
* Update eigen value criteria testing to save compute Signed-off-by: Hao Wu <[email protected]> * remove deprecated function in __all__ Signed-off-by: Hao Wu <[email protected]> --------- Signed-off-by: Hao Wu <[email protected]>
1 parent c3c451b commit 9b433f7

File tree

4 files changed

+147
-145
lines changed

4 files changed

+147
-145
lines changed

docs/apidocs/soap.md

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,4 @@ emerging_optimizers.soap.soap_utils
2727
2828
.. automodule:: emerging_optimizers.soap.soap_utils
2929
:members:
30-
31-
.. autofunction:: _orthogonal_iteration
32-
33-
.. autofunction:: _conjugate
34-
3530
```

emerging_optimizers/soap/soap_utils.py

Lines changed: 15 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import torch
1818

19-
from emerging_optimizers import utils
19+
from emerging_optimizers.utils import eig as eig_utils
2020

2121

2222
__all__ = [
@@ -85,13 +85,10 @@ def get_eigenbasis_eigh(
8585
# We use an empty tensor so that the `precondition` function will skip this factor.
8686
updated_eigenbasis_list.append(torch.empty(0, device=kronecker_factor.device))
8787
continue
88-
# Construct approximated eigenvalues using :math:`Q_L^T L Q_L` or :math:`Q_R^T R Q_R`.
89-
# The approximated eigenvalues should be close to diagonal if the eigenbasis is close to the true
90-
# eigenbasis of the kronecker factor (i.e. the approximated eigenvectors diagonalize the kronecker factor)
91-
approx_eigenvalue_matrix = eigenbasis.T @ kronecker_factor @ eigenbasis
92-
# Update eigenbasis when necessary. Update is skipped only when adaptive update criteria is met.
93-
if utils.eig.met_approx_eigvals_criteria(approx_eigenvalue_matrix, adaptive_update_tolerance):
94-
_, Q = utils.eig.eigh_with_fallback(
88+
89+
approx_eigvals = eig_utils.conjugate(kronecker_factor, eigenbasis, diag=True)
90+
if not eig_utils.met_approx_eigvals_criteria(kronecker_factor, approx_eigvals, adaptive_update_tolerance):
91+
_, Q = eig_utils.eigh_with_fallback(
9592
kronecker_factor,
9693
force_double=False,
9794
eps=eps,
@@ -106,7 +103,7 @@ def get_eigenbasis_eigh(
106103
if kronecker_factor.numel() == 0:
107104
updated_eigenbasis_list.append(torch.empty(0, device=kronecker_factor.device))
108105
continue
109-
_, Q = utils.eig.eigh_with_fallback(
106+
_, Q = eig_utils.eigh_with_fallback(
110107
kronecker_factor, force_double=False, eps=eps, output_dtype=torch.float if convert_to_float else None
111108
)
112109
updated_eigenbasis_list.append(Q)
@@ -204,21 +201,19 @@ def get_eigenbasis_qr(
204201
updated_eigenbasis_list.append(torch.empty(0, device=kronecker_factor.device))
205202
continue
206203

207-
# Update eigenbasis when necessary. Update is skipped only when use_adaptive_criteria is True
208-
# but criteria is not met.
204+
# Update eigenbasis when necessary. Update is skipped only when use_adaptive_criteria is True while
205+
# criteria is not met.
209206
if_update = True
210-
# construct approximated eigenvalues using :math:`Q_L^T L Q_L` or :math:`Q_R^T R Q_R`, which should be close to diagonal
211-
# if the eigenbasis is close to the true eigenbasis of the kronecker factor (i.e. diagonalizes it)
207+
# construct approximated eigenvalues using Q_L^T L Q_L or Q_R^T R Q_R, which should be close to
208+
# diagonal if the eigenbasis is close to the true eigenbasis of the kronecker factor (i.e. diagonalizes it)
209+
approx_eigvals = eig_utils.conjugate(kronecker_factor, eigenbasis, diag=True)
212210
if use_adaptive_criteria:
213-
approx_eigenvalue_matrix = _conjugate(kronecker_factor, eigenbasis)
214-
if_update = not utils.eig.met_approx_eigvals_criteria(approx_eigenvalue_matrix, adaptive_update_tolerance)
215-
if if_update:
216-
approx_eigvals = torch.diag(approx_eigenvalue_matrix)
217-
else:
218-
approx_eigvals = _conjugate(kronecker_factor, eigenbasis, diag=True)
211+
if_update = not eig_utils.met_approx_eigvals_criteria(
212+
kronecker_factor, approx_eigvals, adaptive_update_tolerance
213+
)
219214

220215
if if_update:
221-
Q, exp_avg_sq = _orthogonal_iteration(
216+
Q, exp_avg_sq = eig_utils.orthogonal_iteration(
222217
approx_eigvals=approx_eigvals,
223218
kronecker_factor=kronecker_factor,
224219
eigenbasis=eigenbasis,
@@ -233,82 +228,3 @@ def get_eigenbasis_qr(
233228
updated_eigenbasis_list.append(eigenbasis)
234229

235230
return updated_eigenbasis_list, exp_avg_sq
236-
237-
238-
def _orthogonal_iteration(
239-
approx_eigvals: torch.Tensor,
240-
kronecker_factor: torch.Tensor,
241-
eigenbasis: torch.Tensor,
242-
ind: int,
243-
exp_avg_sq: torch.Tensor,
244-
convert_to_float: bool,
245-
power_iter_steps: int,
246-
) -> Tuple[torch.Tensor, torch.Tensor]:
247-
"""Computes the eigenbases of the preconditioner using power iteration and QR decomposition.
248-
249-
This function performs multiple rounds of power iteration followed by QR decomposition
250-
to recompute the eigenbases of the preconditioner kronecker factor. Generalizes Vyas et al.'s (SOAP) algorithm of 1 step of power iteration for updating the eigenbasis.
251-
252-
Args:
253-
approx_eigenvalue_matrix : Projection of kronecker factor onto the eigenbasis, should be close to diagonal
254-
kronecker_factor : Kronecker factor matrix.
255-
eigenbasis : Kronecker factor eigenbasis matrix.
256-
ind : Index for selecting dimension in the exp_avg_sq matrix to apply the sorting order over.
257-
exp_avg_sq : inner Adam second moment (exp_avg_sq).
258-
convert_to_float : If True, preconditioner matrices and their corresponding
259-
orthonormal matrices will be cast to float. Otherwise, they are left in
260-
their original type. Defaults to False.
261-
power_iter_steps: Number of power iteration steps to perform before QR decomposition.
262-
More steps can lead to better convergence but increased computation time.
263-
264-
Returns:
265-
tuple[torch.Tensor, torch.Tensor]: A tuple containing:
266-
- Q: The updated eigenbasis
267-
- exp_avg_sq: The updated (sorted) inner Adam second moment
268-
"""
269-
# Sort the approximated eigenvalues according to their magnitudes
270-
sort_idx = torch.argsort(approx_eigvals, descending=True)
271-
# re-order the inner adam second moment
272-
exp_avg_sq = exp_avg_sq.index_select(ind, sort_idx)
273-
274-
# Initialize power iteration after sorting the columns of the eigenbasis matrix according to the descending eigenvalues
275-
Q = eigenbasis[:, sort_idx]
276-
277-
# By default, perform QR decomposition with power iteration with FP32 precision
278-
# Perform multiple steps of power iteration
279-
for _ in range(power_iter_steps):
280-
# Project current eigenbases on kronecker factor
281-
Q = kronecker_factor @ Q
282-
# Perform QR to maintain orthogonality between iterations
283-
Q = torch.linalg.qr(Q).Q
284-
285-
# When not converting to float, ensure that Q is in the original dtype
286-
if not convert_to_float:
287-
Q = Q.to(kronecker_factor.dtype)
288-
289-
return Q, exp_avg_sq
290-
291-
292-
def _conjugate(a: torch.Tensor, p: torch.Tensor, diag: bool = False) -> torch.Tensor:
293-
"""Calculate similarity transformation
294-
295-
This function calculates :math:`B = P^T A P`. It assumes P is orthogonal so that :math:`P^{-1} = P^T` and
296-
the similarity transformation exists.
297-
298-
Args:
299-
a: matrix to be transformed
300-
p: An orthogonal matrix.
301-
diag: If True, only return the diagonal of the similarity transformation
302-
303-
Returns:
304-
b
305-
"""
306-
if a.dim() != 2 or p.dim() != 2:
307-
raise TypeError("a and p must be 2D matrices")
308-
pta = p.T @ a
309-
if not diag:
310-
b = pta @ p
311-
else:
312-
# return the diagonal of the similarity transformation
313-
b = (pta * p.T).sum(dim=1)
314-
return b

emerging_optimizers/utils/eig.py

Lines changed: 122 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,19 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
from typing import Optional
15+
from typing import Optional, Tuple
1616

1717
import torch
1818
from absl import logging
1919
from torch import Tensor
2020

21-
from emerging_optimizers import utils
2221

23-
24-
__all__ = ["eigh_with_fallback", "eig_orthogonal_iteration", "met_approx_eigvals_criteria"]
22+
__all__ = [
23+
"eigh_with_fallback",
24+
"met_approx_eigvals_criteria",
25+
"conjugate",
26+
"orthogonal_iteration",
27+
]
2528

2629

2730
def eigh_with_fallback(
@@ -104,6 +107,7 @@ def eig_orthogonal_iteration(
104107
) -> tuple[Tensor, Tensor]:
105108
"""Approximately compute the eigen decomposition
106109
110+
[DEPRECATED] Use `orthogonal_iteration` instead.
107111
108112
Orthogonal or subspace iteration uses iterative power iteration and QR decomposition to update the approximated
109113
eigenvectors. When the initial estimate is the zero matrix, the eigendecomposition is computed
@@ -133,46 +137,129 @@ def eig_orthogonal_iteration(
133137
return eigh_with_fallback(x, force_double=True)
134138

135139
# Perform power iteration and QR decomposition iteratively.
136-
with utils.fp32_matmul_precision("highest"):
137-
Q = approx_eigenvectors
138-
approx_eigenvalues_matrix = Q.T @ x @ Q
139-
approx_eigenvalues = torch.diag(approx_eigenvalues_matrix)
140-
iteration = 0
141-
while iteration < max_iterations and not met_approx_eigvals_criteria(approx_eigenvalues_matrix, tolerance):
142-
power_iteration = x @ Q
143-
Q = torch.linalg.qr(power_iteration).Q
144-
approx_eigenvalues_matrix = Q.T @ x @ Q
145-
iteration += 1
146-
# Sort eigenvalues in descending order and reorder eigenvectors accordingly
147-
# Sorting can help mitigate numerical instability since QR decompositions can mix the approximated eigenvectors
148-
approx_eigenvalues, indices = torch.diag(approx_eigenvalues_matrix).sort(stable=True, descending=True)
149-
Q = Q[:, indices]
150-
151-
return approx_eigenvalues, Q
152-
153-
154-
def met_approx_eigvals_criteria(approx_eigenvalues_matrix: Tensor, tolerance: float) -> bool:
155-
"""Evaluates if a criteria using approximated eigenvalues is below or equal to the tolerance.
156-
157-
`approx_eigenvalues_matrix` is a matrix created from the approximated eigenvectors and the symmetric matrix
158-
that is being eigendecomposed. We check if the ratio of the diagonal norm to the matrix norm is greater
159-
than or equal to (1 - tolerance).
140+
Q = approx_eigenvectors
141+
approx_eigvals = conjugate(x, Q, diag=True)
142+
iteration = 0
143+
while iteration < max_iterations and not met_approx_eigvals_criteria(x, approx_eigvals, tolerance):
144+
power_iteration = x @ Q
145+
Q = torch.linalg.qr(power_iteration).Q
146+
approx_eigvals = conjugate(x, Q, diag=True)
147+
iteration += 1
148+
# Sort eigenvalues in descending order and reorder eigenvectors accordingly
149+
# Sorting can help mitigate numerical instability since QR decompositions can mix the approximated eigenvectors
150+
sorted_approx_eigvals, indices = approx_eigvals.sort(stable=True, descending=True)
151+
Q = Q[:, indices]
152+
153+
return sorted_approx_eigvals, Q
154+
155+
156+
def met_approx_eigvals_criteria(
157+
kronecker_factor: torch.Tensor,
158+
approx_eigvals: torch.Tensor,
159+
tolerance: float,
160+
) -> bool:
161+
"""Determines whether the eigenbasis for a factor matrix met the desired criteria
162+
163+
The approximated eigenvalues update criteria is then defined as
164+
:math:`||diag(Q^T K Q)||_F >= (1 - tolerance) * (Q^T K Q)_F`, where :math:`Q` is the approximated eigenvectors and
165+
:math:`K` is the kronecker factor (L or R).
166+
167+
We use the kronecker factor and approximated eigenvalues directly to save compute because Frobenius norm of
168+
kronecker factor is the same as that of the approximated eigenvalues matrix.
160169
161170
Args:
162-
approx_eigenvalues_matrix: The symmetric matrix whose eigenvalues is being eigendecomposed.
163-
tolerance: The tolerance for the early exit criteria, the min relative error between diagonal norm
164-
and matrix norm of the approximated eigenvalues and the diagonal.
171+
kronecker_factor: Kronecker factor matrix.
172+
approx_eigvals: Approximated eigenvalues
173+
tolerance: Tolerance threshold for the normalized diagonal component of approximated eigenvalue matrix.
165174
166175
Returns:
167-
bool: True if the criteria is below or equal to the tolerance, False otherwise.
168-
176+
perform_update: Whether to update eigenbasis this iteration
169177
"""
170-
matrix_norm = torch.linalg.norm(approx_eigenvalues_matrix)
171-
approx_eigvals = torch.diag(approx_eigenvalues_matrix)
178+
matrix_norm = torch.linalg.norm(kronecker_factor)
172179
diagonal_norm = torch.linalg.norm(approx_eigvals)
180+
173181
return diagonal_norm >= (1 - tolerance) * matrix_norm
174182

175183

184+
def orthogonal_iteration(
185+
approx_eigvals: torch.Tensor,
186+
kronecker_factor: torch.Tensor,
187+
eigenbasis: torch.Tensor,
188+
ind: int,
189+
exp_avg_sq: torch.Tensor,
190+
convert_to_float: bool,
191+
power_iter_steps: int,
192+
) -> Tuple[torch.Tensor, torch.Tensor]:
193+
"""Computes the eigenbases of the preconditioner using power iteration and QR decomposition.
194+
195+
This function performs multiple rounds of power iteration followed by QR decomposition
196+
to recompute the eigenbases of the preconditioner kronecker factor. Generalizes Vyas et al.'s (SOAP) algorithm of 1 step of power iteration for updating the eigenbasis.
197+
198+
Args:
199+
approx_eigenvalue_matrix : Projection of kronecker factor onto the eigenbasis, should be close to diagonal
200+
kronecker_factor : Kronecker factor matrix.
201+
eigenbasis : Kronecker factor eigenbasis matrix.
202+
ind : Index for selecting dimension in the exp_avg_sq matrix to apply the sorting order over.
203+
exp_avg_sq : inner Adam second moment (exp_avg_sq).
204+
convert_to_float : If True, preconditioner matrices and their corresponding
205+
orthonormal matrices will be cast to float. Otherwise, they are left in
206+
their original type. Defaults to False.
207+
power_iter_steps: Number of power iteration steps to perform before QR decomposition.
208+
More steps can lead to better convergence but increased computation time.
209+
210+
Returns:
211+
tuple[torch.Tensor, torch.Tensor]: A tuple containing:
212+
- Q: The updated eigenbasis
213+
- exp_avg_sq: The updated (sorted) inner Adam second moment
214+
"""
215+
# Sort the approximated eigenvalues according to their magnitudes
216+
sort_idx = torch.argsort(approx_eigvals, descending=True)
217+
# re-order the inner adam second moment
218+
exp_avg_sq = exp_avg_sq.index_select(ind, sort_idx)
219+
220+
# Initialize power iteration after sorting the columns of the eigenbasis matrix according to the descending eigenvalues
221+
Q = eigenbasis[:, sort_idx]
222+
223+
# By default, perform QR decomposition with power iteration with FP32 precision
224+
# Perform multiple steps of power iteration
225+
for _ in range(power_iter_steps):
226+
# Project current eigenbases on kronecker factor
227+
Q = kronecker_factor @ Q
228+
# Perform QR to maintain orthogonality between iterations
229+
Q = torch.linalg.qr(Q).Q
230+
231+
# When not converting to float, ensure that Q is in the original dtype
232+
if not convert_to_float:
233+
Q = Q.to(kronecker_factor.dtype)
234+
235+
return Q, exp_avg_sq
236+
237+
238+
def conjugate(a: torch.Tensor, p: torch.Tensor, diag: bool = False) -> torch.Tensor:
239+
"""Calculate similarity transformation
240+
241+
This function calculates :math:`B = P^T A P`. It assumes P is orthogonal so that :math:`P^{-1} = P^T` and
242+
the similarity transformation exists.
243+
244+
Args:
245+
a: matrix to be transformed
246+
p: An orthogonal matrix.
247+
diag: If True, only return the diagonal of the similarity transformation
248+
249+
Returns:
250+
b
251+
"""
252+
if a.dim() != 2 or p.dim() != 2:
253+
raise TypeError("a and p must be 2D matrices")
254+
pta = p.T @ a
255+
if not diag:
256+
b = pta @ p
257+
else:
258+
# return the diagonal of the similarity transformation
259+
b = (pta * p.T).sum(dim=1)
260+
return b
261+
262+
176263
def _is_diagonal(x: Tensor) -> bool:
177264
r"""Checks if symmetric matrix is diagonal. Raises an error if the input is not a square matrix."""
178265

0 commit comments

Comments
 (0)