Skip to content

Commit ce0f192

Browse files
committed
enforce stricter type check
Signed-off-by: Hao Wu <[email protected]>
1 parent fc589ee commit ce0f192

File tree

12 files changed

+86
-15
lines changed

12 files changed

+86
-15
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,11 @@ repos:
3333
- id: ruff-format
3434

3535
- repo: https://github.com/pre-commit/mirrors-mypy
36-
rev: v1.14.0
36+
rev: v1.19.1
3737
hooks:
3838
- id: mypy
3939
exclude: ^docs|^tests|^benchmarks|^docker
40+
additional_dependencies: ["torch"]
4041

4142
- repo: local
4243
hooks:

emerging_optimizers/orthogonalized_optimizers/adaptive_muon.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
from typing import Callable, Literal
15+
from typing import Callable, Literal, overload
1616

1717

1818
# TODO(@boxiangw): remove this once bump to python 3.12
@@ -181,6 +181,12 @@ def _apply_moment2_normalization(
181181
else:
182182
raise TypeError(f"Invalid second moment method: {self.moment2_method}")
183183

184+
@overload
185+
def step(self, closure: None = ...) -> None: ...
186+
187+
@overload
188+
def step(self, closure: Callable[[], float]) -> float: ...
189+
184190
@torch.no_grad() # type: ignore[misc]
185191
@override
186192
def step(self, closure: Callable[[], float] | None = None) -> float | None:

emerging_optimizers/orthogonalized_optimizers/mop.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def __init__(
5757
def scaled_orthogonalize_fn(grad: torch.Tensor) -> torch.Tensor:
5858
orth_grad, _, S = polar_via_svd(grad, False)
5959

60+
scale_factor: float | torch.Tensor
6061
if scale_mode != "nuclear_norm":
6162
scale_factor = muon.get_muon_scale_factor(grad.size(-2), grad.size(-1), mode=scale_mode)
6263
else:

emerging_optimizers/orthogonalized_optimizers/muon_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def newton_schulz(
123123
if tp_group is not None:
124124
X = distributed_normalize_p2(x, eps, tp_group)
125125
else:
126-
X = torch.nn.functional.normalize(x, p=2, dim=(-2, -1), eps=eps)
126+
X = torch.nn.functional.normalize(x, p=2, dim=(-2, -1), eps=eps) # type: ignore[arg-type]
127127

128128
if coefficient_type in _COEFFICIENT_SETS:
129129
coefficient_sets = _COEFFICIENT_SETS[coefficient_type]

emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
from typing import Any, Callable
15+
from typing import Any, Callable, overload
1616

1717

1818
# TODO(@boxiangw): remove this once bump to python 3.12
@@ -126,6 +126,12 @@ def __init__(
126126
super().__init__(params, default_args_dict)
127127
self.scaled_orthogonalize_fn = scaled_orthogonalize_fn
128128

129+
@overload
130+
def step(self, closure: None = ...) -> None: ...
131+
132+
@overload
133+
def step(self, closure: Callable[[], float]) -> float: ...
134+
129135
@torch.no_grad() # type: ignore[misc]
130136
@override
131137
def step(self, closure: Callable[[], float] | None = None) -> float | None:

emerging_optimizers/psgd/procrustes_step.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,15 +65,17 @@ def procrustes_step(
6565
# rotate Q as exp(a R) Q ~ (I + a R + a^2 R^2/2) Q with an optimal step size by line search
6666
# for 2nd order expansion, only expand exp(a R) to its 2nd term.
6767
# Q += _step_size * (RQ + 0.5 * _step_size * RRQ)
68-
Q = torch.add(Q, torch.add(RQ, RRQ, alpha=0.5 * step_size), alpha=step_size)
68+
Q = torch.add(Q, torch.add(RQ, RRQ, alpha=0.5 * step_size), alpha=step_size) # type: ignore[call-overload]
6969
if order == 3:
7070
RRRQ = R @ RRQ
7171
tr_RRRQ = torch.trace(RRRQ)
7272
# for a 3rd order expansion, we take the larger root of the cubic.
7373
_step_size = (-tr_RRQ - torch.sqrt(tr_RRQ * tr_RRQ - 1.5 * tr_RQ * tr_RRRQ)) / (0.75 * tr_RRRQ)
7474
step_size = torch.clamp(_step_size, max=max_step_size)
7575
# Q += step_size * (RQ + 0.5 * step_size * (RRQ + 0.25 * step_size * RRRQ))
76-
Q = torch.add(
77-
Q, torch.add(RQ, torch.add(RRQ, RRRQ, alpha=0.25 * step_size), alpha=0.5 * step_size), alpha=step_size
76+
Q = torch.add( # type: ignore[call-overload]
77+
Q,
78+
torch.add(RQ, torch.add(RRQ, RRRQ, alpha=0.25 * step_size), alpha=0.5 * step_size), # type: ignore[call-overload]
79+
alpha=step_size, # type: ignore[call-overload]
7880
)
7981
return Q

emerging_optimizers/psgd/psgd.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,13 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
import math
16-
from typing import Callable, override
16+
from typing import Callable, overload
17+
18+
19+
try:
20+
from typing import override
21+
except ImportError:
22+
from typing_extensions import override
1723

1824
import torch
1925
from torch.optim.optimizer import ParamsT
@@ -85,6 +91,12 @@ def __init__(
8591
}
8692
super().__init__(params, defaults)
8793

94+
@overload
95+
def step(self, closure: None = ...) -> None: ...
96+
97+
@overload
98+
def step(self, closure: Callable[[], float]) -> float: ...
99+
88100
@torch.no_grad() # type: ignore[misc]
89101
@override
90102
def step(self, closure: Callable[[], float] | None = None) -> float | None:

emerging_optimizers/riemannian_optimizers/normalized_optimizer.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,13 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
from typing import Callable
15+
from typing import Callable, overload
16+
17+
18+
try:
19+
from typing import override
20+
except ImportError:
21+
from typing_extensions import override
1622

1723
import torch
1824
from torch.optim.optimizer import Optimizer
@@ -65,8 +71,15 @@ def __init__(
6571
)
6672
super().__init__(params, defaults)
6773

74+
@overload
75+
def step(self, closure: None = ...) -> None: ...
76+
77+
@overload
78+
def step(self, closure: Callable[[], float]) -> float: ...
79+
6880
@torch.no_grad() # type: ignore[misc]
69-
def step(self, closure: Callable[[], float] | None = None) -> float | None:
81+
@override
82+
def step(self, closure: None = None) -> float | None:
7083
"""Performs a single optimization step.
7184
7285
Args:
@@ -154,7 +167,14 @@ def __init__(
154167
)
155168
super().__init__(params, defaults)
156169

170+
@overload
171+
def step(self, closure: None = ...) -> None: ...
172+
173+
@overload
174+
def step(self, closure: Callable[[], float]) -> float: ...
175+
157176
@torch.no_grad() # type: ignore[misc]
177+
@override
158178
def step(self, closure: Callable[[], float] | None = None) -> float | None:
159179
"""Performs a single optimization step.
160180

emerging_optimizers/soap/soap.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515
from functools import partial
1616
from itertools import chain
17-
from typing import Callable
17+
from typing import Callable, overload
1818

1919

2020
# TODO(@boxiangw): remove this once bump to python 3.12
@@ -136,6 +136,12 @@ def __init__(
136136
}
137137
super().__init__(params, defaults)
138138

139+
@overload
140+
def step(self, closure: None = ...) -> None: ...
141+
142+
@overload
143+
def step(self, closure: Callable[[], float]) -> float: ...
144+
139145
@torch.no_grad() # type: ignore[misc]
140146
@override
141147
def step(self, closure: Callable[[], float] | None = None) -> float | None:

emerging_optimizers/utils/modules.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@
1616
import math
1717
from typing import Any, Self
1818

19+
20+
try:
21+
from typing import override
22+
except ImportError:
23+
from typing_extensions import override
24+
1925
import torch
2026
import torch.nn as nn
2127
import torch.nn.functional as F
@@ -51,8 +57,8 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
5157

5258
assert self.padding_mode == "zeros", "Only zeros padding is supported"
5359

54-
self.weight: nn.Parameter[torch.Tensor]
55-
self.bias: nn.Parameter[torch.Tensor] | None | str
60+
self.weight: nn.Parameter
61+
self.bias: nn.Parameter | None
5662

5763
flat_weight_shape = [self.out_channels, math.prod(self.weight.shape[1:])]
5864
if self.bias is not None:
@@ -63,7 +69,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
6369
flat_weight_buffer[..., -1].copy_(self.bias)
6470
del self.bias
6571
self.has_bias = True
66-
self.bias = "dummy" # Trick con1d.extra_repr() to not print bias=False
6772
else:
6873
flat_weight_buffer.copy_(self.weight.view(self.out_channels, -1))
6974
self.has_bias = False
@@ -98,6 +103,7 @@ def from_conv1d(cls, conv1d: nn.Conv1d) -> Self:
98103
def weight_shape(self) -> tuple[int, int, int]:
99104
return (self.out_channels, self.in_channels // self.groups, self.kernel_size[0])
100105

106+
@override
101107
def forward(self, x: torch.Tensor) -> torch.Tensor:
102108
if self.has_bias:
103109
weight = self.weight[..., :-1].view(self.weight_shape)
@@ -108,6 +114,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
108114

109115
return F.conv1d(x, weight, bias, self.stride, self.padding, self.dilation, self.groups)
110116

117+
@override
111118
def extra_repr(self) -> str:
112119
base_repr = super().extra_repr()
120+
if self.has_bias:
121+
base_repr += ", bias=True"
113122
return f"{base_repr}, flattened_param_shape={tuple(self.weight.shape)}"

0 commit comments

Comments
 (0)