|
12 | 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
15 | | -from typing import Optional |
| 15 | +from typing import Optional, Tuple |
16 | 16 |
|
17 | 17 | import torch |
18 | 18 | from absl import logging |
19 | 19 | from torch import Tensor |
20 | 20 |
|
21 | | -from emerging_optimizers import utils |
22 | 21 |
|
23 | | - |
24 | | -__all__ = ["eigh_with_fallback", "eig_orthogonal_iteration", "met_approx_eigvals_criteria"] |
| 22 | +__all__ = [ |
| 23 | + "eigh_with_fallback", |
| 24 | + "met_approx_eigvals_criteria", |
| 25 | + "conjugate", |
| 26 | + "orthogonal_iteration", |
| 27 | +] |
25 | 28 |
|
26 | 29 |
|
27 | 30 | def eigh_with_fallback( |
@@ -104,6 +107,7 @@ def eig_orthogonal_iteration( |
104 | 107 | ) -> tuple[Tensor, Tensor]: |
105 | 108 | """Approximately compute the eigen decomposition |
106 | 109 |
|
| 110 | + [DEPRECATED] Use `orthogonal_iteration` instead. |
107 | 111 |
|
108 | 112 | Orthogonal or subspace iteration uses iterative power iteration and QR decomposition to update the approximated |
109 | 113 | eigenvectors. When the initial estimate is the zero matrix, the eigendecomposition is computed |
@@ -133,46 +137,129 @@ def eig_orthogonal_iteration( |
133 | 137 | return eigh_with_fallback(x, force_double=True) |
134 | 138 |
|
135 | 139 | # Perform power iteration and QR decomposition iteratively. |
136 | | - with utils.fp32_matmul_precision("highest"): |
137 | | - Q = approx_eigenvectors |
138 | | - approx_eigenvalues_matrix = Q.T @ x @ Q |
139 | | - approx_eigenvalues = torch.diag(approx_eigenvalues_matrix) |
140 | | - iteration = 0 |
141 | | - while iteration < max_iterations and not met_approx_eigvals_criteria(approx_eigenvalues_matrix, tolerance): |
142 | | - power_iteration = x @ Q |
143 | | - Q = torch.linalg.qr(power_iteration).Q |
144 | | - approx_eigenvalues_matrix = Q.T @ x @ Q |
145 | | - iteration += 1 |
146 | | - # Sort eigenvalues in descending order and reorder eigenvectors accordingly |
147 | | - # Sorting can help mitigate numerical instability since QR decompositions can mix the approximated eigenvectors |
148 | | - approx_eigenvalues, indices = torch.diag(approx_eigenvalues_matrix).sort(stable=True, descending=True) |
149 | | - Q = Q[:, indices] |
150 | | - |
151 | | - return approx_eigenvalues, Q |
152 | | - |
153 | | - |
154 | | -def met_approx_eigvals_criteria(approx_eigenvalues_matrix: Tensor, tolerance: float) -> bool: |
155 | | - """Evaluates if a criteria using approximated eigenvalues is below or equal to the tolerance. |
156 | | -
|
157 | | - `approx_eigenvalues_matrix` is a matrix created from the approximated eigenvectors and the symmetric matrix |
158 | | - that is being eigendecomposed. We check if the ratio of the diagonal norm to the matrix norm is greater |
159 | | - than or equal to (1 - tolerance). |
| 140 | + Q = approx_eigenvectors |
| 141 | + approx_eigvals = conjugate(x, Q, diag=True) |
| 142 | + iteration = 0 |
| 143 | + while iteration < max_iterations and not met_approx_eigvals_criteria(x, approx_eigvals, tolerance): |
| 144 | + power_iteration = x @ Q |
| 145 | + Q = torch.linalg.qr(power_iteration).Q |
| 146 | + approx_eigvals = conjugate(x, Q, diag=True) |
| 147 | + iteration += 1 |
| 148 | + # Sort eigenvalues in descending order and reorder eigenvectors accordingly |
| 149 | + # Sorting can help mitigate numerical instability since QR decompositions can mix the approximated eigenvectors |
| 150 | + sorted_approx_eigvals, indices = approx_eigvals.sort(stable=True, descending=True) |
| 151 | + Q = Q[:, indices] |
| 152 | + |
| 153 | + return sorted_approx_eigvals, Q |
| 154 | + |
| 155 | + |
| 156 | +def met_approx_eigvals_criteria( |
| 157 | + kronecker_factor: torch.Tensor, |
| 158 | + approx_eigvals: torch.Tensor, |
| 159 | + tolerance: float, |
| 160 | +) -> bool: |
| 161 | + """Determines whether the eigenbasis for a factor matrix met the desired criteria |
| 162 | +
|
| 163 | + The approximated eigenvalues update criteria is then defined as |
| 164 | + :math:`||diag(Q^T K Q)||_F >= (1 - tolerance) * (Q^T K Q)_F`, where :math:`Q` is the approximated eigenvectors and |
| 165 | + :math:`K` is the kronecker factor (L or R). |
| 166 | +
|
| 167 | + We use the kronecker factor and approximated eigenvalues directly to save compute because Frobenius norm of |
| 168 | + kronecker factor is the same as that of the approximated eigenvalues matrix. |
160 | 169 |
|
161 | 170 | Args: |
162 | | - approx_eigenvalues_matrix: The symmetric matrix whose eigenvalues is being eigendecomposed. |
163 | | - tolerance: The tolerance for the early exit criteria, the min relative error between diagonal norm |
164 | | - and matrix norm of the approximated eigenvalues and the diagonal. |
| 171 | + kronecker_factor: Kronecker factor matrix. |
| 172 | + approx_eigvals: Approximated eigenvalues |
| 173 | + tolerance: Tolerance threshold for the normalized diagonal component of approximated eigenvalue matrix. |
165 | 174 |
|
166 | 175 | Returns: |
167 | | - bool: True if the criteria is below or equal to the tolerance, False otherwise. |
168 | | -
|
| 176 | + perform_update: Whether to update eigenbasis this iteration |
169 | 177 | """ |
170 | | - matrix_norm = torch.linalg.norm(approx_eigenvalues_matrix) |
171 | | - approx_eigvals = torch.diag(approx_eigenvalues_matrix) |
| 178 | + matrix_norm = torch.linalg.norm(kronecker_factor) |
172 | 179 | diagonal_norm = torch.linalg.norm(approx_eigvals) |
| 180 | + |
173 | 181 | return diagonal_norm >= (1 - tolerance) * matrix_norm |
174 | 182 |
|
175 | 183 |
|
| 184 | +def orthogonal_iteration( |
| 185 | + approx_eigvals: torch.Tensor, |
| 186 | + kronecker_factor: torch.Tensor, |
| 187 | + eigenbasis: torch.Tensor, |
| 188 | + ind: int, |
| 189 | + exp_avg_sq: torch.Tensor, |
| 190 | + convert_to_float: bool, |
| 191 | + power_iter_steps: int, |
| 192 | +) -> Tuple[torch.Tensor, torch.Tensor]: |
| 193 | + """Computes the eigenbases of the preconditioner using power iteration and QR decomposition. |
| 194 | +
|
| 195 | + This function performs multiple rounds of power iteration followed by QR decomposition |
| 196 | + to recompute the eigenbases of the preconditioner kronecker factor. Generalizes Vyas et al.'s (SOAP) algorithm of 1 step of power iteration for updating the eigenbasis. |
| 197 | +
|
| 198 | + Args: |
| 199 | + approx_eigenvalue_matrix : Projection of kronecker factor onto the eigenbasis, should be close to diagonal |
| 200 | + kronecker_factor : Kronecker factor matrix. |
| 201 | + eigenbasis : Kronecker factor eigenbasis matrix. |
| 202 | + ind : Index for selecting dimension in the exp_avg_sq matrix to apply the sorting order over. |
| 203 | + exp_avg_sq : inner Adam second moment (exp_avg_sq). |
| 204 | + convert_to_float : If True, preconditioner matrices and their corresponding |
| 205 | + orthonormal matrices will be cast to float. Otherwise, they are left in |
| 206 | + their original type. Defaults to False. |
| 207 | + power_iter_steps: Number of power iteration steps to perform before QR decomposition. |
| 208 | + More steps can lead to better convergence but increased computation time. |
| 209 | +
|
| 210 | + Returns: |
| 211 | + tuple[torch.Tensor, torch.Tensor]: A tuple containing: |
| 212 | + - Q: The updated eigenbasis |
| 213 | + - exp_avg_sq: The updated (sorted) inner Adam second moment |
| 214 | + """ |
| 215 | + # Sort the approximated eigenvalues according to their magnitudes |
| 216 | + sort_idx = torch.argsort(approx_eigvals, descending=True) |
| 217 | + # re-order the inner adam second moment |
| 218 | + exp_avg_sq = exp_avg_sq.index_select(ind, sort_idx) |
| 219 | + |
| 220 | + # Initialize power iteration after sorting the columns of the eigenbasis matrix according to the descending eigenvalues |
| 221 | + Q = eigenbasis[:, sort_idx] |
| 222 | + |
| 223 | + # By default, perform QR decomposition with power iteration with FP32 precision |
| 224 | + # Perform multiple steps of power iteration |
| 225 | + for _ in range(power_iter_steps): |
| 226 | + # Project current eigenbases on kronecker factor |
| 227 | + Q = kronecker_factor @ Q |
| 228 | + # Perform QR to maintain orthogonality between iterations |
| 229 | + Q = torch.linalg.qr(Q).Q |
| 230 | + |
| 231 | + # When not converting to float, ensure that Q is in the original dtype |
| 232 | + if not convert_to_float: |
| 233 | + Q = Q.to(kronecker_factor.dtype) |
| 234 | + |
| 235 | + return Q, exp_avg_sq |
| 236 | + |
| 237 | + |
| 238 | +def conjugate(a: torch.Tensor, p: torch.Tensor, diag: bool = False) -> torch.Tensor: |
| 239 | + """Calculate similarity transformation |
| 240 | +
|
| 241 | + This function calculates :math:`B = P^T A P`. It assumes P is orthogonal so that :math:`P^{-1} = P^T` and |
| 242 | + the similarity transformation exists. |
| 243 | +
|
| 244 | + Args: |
| 245 | + a: matrix to be transformed |
| 246 | + p: An orthogonal matrix. |
| 247 | + diag: If True, only return the diagonal of the similarity transformation |
| 248 | +
|
| 249 | + Returns: |
| 250 | + b |
| 251 | + """ |
| 252 | + if a.dim() != 2 or p.dim() != 2: |
| 253 | + raise TypeError("a and p must be 2D matrices") |
| 254 | + pta = p.T @ a |
| 255 | + if not diag: |
| 256 | + b = pta @ p |
| 257 | + else: |
| 258 | + # return the diagonal of the similarity transformation |
| 259 | + b = (pta * p.T).sum(dim=1) |
| 260 | + return b |
| 261 | + |
| 262 | + |
176 | 263 | def _is_diagonal(x: Tensor) -> bool: |
177 | 264 | r"""Checks if symmetric matrix is diagonal. Raises an error if the input is not a square matrix.""" |
178 | 265 |
|
|
0 commit comments