Skip to content

Commit c3c451b

Browse files
authored
Improve document (#33)
Use math equation in some of the docstrings. Get rid off long lines. Modified some unnecessary content. No function change. Signed-off-by: Hao Wu <[email protected]>
1 parent 2ebdb41 commit c3c451b

File tree

8 files changed

+60
-32
lines changed

8 files changed

+60
-32
lines changed

docs/apidocs/soap.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,15 @@ emerging_optimizers.soap
2121
.. autofunction:: update_kronecker_factors
2222
2323
.. autofunction:: update_eigenbasis_and_momentum
24+
25+
emerging_optimizers.soap.soap_utils
26+
=====================================
27+
28+
.. automodule:: emerging_optimizers.soap.soap_utils
29+
:members:
30+
31+
.. autofunction:: _orthogonal_iteration
32+
33+
.. autofunction:: _conjugate
34+
2435
```

docs/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
"numpy": ("https://numpy.org/doc/stable", None),
7373
"torch": ("https://pytorch.org/docs/2.5", None),
7474
}
75+
autodoc_typehints = "description"
7576

7677

7778
def linkcode_resolve(domain, info):

emerging_optimizers/orthogonalized_optimizers/muon.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,12 @@ class Muon(OrthogonalizedOptimizer):
3737
optimization via Frank-Wolfe.
3838
3939
References:
40-
- Jordan, K. *Muon Optimizer Implementation.* [`GitHub <https://github.com/KellerJordan/Muon/blob/master/muon.py>`_]
41-
- *Modular Duality in Deep Learning.* arXiv:2410.21265 (2024). [`arXiv:2410.21265 <https://arxiv.org/abs/2410.21265>`_]
42-
- *Training Deep Learning Models with Norm-Constrained LMOs.* arXiv:2502.07529 (2025). [`arXiv:2502.07529 <https://arxiv.org/abs/2502.07529>`_]
40+
- Jordan, K. *Muon Optimizer Implementation.*
41+
[`GitHub <https://github.com/KellerJordan/Muon/blob/master/muon.py>`_]
42+
- *Modular Duality in Deep Learning.* arXiv:2410.21265 (2024).
43+
[`arXiv:2410.21265 <https://arxiv.org/abs/2410.21265>`_]
44+
- *Training Deep Learning Models with Norm-Constrained LMOs.* arXiv:2502.07529 (2025).
45+
[`arXiv:2502.07529 <https://arxiv.org/abs/2502.07529>`_]
4346
4447
Warning:
4548
- This optimizer requires that all parameters passed in are 2D.
@@ -122,7 +125,8 @@ def get_muon_scale_factor(
122125
# Suggested by K. Jordan and Kimi (https://arxiv.org/abs/2502.16982)
123126
return extra_scale_factor * max(size_out, size_in) ** 0.5
124127
elif mode == "unit_rms_norm":
125-
# Suggested by Scion (https://arxiv.org/abs/2502.07529) and Bernstein et al. (https://jeremybernste.in/writing/deriving-muon)
128+
# Suggested by Scion (https://arxiv.org/abs/2502.07529) and Bernstein et al.
129+
# (https://jeremybernste.in/writing/deriving-muon)
126130
return extra_scale_factor * (size_out / size_in) ** 0.5
127131
else:
128132
raise ValueError(f"Invalid mode for Muon update scale factor: {mode}")

emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,11 @@ class OrthogonalizedOptimizer(optim.Optimizer):
4545
4646
- Carlson, D., Cevher, V., and Carin, L. *Stochastic spectral descent for Restricted Boltzmann Machines.*
4747
In International Conference on Artificial Intelligence and Statistics (2015a).
48-
- Carlson, D., Hsieh, Y.-P., Collins, E., Carin, L., and Cevher, V. *Stochastic Spectral Descent for Discrete Graphical Models.*
48+
- Carlson, D., Hsieh, Y.-P., Collins, E., Carin, L., and Cevher, V.
49+
*Stochastic Spectral Descent for Discrete Graphical Models.*
4950
In IEEE Journal of Selected Topics in Signal Processing, vol. 10, no. 2, pp. 296-311 (2016).
50-
- Carlson, D., Collins, E., Hsieh, Y.-P., Carin, L., and Cevher, V. *Preconditioned spectral descent for deep learning.*
51+
- Carlson, D., Collins, E., Hsieh, Y.-P., Carin, L., and Cevher, V.
52+
*Preconditioned spectral descent for deep learning.*
5153
In Neural Information Processing Systems (2015b).
5254
- Flynn, T. *The duality structure gradient descent algorithm: analysis and applications to neural networks.*
5355
arXiv preprint arXiv:1708.00523 (2017). [`arXiv:1708.00523 <https://arxiv.org/abs/1708.00523>`_]

emerging_optimizers/soap/soap.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,8 @@ def __init__(
125125
original_adam_warmup_steps = adam_warmup_steps
126126
adam_warmup_steps = max(1, precondition_warmup_steps - 1)
127127
logging.info(
128-
f"adam_warmup_steps ({original_adam_warmup_steps}) should be less than precondition_warmup_steps ({precondition_warmup_steps}). "
128+
f"adam_warmup_steps ({original_adam_warmup_steps}) should be less "
129+
f"than precondition_warmup_steps ({precondition_warmup_steps}). "
129130
f"Setting adam_warmup_steps to {adam_warmup_steps} by default."
130131
)
131132

@@ -193,7 +194,8 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
193194
precondition_1d=group["precondition_1d"],
194195
)
195196

196-
# Update preconditioner matrices with gradient statistics, do not use shampoo_beta for EMA at first step
197+
# Update preconditioner matrices with gradient statistics,
198+
# do not use shampoo_beta for EMA at first step
197199
with utils.fp32_matmul_precision(group["fp32_matmul_prec"]):
198200
update_kronecker_factors(
199201
kronecker_factor_list=state["GG"],
@@ -282,7 +284,8 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
282284
)
283285
torch.cuda.nvtx.range_pop()
284286

285-
# If current step is the last step to skip preconditioning, initialize eigenbases and end first order warmup
287+
# If current step is the last step to skip preconditioning, initialize eigenbases and
288+
# end first order warmup
286289
if state["step"] == group["adam_warmup_steps"]:
287290
# Obtain kronecker factor eigenbases from kronecker factor matrices using eigendecomposition
288291
state["Q"] = get_eigenbasis_eigh(state["GG"])
@@ -425,7 +428,8 @@ def update_kronecker_factors(
425428
else:
426429
# For 1D tensors, skip preconditioning
427430
logging.error(
428-
"1D tensor is passed to update_kronecker_factors, but precondition_1d is not set to True, skipping preconditioning."
431+
"1D tensor is passed to update_kronecker_factors, "
432+
"but precondition_1d is not set to True, skipping preconditioning."
429433
)
430434
return
431435
else:
@@ -586,7 +590,8 @@ def precondition(
586590
)
587591
else:
588592
# Permute gradient dimensions to process the next dimension in the following iteration
589-
# when preconditioning for the current dimension is skipped (Q is empty), in the case of one-sided preconditioning.
593+
# when preconditioning for the current dimension is skipped (Q is empty), in the case of
594+
# one-sided preconditioning.
590595
permute_order = list(range(1, grad.dim())) + [0]
591596
grad = grad.permute(permute_order)
592597

emerging_optimizers/soap/soap_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def get_eigenbasis_eigh(
8585
# We use an empty tensor so that the `precondition` function will skip this factor.
8686
updated_eigenbasis_list.append(torch.empty(0, device=kronecker_factor.device))
8787
continue
88-
# Construct approximated eigenvalues using QL^T@L@QL or QR^T@R@QR.
88+
# Construct approximated eigenvalues using :math:`Q_L^T L Q_L` or :math:`Q_R^T R Q_R`.
8989
# The approximated eigenvalues should be close to diagonal if the eigenbasis is close to the true
9090
# eigenbasis of the kronecker factor (i.e. the approximated eigenvectors diagonalize the kronecker factor)
9191
approx_eigenvalue_matrix = eigenbasis.T @ kronecker_factor @ eigenbasis
@@ -128,8 +128,8 @@ def get_eigenbasis_qr(
128128
Computes using multiple rounds of power iteration followed by QR decomposition (orthogonal iteration).
129129
130130
Args:
131-
kronecker_factor_list: List containing preconditioner (GGT and GTG)
132-
eigenbasis_list: List containing eigenbases (QL and QR)
131+
kronecker_factor_list: List containing preconditioner (:math:`GG^T` and :math:`G^TG`)
132+
eigenbasis_list: List containing eigenbases (:math:`Q_L` and :math:`Q_R`)
133133
exp_avg_sq: inner adam second moment (exp_avg_sq). This tensor is modified in-place.
134134
convert_to_float: If True, preconditioner matrices and their corresponding
135135
orthonormal matrices will be cast to float. Otherwise, they are left in
@@ -207,7 +207,7 @@ def get_eigenbasis_qr(
207207
# Update eigenbasis when necessary. Update is skipped only when use_adaptive_criteria is True
208208
# but criteria is not met.
209209
if_update = True
210-
# construct approximated eigenvalues using QL^T@L@QL or QR^T@R@QR, which should be close to diagonal
210+
# construct approximated eigenvalues using :math:`Q_L^T L Q_L` or :math:`Q_R^T R Q_R`, which should be close to diagonal
211211
# if the eigenbasis is close to the true eigenbasis of the kronecker factor (i.e. diagonalizes it)
212212
if use_adaptive_criteria:
213213
approx_eigenvalue_matrix = _conjugate(kronecker_factor, eigenbasis)

emerging_optimizers/utils/eig.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,15 @@ def eigh_with_fallback(
3737
Default 2nd argument of eigh UPLO is 'L'.
3838
3939
Args:
40-
x: Tensor of shape (*, n, n) where "*" is zero or more batch dimensions consisting of symmetric or Hermitian matrices.
40+
x: Tensor of shape (*, n, n) where "*" is zero or more batch dimensions consisting of symmetric or
41+
Hermitian matrices.
4142
force_double: Force double precision computation. Default False.
42-
eps: Small offset for numerical stability. If None, uses dtype-appropriate values (1e-7 for float32, 1e-15 for float64). Default None.
43+
eps: Small offset for numerical stability. If None, uses dtype-appropriate values (1e-7 for float32,
44+
1e-15 for float64). Default None.
4345
output_dtype: Desired output dtype. If None, uses input dtype. Default None.
4446
4547
Returns:
46-
tuple[Tensor, Tensor]: Eigenvalues and eigenvectors tuple (eigenvalues in descending order).
48+
Eigenvalues and eigenvectors tuple (eigenvalues in descending order).
4749
"""
4850
input_dtype = x.dtype
4951
if output_dtype is None:
@@ -100,25 +102,26 @@ def eig_orthogonal_iteration(
100102
max_iterations: int = 1,
101103
tolerance: float = 0.01,
102104
) -> tuple[Tensor, Tensor]:
103-
"""Approximately compute the eigendecomposition of a symmetric matrix by performing the orthogonal iteration algorithm.
105+
"""Approximately compute the eigen decomposition
104106
105107
106-
Orthogonal or subspace iteration uses iterative power iteration and QR decomposition to update the approximated eigenvectors.
107-
When the initial estimate is the zero matrix, the eigendecomposition is computed using `eigh_with_fallback`.
108+
Orthogonal or subspace iteration uses iterative power iteration and QR decomposition to update the approximated
109+
eigenvectors. When the initial estimate is the zero matrix, the eigendecomposition is computed
110+
using `eigh_with_fallback`.
108111
109-
Based on Purifying Shampoo (https://www.arxiv.org/abs/2506.03595), we use an early exit criteria to stop the QR iterations.
110-
This generalizes SOAP's algorithm of 1 step of power iteration for updating the eigenbasis.
112+
Based on Purifying Shampoo (https://www.arxiv.org/abs/2506.03595), we use an early exit criteria to stop the
113+
QR iterations. This generalizes SOAP's algorithm of 1 step of power iteration for updating the eigenbasis.
111114
112115
Args:
113116
x: tensor of shape (n, n) where x is a symmetric or Hermitian matrix.
114117
approx_eigenvectors: The current estimate of the eigenvectors of x. If None or a zero matrix,
115118
falls back to using `eigh_with_fallback`.
116-
max_iterations: The maximum number of iterations to perform. (Default: 1)
117-
tolerance: The tolerance for determining convergence in terms of the norm of the off-diagonal elements of the approximated eigenvalues.
118-
(Default: 0.01)
119+
max_iterations: The maximum number of iterations to perform.
120+
tolerance: The tolerance for determining convergence in terms of the norm of the off-diagonal elements
121+
of the approximated eigenvalues.
119122
120123
Returns:
121-
tuple[Tensor, Tensor]: A tuple containing the approximated eigenvalues and eigenvectors matrix of the input matrix A.
124+
A tuple containing the approximated eigenvalues and eigenvectors matrix of the input matrix A.
122125
"""
123126

124127
# Check if x is already a diagonal matrix
@@ -151,12 +154,14 @@ def eig_orthogonal_iteration(
151154
def met_approx_eigvals_criteria(approx_eigenvalues_matrix: Tensor, tolerance: float) -> bool:
152155
"""Evaluates if a criteria using approximated eigenvalues is below or equal to the tolerance.
153156
154-
`approx_eigenvalues_matrix` is a matrix created from the approximated eigenvectors and the symmetric matrix that is being eigendecomposed.
155-
We check if the ratio of the diagonal norm to the matrix norm is greater than or equal to (1 - tolerance).
157+
`approx_eigenvalues_matrix` is a matrix created from the approximated eigenvectors and the symmetric matrix
158+
that is being eigendecomposed. We check if the ratio of the diagonal norm to the matrix norm is greater
159+
than or equal to (1 - tolerance).
156160
157161
Args:
158162
approx_eigenvalues_matrix: The symmetric matrix whose eigenvalues is being eigendecomposed.
159-
tolerance: The tolerance for the early exit criteria, the min relative error between diagonal norm and matrix norm of the approximated eigenvalues and the diagonal.
163+
tolerance: The tolerance for the early exit criteria, the min relative error between diagonal norm
164+
and matrix norm of the approximated eigenvalues and the diagonal.
160165
161166
Returns:
162167
bool: True if the criteria is below or equal to the tolerance, False otherwise.
@@ -189,7 +194,7 @@ def _try_handle_diagonal_matrix(x: Tensor) -> Optional[tuple[Tensor, Tensor]]:
189194
x: Tensor of shape (n, n) where x is a symmetric or Hermitian matrix.
190195
191196
Returns:
192-
Optional[tuple[Tensor, Tensor]]: Sorted eigenvalues and eigenvectors if A is diagonal, None otherwise.
197+
Sorted eigenvalues and eigenvectors if A is diagonal, None otherwise.
193198
"""
194199
input_dtype = x.dtype
195200
if _is_diagonal(x):

tests/test_soap_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def test_tensordot_vs_matmul(self, m, n):
149149
{"N": 32, "M": 8},
150150
)
151151
def test_project_and_project_back(self, N: int, M: int) -> None:
152-
"""Tests that projecting a tensor to eigenbasis of QL and QR and then projecting it back results in the original tensor.
152+
"""Tests that projecting a tensor to eigenbasis of QL and QR and back
153153
154154
The projected tensor should approximately recover the original tensor.
155155
"""

0 commit comments

Comments
 (0)