Skip to content

Commit d246453

Browse files
authored
Impose more strict type checking (#86)
* Added torch as additional_dependencies to mypy in pre-commit config. * Enabled error code explicit-override to mypy. * Fixed type issues based on AI's recommendation. Signed-off-by: Hao Wu <[email protected]>
1 parent a98438d commit d246453

File tree

20 files changed

+181
-90
lines changed

20 files changed

+181
-90
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,11 @@ repos:
3333
- id: ruff-format
3434

3535
- repo: https://github.com/pre-commit/mirrors-mypy
36-
rev: v1.14.0
36+
rev: v1.19.1
3737
hooks:
3838
- id: mypy
3939
exclude: ^docs|^tests|^benchmarks|^docker
40+
additional_dependencies: ["torch"]
4041

4142
- repo: local
4243
hooks:

emerging_optimizers/orthogonalized_optimizers/adaptive_muon.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
from typing import Callable, Literal
15+
from typing import Callable, Literal, overload
1616

1717

1818
# TODO(@boxiangw): remove this once bump to python 3.12
@@ -27,6 +27,8 @@
2727
from emerging_optimizers import mixin as opt_mixin
2828
from emerging_optimizers import utils
2929
from emerging_optimizers.orthogonalized_optimizers import muon
30+
from emerging_optimizers.orthogonalized_optimizers.muon_utils import NSCoeffT
31+
from emerging_optimizers.utils import FP32MatmulPrecT
3032

3133

3234
class AdaptiveMuon(muon.Muon):
@@ -65,8 +67,8 @@ def __init__(
6567
*,
6668
use_nesterov: bool,
6769
weight_decay_method: opt_mixin.WeightDecayT = "decoupled",
68-
fp32_matmul_prec: str,
69-
coefficient_type: str = "quintic",
70+
fp32_matmul_prec: FP32MatmulPrecT,
71+
coefficient_type: NSCoeffT = "quintic",
7072
num_ns_steps: int = 5,
7173
scale_mode: muon.MuonScaleT = "spectral",
7274
extra_scale_factor: float = 1.0,
@@ -179,6 +181,12 @@ def _apply_moment2_normalization(
179181
else:
180182
raise TypeError(f"Invalid second moment method: {self.moment2_method}")
181183

184+
@overload
185+
def step(self, closure: None = ...) -> None: ...
186+
187+
@overload
188+
def step(self, closure: Callable[[], float]) -> float: ...
189+
182190
@torch.no_grad() # type: ignore[misc]
183191
@override
184192
def step(self, closure: Callable[[], float] | None = None) -> float | None:

emerging_optimizers/orthogonalized_optimizers/mop.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from emerging_optimizers.mixin import WeightDecayT
2323
from emerging_optimizers.orthogonalized_optimizers import muon
2424
from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import OrthogonalizedOptimizer, _args_doc
25+
from emerging_optimizers.utils import FP32MatmulPrecT
2526

2627

2728
__all__ = ["MOP"]
@@ -49,13 +50,14 @@ def __init__(
4950
*,
5051
use_nesterov: bool = False,
5152
weight_decay_method: WeightDecayT = "decoupled",
52-
fp32_matmul_prec: str = "highest",
53+
fp32_matmul_prec: FP32MatmulPrecT = "highest",
5354
scale_mode: muon.MuonScaleT | Literal["nuclear_norm"] = "nuclear_norm",
5455
extra_scale_factor: float = 1.0,
5556
) -> None:
5657
def scaled_orthogonalize_fn(grad: torch.Tensor) -> torch.Tensor:
5758
orth_grad, _, S = polar_via_svd(grad, False)
5859

60+
scale_factor: float | torch.Tensor
5961
if scale_mode != "nuclear_norm":
6062
scale_factor = muon.get_muon_scale_factor(grad.size(-2), grad.size(-1), mode=scale_mode)
6163
else:

emerging_optimizers/orthogonalized_optimizers/muon.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@
2222
from emerging_optimizers import triton_kernels
2323
from emerging_optimizers.mixin import WeightDecayT
2424
from emerging_optimizers.orthogonalized_optimizers import muon_utils
25+
from emerging_optimizers.orthogonalized_optimizers.muon_utils import NSCoeffT
2526
from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import OrthogonalizedOptimizer, _args_doc
27+
from emerging_optimizers.utils import FP32MatmulPrecT
2628

2729

2830
MuonScaleT = Literal["shape_scaling", "spectral", "unit_rms_norm"]
@@ -75,8 +77,8 @@ def __init__(
7577
*,
7678
use_nesterov: bool = False,
7779
weight_decay_method: WeightDecayT = "decoupled",
78-
fp32_matmul_prec: str = "medium",
79-
coefficient_type: str = "quintic",
80+
fp32_matmul_prec: FP32MatmulPrecT = "medium",
81+
coefficient_type: NSCoeffT = "quintic",
8082
num_ns_steps: int = 5,
8183
scale_mode: MuonScaleT = "spectral",
8284
extra_scale_factor: float = 1.0,

emerging_optimizers/orthogonalized_optimizers/muon_utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
from emerging_optimizers import triton_kernels
2121

2222

23-
__all__ = ["newton_schulz", "newton_schulz_tp"]
23+
__all__ = ["newton_schulz", "newton_schulz_tp", "NSCoeffT"]
24+
25+
NSCoeffT = Literal["simple", "quintic", "polar_express", "aol", "custom"]
2426

2527
_COEFFICIENT_SETS = {
2628
"simple": [
@@ -67,7 +69,7 @@ def distributed_normalize_p2(x: torch.Tensor, eps: float, group: torch.distribut
6769
def newton_schulz(
6870
x: torch.Tensor,
6971
steps: int,
70-
coefficient_type: str = "quintic",
72+
coefficient_type: NSCoeffT = "quintic",
7173
custom_coefficient_sets: list[tuple[float, float, float]] | None = None,
7274
eps: float = 1e-7,
7375
transpose: bool | None = None,
@@ -121,7 +123,7 @@ def newton_schulz(
121123
if tp_group is not None:
122124
X = distributed_normalize_p2(x, eps, tp_group)
123125
else:
124-
X = torch.nn.functional.normalize(x, p=2, dim=(-2, -1), eps=eps)
126+
X = torch.nn.functional.normalize(x, p=2, dim=(-2, -1), eps=eps) # type: ignore[arg-type]
125127

126128
if coefficient_type in _COEFFICIENT_SETS:
127129
coefficient_sets = _COEFFICIENT_SETS[coefficient_type]
@@ -164,7 +166,7 @@ def newton_schulz(
164166
def newton_schulz_tp(
165167
x: torch.Tensor,
166168
steps: int,
167-
coefficient_type: str,
169+
coefficient_type: NSCoeffT,
168170
tp_group: torch.distributed.ProcessGroup,
169171
partition_dim: int | None = None,
170172
mode: Literal["duplicated", "distributed"] = "duplicated",

emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
from typing import Any, Callable
15+
from typing import Any, Callable, overload
1616

1717

1818
# TODO(@boxiangw): remove this once bump to python 3.12
@@ -28,6 +28,7 @@
2828

2929
from emerging_optimizers import mixin as opt_mixin
3030
from emerging_optimizers import utils
31+
from emerging_optimizers.utils import FP32MatmulPrecT
3132

3233

3334
_args_doc = """params: Iterable of parameters to optimize or dicts defining parameter groups
@@ -103,7 +104,7 @@ def __init__(
103104
*,
104105
use_nesterov: bool,
105106
weight_decay_method: opt_mixin.WeightDecayT,
106-
fp32_matmul_prec: str,
107+
fp32_matmul_prec: FP32MatmulPrecT,
107108
scaled_orthogonalize_fn: Callable | None = None,
108109
**kwargs: Any,
109110
):
@@ -125,6 +126,12 @@ def __init__(
125126
super().__init__(params, default_args_dict)
126127
self.scaled_orthogonalize_fn = scaled_orthogonalize_fn
127128

129+
@overload
130+
def step(self, closure: None = ...) -> None: ...
131+
132+
@overload
133+
def step(self, closure: Callable[[], float]) -> float: ...
134+
128135
@torch.no_grad() # type: ignore[misc]
129136
@override
130137
def step(self, closure: Callable[[], float] | None = None) -> float | None:

emerging_optimizers/orthogonalized_optimizers/scion.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717
from absl import logging
1818
from torch.optim.optimizer import ParamsT
1919

20+
from emerging_optimizers.orthogonalized_optimizers import muon_utils
2021
from emerging_optimizers.orthogonalized_optimizers.muon import get_muon_scale_factor
21-
from emerging_optimizers.orthogonalized_optimizers.muon_utils import newton_schulz
22+
from emerging_optimizers.orthogonalized_optimizers.muon_utils import NSCoeffT
2223
from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import OrthogonalizedOptimizer
24+
from emerging_optimizers.utils import FP32MatmulPrecT
2325

2426

2527
class Scion(OrthogonalizedOptimizer):
@@ -61,8 +63,8 @@ def __init__(
6163
lr: float = 3e-4,
6264
momentum_beta: float = 0.95,
6365
*,
64-
fp32_matmul_prec: str = "medium",
65-
coefficient_type: str = "quintic",
66+
fp32_matmul_prec: FP32MatmulPrecT = "medium",
67+
coefficient_type: NSCoeffT = "quintic",
6668
num_ns_steps: int = 5,
6769
spectral_radius: float = 1.0,
6870
) -> None:
@@ -83,7 +85,9 @@ def scaled_orthogonalize_fn(grad: torch.Tensor) -> torch.Tensor:
8385
logging.debug(
8486
f"Orthogonalizing grad with {num_ns_steps} steps, {coefficient_type} coefficient, spectral_radius={spectral_radius}"
8587
)
86-
orth_grad = newton_schulz(grad, steps=num_ns_steps, coefficient_type=coefficient_type, use_syrk=False)
88+
orth_grad = muon_utils.newton_schulz(
89+
grad, steps=num_ns_steps, coefficient_type=coefficient_type, use_syrk=False
90+
)
8791
width_factor = get_muon_scale_factor(grad.size(-2), grad.size(-1), mode="unit_rms_norm")
8892
return orth_grad * width_factor * spectral_radius
8993

emerging_optimizers/psgd/procrustes_step.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,15 +65,17 @@ def procrustes_step(
6565
# rotate Q as exp(a R) Q ~ (I + a R + a^2 R^2/2) Q with an optimal step size by line search
6666
# for 2nd order expansion, only expand exp(a R) to its 2nd term.
6767
# Q += _step_size * (RQ + 0.5 * _step_size * RRQ)
68-
Q = torch.add(Q, torch.add(RQ, RRQ, alpha=0.5 * step_size), alpha=step_size)
68+
Q = torch.add(Q, torch.add(RQ, RRQ, alpha=0.5 * step_size), alpha=step_size) # type: ignore[call-overload]
6969
if order == 3:
7070
RRRQ = R @ RRQ
7171
tr_RRRQ = torch.trace(RRRQ)
7272
# for a 3rd order expansion, we take the larger root of the cubic.
7373
_step_size = (-tr_RRQ - torch.sqrt(tr_RRQ * tr_RRQ - 1.5 * tr_RQ * tr_RRRQ)) / (0.75 * tr_RRRQ)
7474
step_size = torch.clamp(_step_size, max=max_step_size)
7575
# Q += step_size * (RQ + 0.5 * step_size * (RRQ + 0.25 * step_size * RRRQ))
76-
Q = torch.add(
77-
Q, torch.add(RQ, torch.add(RRQ, RRRQ, alpha=0.25 * step_size), alpha=0.5 * step_size), alpha=step_size
76+
Q = torch.add( # type: ignore[call-overload]
77+
Q,
78+
torch.add(RQ, torch.add(RRQ, RRRQ, alpha=0.25 * step_size), alpha=0.5 * step_size), # type: ignore[call-overload]
79+
alpha=step_size, # type: ignore[call-overload]
7880
)
7981
return Q

emerging_optimizers/psgd/psgd.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,13 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
import math
16-
from typing import Callable, List, Tuple, override
16+
from typing import Callable, overload
17+
18+
19+
try:
20+
from typing import override
21+
except ImportError:
22+
from typing_extensions import override
1723

1824
import torch
1925
from torch.optim.optimizer import ParamsT
@@ -85,6 +91,12 @@ def __init__(
8591
}
8692
super().__init__(params, defaults)
8793

94+
@overload
95+
def step(self, closure: None = ...) -> None: ...
96+
97+
@overload
98+
def step(self, closure: Callable[[], float]) -> float: ...
99+
88100
@torch.no_grad() # type: ignore[misc]
89101
@override
90102
def step(self, closure: Callable[[], float] | None = None) -> float | None:
@@ -154,7 +166,7 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
154166
def _init_psgd_kron_states(
155167
grad: torch.Tensor,
156168
precond_init_scale: float = 1.0,
157-
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
169+
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
158170
"""Initialize the Kronecker factor matrices and Lipschitz constants.
159171
160172
Args:
@@ -165,8 +177,8 @@ def _init_psgd_kron_states(
165177
q_list: List of Kronecker factors.
166178
lip_const_list: List of Lipschitz constants for the Kronecker factors.
167179
"""
168-
q_list: List[torch.Tensor] = []
169-
lip_const_list: List[torch.Tensor] = []
180+
q_list: list[torch.Tensor] = []
181+
lip_const_list: list[torch.Tensor] = []
170182

171183
# Create identity matrices scaled by precond_init_scale for each dimension
172184
for size in grad.shape:
@@ -177,13 +189,13 @@ def _init_psgd_kron_states(
177189

178190

179191
def _update_precond_procrustes(
180-
q_list: List[torch.Tensor],
181-
lip_const_list: List[torch.Tensor],
192+
q_list: list[torch.Tensor],
193+
lip_const_list: list[torch.Tensor],
182194
exp_avg: torch.Tensor,
183195
damping_noise_scale: float = 1e-9,
184196
precond_lr: float = 0.1,
185197
beta_lip: float = 0.9,
186-
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
198+
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
187199
r"""Update the Kron preconditioner Q using procrustes step and uniformization.
188200
189201
Args:
@@ -201,8 +213,8 @@ def _update_precond_procrustes(
201213
dampened_momentum = exp_avg + (damping_noise_scale + 1e-7 * exp_avg.abs()) * torch.randn_like(exp_avg)
202214
pg = psgd_kron_contractions.apply_preconditioner(q_list, dampened_momentum)
203215
total_numel = pg.numel()
204-
updated_q_list: List[torch.Tensor] = []
205-
updated_lip_const_list: List[torch.Tensor] = []
216+
updated_q_list: list[torch.Tensor] = []
217+
updated_lip_const_list: list[torch.Tensor] = []
206218
for dim, q in enumerate(q_list):
207219
# compute gradient covariance
208220
precond_grad_cov = psgd_kron_contractions.partial_contraction(pg, pg, dim)
@@ -229,7 +241,7 @@ def _update_matrix_preconditioner(
229241
total_numel: int,
230242
precond_lr: float,
231243
beta_lip: float,
232-
) -> Tuple[torch.Tensor, torch.Tensor]:
244+
) -> tuple[torch.Tensor, torch.Tensor]:
233245
r"""Update matrix-structured preconditioner with adaptive Lipschitz constant.
234246
235247
Args:
@@ -259,7 +271,7 @@ def _update_1d_preconditioner(
259271
total_numel: int,
260272
precond_lr: float,
261273
beta_lip: float,
262-
) -> Tuple[torch.Tensor, torch.Tensor]:
274+
) -> tuple[torch.Tensor, torch.Tensor]:
263275
r"""Update 1D preconditioner with adaptive Lipschitz constant.
264276
265277
Args:

emerging_optimizers/psgd/psgd_utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
from typing import List
16-
1715
import torch
1816

1917

@@ -25,7 +23,7 @@
2523

2624

2725
@torch.compile # type: ignore[misc]
28-
def uniformize_q_in_place(Q_list: List[torch.Tensor]) -> None:
26+
def uniformize_q_in_place(Q_list: list[torch.Tensor]) -> None:
2927
"""Balance the dynamic ranges of kronecker factors in place to prevent numerical underflow or overflow.
3028
3129
Each tensor in `Q_list` is rescaled so that its maximum absolute entry

0 commit comments

Comments
 (0)