Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
<!-- Get the codecov badge with a token direct from https://app.codecov.io/gh/NVIDIA-NeMo -->
[![codecov](https://codecov.io/gh/NVIDIA-NeMo/Emerging-Optimizers/graph/badge.svg?token=IQ6U7IFYN0)](https://codecov.io/gh/NVIDIA-NeMo/Emerging-Optimizers)
[![CICD NeMo](https://github.com/NVIDIA-NeMo/Emerging-Optimizers/actions/workflows/cicd-main.yml/badge.svg?branch=main)](https://github.com/NVIDIA-NeMo/Emerging-Optimizers/actions/workflows/cicd-main.yml)
[![Python 3.10+](https://img.shields.io/badge/python-3.10+-blue.svg)](https://www.python.org/downloads/release/python-3100/)
[![Python 3.12+](https://img.shields.io/badge/python-3.12+-blue.svg)](https://www.python.org/downloads/release/python-3120/)
![GitHub Repo stars](https://img.shields.io/github/stars/NVIDIA-NeMo/Emerging-Optimizers)
[![Documentation](https://img.shields.io/badge/api-reference-blue.svg)](https://docs.nvidia.com/nemo/emerging-optimizers/latest/index.html)

Expand Down Expand Up @@ -41,7 +41,7 @@ Emerging optimizers have demonstrated significant practical impact in large-scal

### Prerequisites

- Python 3.10 or higher, 3.12 is recommended
- Python 3.12 (Release v0.1.0 is the last version supports Python 3.10)
- PyTorch 2.0 or higher

### Install from Source
Expand Down
2 changes: 1 addition & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Emerging-Optimizers is under active development. All APIs are experimental and s

### Prerequisites

- Python 3.10 or higher, 3.12 is recommended
- Python 3.12
- PyTorch 2.0 or higher

### Install from Source
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Literal, overload


# TODO(@boxiangw): remove this once bump to python 3.12
try:
from typing import override
except ImportError:
from typing_extensions import override
from typing import Callable, Literal, overload, override

import torch
from torch.optim.optimizer import ParamsT
Expand Down
6 changes: 2 additions & 4 deletions emerging_optimizers/orthogonalized_optimizers/mop.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.


from typing import Literal, Optional
from typing import Literal

import torch
from torch.optim.optimizer import ParamsT
Expand Down Expand Up @@ -80,9 +80,7 @@ def scaled_orthogonalize_fn(grad: torch.Tensor) -> torch.Tensor:
MOP.__doc__ = MOP.__doc__.format(_args_doc=_args_doc) # type: ignore[union-attr]


def polar_via_svd(
A: torch.Tensor, return_p: bool = False
) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]:
def polar_via_svd(A: torch.Tensor, return_p: bool = False) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor]:
"""Compute polar decomposition via SVD

Args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, overload


# TODO(@boxiangw): remove this once bump to python 3.12
try:
from typing import override
except ImportError:
from typing_extensions import override
from typing import Any, Callable, overload, override

import torch
import torch.optim as optim
Expand Down
8 changes: 1 addition & 7 deletions emerging_optimizers/psgd/psgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Callable, overload


try:
from typing import override
except ImportError:
from typing_extensions import override
from typing import Callable, overload, override

import torch
from torch.optim.optimizer import ParamsT
Expand Down
8 changes: 3 additions & 5 deletions emerging_optimizers/psgd/psgd_kron_contractions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List

import torch


Expand Down Expand Up @@ -43,7 +41,7 @@ def partial_contraction(G1: torch.Tensor, G2: torch.Tensor, axis: int) -> torch.


@torch.compile # type: ignore[misc]
def apply_kronecker_factors(Q_list: List[torch.Tensor], X: torch.Tensor) -> torch.Tensor:
def apply_kronecker_factors(Q_list: list[torch.Tensor], X: torch.Tensor) -> torch.Tensor:
"""Apply all Kronecker factors once to tensor :math:`X`, each to its corresponding dimension.

This applies each :math:`Q` factor once, for example in 2D case: :math:`Q_1 X Q_2^T`.
Expand All @@ -67,7 +65,7 @@ def apply_kronecker_factors(Q_list: List[torch.Tensor], X: torch.Tensor) -> torc


@torch.compile # type: ignore[misc]
def apply_preconditioner(Q_list: List[torch.Tensor], X: torch.Tensor) -> torch.Tensor:
def apply_preconditioner(Q_list: list[torch.Tensor], X: torch.Tensor) -> torch.Tensor:
"""Apply the full PSGD preconditioner to X.

This is the full Kronecker product of PSGD's kronecker factors Q^T Q, applied to X.
Expand Down Expand Up @@ -130,7 +128,7 @@ def _dim_n_mul_and_permute(X: torch.Tensor, M: torch.Tensor, contract_dim: int)


@torch.compile # type: ignore[misc]
def _apply_single_kronecker_factor(Q_list: List[torch.Tensor], X: torch.Tensor, axis: int) -> torch.Tensor:
def _apply_single_kronecker_factor(Q_list: list[torch.Tensor], X: torch.Tensor, axis: int) -> torch.Tensor:
"""Apply a single Kronecker factor Q to X at dimension `axis`. Helper function for apply_kronecker_factors.

If Q is a vector, we multiply X by Q.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, overload


try:
from typing import override
except ImportError:
from typing_extensions import override
from typing import Callable, overload, override

import torch
from torch.optim.optimizer import Optimizer
Expand Down
11 changes: 5 additions & 6 deletions emerging_optimizers/scalar_optimizers/ademamix.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional, Tuple

import torch

Expand All @@ -30,9 +29,9 @@ def calculate_sim_ademamix_update(
grad: torch.Tensor,
exp_avg: torch.Tensor,
exp_avg_sq: torch.Tensor,
num_beta_fast_warmup_steps: Optional[int],
num_beta_fast_warmup_steps: int | None,
min_beta_fast: float,
betas: Tuple[float, float],
betas: tuple[float, float],
step: int,
eps: float,
correct_bias: bool,
Expand Down Expand Up @@ -107,9 +106,9 @@ def calculate_ademamix_update(
exp_avg_fast: torch.Tensor,
exp_avg_slow: torch.Tensor,
exp_avg_sq: torch.Tensor,
num_beta_slow_warmup_steps: Optional[int],
num_alpha_warmup_steps: Optional[int],
betas: Tuple[float, float, float],
num_beta_slow_warmup_steps: int | None,
num_alpha_warmup_steps: int | None,
betas: tuple[float, float, float],
step: int,
eps: float,
correct_bias: bool,
Expand Down
4 changes: 1 addition & 3 deletions emerging_optimizers/scalar_optimizers/laprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple

import torch


Expand All @@ -29,7 +27,7 @@ def calculate_laprop_update(
exp_avg: torch.Tensor,
exp_avg_sq: torch.Tensor,
correct_bias: bool,
betas: Tuple[float, float],
betas: tuple[float, float],
step: int,
eps: float,
) -> torch.Tensor:
Expand Down
17 changes: 5 additions & 12 deletions emerging_optimizers/soap/soap.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,7 @@
# limitations under the License.
from functools import partial
from itertools import chain
from typing import Callable, overload


# TODO(@boxiangw): remove this once bump to python 3.12
try:
from typing import override
except ImportError:
from typing_extensions import override
from typing import Callable, overload, override

import torch
from absl import logging
Expand Down Expand Up @@ -325,7 +318,7 @@ def init_kronecker_factors(
(like biases). If False, 1D tensors will skip preconditioning.

Returns:
List[torch.Tensor]: List of kronecker factor matrices (L and R in paper).
List of kronecker factor matrices (L and R in paper).
- For 1D tensors with precondition_1d=False: List containing an empty tensor
- For 1D tensors with precondition_1d=True: List containing a square matrix
- For higher dimensional tensors: List of square matrices, one per dimension
Expand Down Expand Up @@ -503,9 +496,9 @@ def update_eigenbasis_and_momentum(
orthonormal matrices to float for amortized computation. Otherwise, they are left in their original type.

Returns:
Tuple[List[torch.Tensor], torch.Tensor]: A tuple containing:
- List[torch.Tensor]: Updated list of eigenbases (QL and QR)
- torch.Tensor: Updated momentum tensor projected to the new eigenbasis
A tuple containing:
- Updated list of eigenbases (QL and QR)
- Updated momentum tensor projected to the new eigenbasis

Example:
>>> L = torch.randn(10, 10)
Expand Down
13 changes: 7 additions & 6 deletions emerging_optimizers/soap/soap_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def get_eigenbasis_eigh(
eps: Small offset for numerical stability. If None, uses dtype-appropriate values (1e-7 for float32, 1e-15 for float64)

Returns:
List[torch.Tensor]: List of orthonormal kronecker factor eigenbases matrices
List of orthonormal kronecker factor eigenbases matrices

Example:
.. code-block:: python
Expand Down Expand Up @@ -135,15 +135,16 @@ def get_eigenbasis_qr(
orthonormal matrices will be cast to float. Otherwise, they are left in
their original type.
use_adaptive_criteria: Whether to use update criteria strategy
adaptive_update_tolerance: Tolerance threshold for the normalized diagonal component of approximated eigenvalue matrix.
If None, defaults to 1e-7, which is appropriate for single precision computations. This means adaptive update criteria will be used whenever there is a small change in the approximated eigenvalues
matrix and QR will be used.
adaptive_update_tolerance: Tolerance threshold for the normalized diagonal component of approximated
eigenvalues matrix. If None, defaults to 1e-7, which is appropriate for single precision computations.
This means adaptive update criteria will be used whenever there is a small change in the approximated
eigenvalues matrix and QR will be used.
power_iter_steps: Number of power iteration steps to perform before QR decomposition.
More steps can lead to better convergence but increased computation time.

Returns:
List[torch.Tensor]: Updated list of orthonormal kronecker factor eigenbases matrices
torch.Tensor: Updated (sorted) inner adam second moment
Tuple of updated list of orthonormal kronecker factor eigenbases matrices and updated (sorted) inner
Adam's second moment.

Example:
.. code-block:: python
Expand Down
8 changes: 1 addition & 7 deletions emerging_optimizers/utils/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,7 @@
# limitations under the License.

import math
from typing import Any, Self


try:
from typing import override
except ImportError:
from typing_extensions import override
from typing import Any, Self, override

import torch
import torch.nn as nn
Expand Down
6 changes: 1 addition & 5 deletions emerging_optimizers/utils/precondition_schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,9 @@
# limitations under the License.
import math
from abc import ABC, abstractmethod
from typing import override


try:
from typing import override
except ImportError:
from typing_extensions import override

__all__ = [
"LinearSchedule",
"CosineSchedule",
Expand Down
5 changes: 2 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ dynamic = [
]
description = "A research project for emerging optimizers other than AdamW"
license = {file = "LICENSE"}
requires-python = ">=3.10"
requires-python = ">=3.12"
authors = [{ name = "NVIDIA", email = "nemo-toolkit@nvidia.com" }]
maintainers = [{ name = "NVIDIA", email = "nemo-toolkit@nvidia.com" }]
keywords = [
Expand Down Expand Up @@ -57,7 +57,6 @@ classifiers = [
dependencies = [
"torch",
"absl-py",
"typing-extensions",
]

[build-system]
Expand Down Expand Up @@ -106,7 +105,7 @@ Homepage = "https://github.com/NVIDIA-NeMo/Emerging-Optimizers"
[tool.ruff]
# Match black's line length
line-length = 119
target-version = "py310"
target-version = "py312"

# Exclude common directories that shouldn't be linted
exclude = [
Expand Down
3 changes: 1 addition & 2 deletions tests/soap_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
# SOFTWARE.

from itertools import chain
from typing import Tuple

import torch
import torch.optim as optim
Expand All @@ -56,7 +55,7 @@ def __init__(
self,
params,
lr: float,
betas: Tuple[float, float],
betas: tuple[float, float],
shampoo_beta: float,
eps: float,
weight_decay: float,
Expand Down
6 changes: 3 additions & 3 deletions tests/test_soap.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
import math
from functools import partial
from typing import Any, List
from typing import Any

import soap_reference
import torch
Expand All @@ -29,9 +29,9 @@


def kl_shampoo_update_ref(
kronecker_factor_list: List[torch.Tensor],
kronecker_factor_list: list[torch.Tensor],
grad: torch.Tensor,
eigenbasis_list: List[torch.Tensor],
eigenbasis_list: list[torch.Tensor],
shampoo_beta: float,
eps: float,
eigval_exp: float = -1.0,
Expand Down
Loading