Skip to content

Commit 770cf64

Browse files
committed
update critical typing issues
Signed-off-by: Hao Wu <[email protected]>
1 parent a98438d commit 770cf64

File tree

8 files changed

+52
-58
lines changed

8 files changed

+52
-58
lines changed

emerging_optimizers/psgd/psgd.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
import math
16-
from typing import Callable, List, Tuple, override
16+
from typing import Callable, override
1717

1818
import torch
1919
from torch.optim.optimizer import ParamsT
@@ -154,7 +154,7 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
154154
def _init_psgd_kron_states(
155155
grad: torch.Tensor,
156156
precond_init_scale: float = 1.0,
157-
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
157+
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
158158
"""Initialize the Kronecker factor matrices and Lipschitz constants.
159159
160160
Args:
@@ -165,8 +165,8 @@ def _init_psgd_kron_states(
165165
q_list: List of Kronecker factors.
166166
lip_const_list: List of Lipschitz constants for the Kronecker factors.
167167
"""
168-
q_list: List[torch.Tensor] = []
169-
lip_const_list: List[torch.Tensor] = []
168+
q_list: list[torch.Tensor] = []
169+
lip_const_list: list[torch.Tensor] = []
170170

171171
# Create identity matrices scaled by precond_init_scale for each dimension
172172
for size in grad.shape:
@@ -177,13 +177,13 @@ def _init_psgd_kron_states(
177177

178178

179179
def _update_precond_procrustes(
180-
q_list: List[torch.Tensor],
181-
lip_const_list: List[torch.Tensor],
180+
q_list: list[torch.Tensor],
181+
lip_const_list: list[torch.Tensor],
182182
exp_avg: torch.Tensor,
183183
damping_noise_scale: float = 1e-9,
184184
precond_lr: float = 0.1,
185185
beta_lip: float = 0.9,
186-
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
186+
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
187187
r"""Update the Kron preconditioner Q using procrustes step and uniformization.
188188
189189
Args:
@@ -201,8 +201,8 @@ def _update_precond_procrustes(
201201
dampened_momentum = exp_avg + (damping_noise_scale + 1e-7 * exp_avg.abs()) * torch.randn_like(exp_avg)
202202
pg = psgd_kron_contractions.apply_preconditioner(q_list, dampened_momentum)
203203
total_numel = pg.numel()
204-
updated_q_list: List[torch.Tensor] = []
205-
updated_lip_const_list: List[torch.Tensor] = []
204+
updated_q_list: list[torch.Tensor] = []
205+
updated_lip_const_list: list[torch.Tensor] = []
206206
for dim, q in enumerate(q_list):
207207
# compute gradient covariance
208208
precond_grad_cov = psgd_kron_contractions.partial_contraction(pg, pg, dim)
@@ -229,7 +229,7 @@ def _update_matrix_preconditioner(
229229
total_numel: int,
230230
precond_lr: float,
231231
beta_lip: float,
232-
) -> Tuple[torch.Tensor, torch.Tensor]:
232+
) -> tuple[torch.Tensor, torch.Tensor]:
233233
r"""Update matrix-structured preconditioner with adaptive Lipschitz constant.
234234
235235
Args:
@@ -259,7 +259,7 @@ def _update_1d_preconditioner(
259259
total_numel: int,
260260
precond_lr: float,
261261
beta_lip: float,
262-
) -> Tuple[torch.Tensor, torch.Tensor]:
262+
) -> tuple[torch.Tensor, torch.Tensor]:
263263
r"""Update 1D preconditioner with adaptive Lipschitz constant.
264264
265265
Args:

emerging_optimizers/psgd/psgd_utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
from typing import List
16-
1715
import torch
1816

1917

@@ -25,7 +23,7 @@
2523

2624

2725
@torch.compile # type: ignore[misc]
28-
def uniformize_q_in_place(Q_list: List[torch.Tensor]) -> None:
26+
def uniformize_q_in_place(Q_list: list[torch.Tensor]) -> None:
2927
"""Balance the dynamic ranges of kronecker factors in place to prevent numerical underflow or overflow.
3028
3129
Each tensor in `Q_list` is rescaled so that its maximum absolute entry

emerging_optimizers/scalar_optimizers/adam.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
from typing import Tuple
16-
1715
import torch
1816

1917

@@ -28,7 +26,7 @@ def calculate_adam_update(
2826
grad: torch.Tensor,
2927
exp_avg: torch.Tensor,
3028
exp_avg_sq: torch.Tensor,
31-
betas: Tuple[float, float],
29+
betas: tuple[float, float],
3230
correct_bias: bool,
3331
use_nesterov: bool,
3432
step: int,

emerging_optimizers/scalar_optimizers/lion.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
from typing import Optional
16-
1715
import torch
1816

1917

@@ -28,7 +26,7 @@ def calculate_lion_update(
2826
grad: torch.Tensor,
2927
exp_avg: torch.Tensor,
3028
momentum_beta: float,
31-
momentum_beta2: Optional[float] = None,
29+
momentum_beta2: float | None = None,
3230
) -> torch.Tensor:
3331
"""Performs the Lion update.
3432

emerging_optimizers/soap/soap.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515
from functools import partial
1616
from itertools import chain
17-
from typing import Callable, List, Optional, Tuple, Union
17+
from typing import Callable
1818

1919

2020
# TODO(@boxiangw): remove this once bump to python 3.12
@@ -86,14 +86,14 @@ def __init__(
8686
self,
8787
params: ParamsT,
8888
lr: float,
89-
betas: Tuple[float, float] = (0.9, 0.95),
89+
betas: tuple[float, float] = (0.9, 0.95),
9090
shampoo_beta: float = 0.95,
9191
eps: float = 1e-8,
9292
weight_decay: float = 0.01,
9393
*,
9494
weight_decay_method: opt_mixin.WeightDecayT = "decoupled",
9595
use_nesterov: bool = False,
96-
precondition_frequency: Union[int, Callable[[int], int]] = 1,
96+
precondition_frequency: int | Callable[[int], int] = 1,
9797
adam_warmup_steps: int = 0,
9898
precondition_1d: bool = False,
9999
correct_bias: bool = True,
@@ -293,7 +293,7 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
293293
def init_kronecker_factors(
294294
grad: torch.Tensor,
295295
precondition_1d: bool = False,
296-
) -> List[torch.Tensor]:
296+
) -> list[torch.Tensor]:
297297
"""Initializes the kronecker factor matrices for the SOAP optimizer.
298298
299299
This function creates the initial Kronecker factor matrices (L and R) used for
@@ -338,7 +338,7 @@ def init_kronecker_factors(
338338
>>> print(precond_2d[1].shape) # (20, 20)
339339
340340
"""
341-
kronecker_factor_list: List[torch.Tensor] = []
341+
kronecker_factor_list: list[torch.Tensor] = []
342342

343343
if grad.dim() == 1:
344344
if not precondition_1d:
@@ -358,7 +358,7 @@ def init_kronecker_factors(
358358

359359
@torch.no_grad() # type: ignore[misc]
360360
def update_kronecker_factors(
361-
kronecker_factor_list: List[torch.Tensor],
361+
kronecker_factor_list: list[torch.Tensor],
362362
grad: torch.Tensor,
363363
shampoo_beta: float,
364364
precondition_1d: bool = False,
@@ -414,10 +414,10 @@ def update_kronecker_factors(
414414

415415
@torch.no_grad() # type: ignore[misc]
416416
def update_kronecker_factors_kl_shampoo(
417-
kronecker_factor_list: List[torch.Tensor],
417+
kronecker_factor_list: list[torch.Tensor],
418418
grad: torch.Tensor,
419419
shampoo_beta: float,
420-
eigenbasis_list: List[torch.Tensor],
420+
eigenbasis_list: list[torch.Tensor],
421421
eps: float,
422422
eigval_exp: float = -1.0,
423423
) -> None:
@@ -457,16 +457,16 @@ def update_kronecker_factors_kl_shampoo(
457457

458458
@torch.no_grad() # type: ignore[misc]
459459
def update_eigenbasis_and_momentum(
460-
kronecker_factor_list: List[torch.Tensor],
461-
eigenbasis_list: List[torch.Tensor],
460+
kronecker_factor_list: list[torch.Tensor],
461+
eigenbasis_list: list[torch.Tensor],
462462
exp_avg_sq: torch.Tensor,
463463
momentum: torch.Tensor,
464464
use_eigh: bool = False,
465465
use_adaptive_criteria: bool = False,
466-
adaptive_update_tolerance: Optional[float] = None,
466+
adaptive_update_tolerance: float | None = None,
467467
power_iter_steps: int = 1,
468468
convert_to_float: bool = True,
469-
) -> Tuple[List[torch.Tensor], torch.Tensor, torch.Tensor]:
469+
) -> tuple[list[torch.Tensor], torch.Tensor, torch.Tensor]:
470470
"""Updates the eigenbases using QR decomposition and power iteration or eigh.
471471
472472
This function performs an update of the eigenbases (QL and QR)
@@ -559,8 +559,8 @@ def update_eigenbasis_and_momentum(
559559
@torch.compile # type: ignore[misc]
560560
def precondition(
561561
grad: torch.Tensor,
562-
eigenbasis_list: Optional[List[torch.Tensor]] = None,
563-
dims: Optional[List[List[int]]] = None,
562+
eigenbasis_list: list[torch.Tensor] | None = None,
563+
dims: list[list[int]] | None = None,
564564
) -> torch.Tensor:
565565
"""Projects the gradient to and from the eigenbases of the kronecker factor matrices.
566566
@@ -610,7 +610,7 @@ def precondition(
610610
def _is_eigenbasis_update_step(
611611
step: int,
612612
adam_warmup_steps: int,
613-
precondition_frequency: Union[int, Callable[[int], int]],
613+
precondition_frequency: int | Callable[[int], int],
614614
) -> bool:
615615
"""Checks if amortized computation of the eigenbasis should be recomputed.
616616

emerging_optimizers/soap/soap_utils.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,27 +12,30 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
from typing import List, Optional, Tuple
15+
from typing import TypeAlias
1616

1717
import torch
1818

1919
from emerging_optimizers.utils import eig as eig_utils
2020

2121

22+
TensorList: TypeAlias = list[torch.Tensor]
23+
24+
2225
__all__ = [
2326
"get_eigenbasis_eigh",
2427
"get_eigenbasis_qr",
2528
]
2629

2730

2831
def get_eigenbasis_eigh(
29-
kronecker_factor_list: List[torch.Tensor],
32+
kronecker_factor_list: TensorList,
3033
convert_to_float: bool = True,
31-
eigenbasis_list: Optional[List[torch.Tensor]] = None,
34+
eigenbasis_list: TensorList | None = None,
3235
use_adaptive_criteria: bool = False,
33-
adaptive_update_tolerance: Optional[float] = None,
34-
eps: Optional[float] = None,
35-
) -> List[torch.Tensor]:
36+
adaptive_update_tolerance: float | None = None,
37+
eps: float | None = None,
38+
) -> TensorList:
3639
"""Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition.
3740
3841
Args:
@@ -66,7 +69,7 @@ def get_eigenbasis_eigh(
6669
adaptive_update_tolerance = 1e-7
6770

6871
# cast the kronecker factor matrices to float32 if convert_to_float is True
69-
casted_matrix_list: List[torch.Tensor] = []
72+
casted_matrix_list: TensorList = []
7073
for kronecker_factor in kronecker_factor_list:
7174
if kronecker_factor.numel() == 0:
7275
casted_matrix_list.append(torch.empty(0, device=kronecker_factor.device))
@@ -76,7 +79,7 @@ def get_eigenbasis_eigh(
7679
else:
7780
casted_matrix_list.append(kronecker_factor)
7881

79-
updated_eigenbasis_list: List[torch.Tensor] = []
82+
updated_eigenbasis_list: TensorList = []
8083

8184
# use adaptive early exit criteria
8285
if use_adaptive_criteria and eigenbasis_list is not None:
@@ -112,14 +115,14 @@ def get_eigenbasis_eigh(
112115

113116

114117
def get_eigenbasis_qr(
115-
kronecker_factor_list: List[torch.Tensor],
116-
eigenbasis_list: List[torch.Tensor],
118+
kronecker_factor_list: TensorList,
119+
eigenbasis_list: TensorList,
117120
exp_avg_sq: torch.Tensor,
118121
convert_to_float: bool = True,
119122
use_adaptive_criteria: bool = False,
120-
adaptive_update_tolerance: Optional[float] = None,
123+
adaptive_update_tolerance: float | None = None,
121124
power_iter_steps: int = 1,
122-
) -> Tuple[List[torch.Tensor], torch.Tensor]:
125+
) -> tuple[TensorList, torch.Tensor]:
123126
"""Updates the eigenbases of the preconditioner using power iteration and QR.
124127
125128
Computes using multiple rounds of power iteration followed by QR decomposition (orthogonal iteration).
@@ -175,8 +178,8 @@ def get_eigenbasis_qr(
175178
if adaptive_update_tolerance is None:
176179
adaptive_update_tolerance = 1e-7
177180

178-
casted_matrix_list: List[torch.Tensor] = []
179-
casted_eigenbasis_list: List[torch.Tensor] = []
181+
casted_matrix_list: TensorList = []
182+
casted_eigenbasis_list: TensorList = []
180183
for kronecker_factor, eigenbasis in zip(kronecker_factor_list, eigenbasis_list, strict=True):
181184
# If the tensor is empty, propagate an empty tensor to the output lists.
182185
if kronecker_factor.numel() == 0:
@@ -195,7 +198,7 @@ def get_eigenbasis_qr(
195198
if convert_to_float and exp_avg_sq.dtype != torch.float:
196199
exp_avg_sq = exp_avg_sq.to(torch.float)
197200

198-
updated_eigenbasis_list: List[torch.Tensor] = []
201+
updated_eigenbasis_list: TensorList = []
199202
for ind, (kronecker_factor, eigenbasis) in enumerate(zip(casted_matrix_list, casted_eigenbasis_list, strict=True)):
200203
if kronecker_factor.numel() == 0:
201204
updated_eigenbasis_list.append(torch.empty(0, device=kronecker_factor.device))

emerging_optimizers/utils/eig.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
from typing import Optional, Tuple
16-
1715
import torch
1816
from absl import logging
1917
from torch import Tensor
@@ -30,8 +28,8 @@
3028
def eigh_with_fallback(
3129
x: Tensor,
3230
force_double: bool = False,
33-
eps: Optional[float] = None,
34-
output_dtype: Optional[torch.dtype] = None,
31+
eps: float | None = None,
32+
output_dtype: torch.dtype | None = None,
3533
) -> tuple[Tensor, Tensor]:
3634
r"""torch.linalg.eigh() function with double precision fallback
3735
@@ -190,7 +188,7 @@ def orthogonal_iteration(
190188
exp_avg_sq: torch.Tensor,
191189
convert_to_float: bool,
192190
power_iter_steps: int,
193-
) -> Tuple[torch.Tensor, torch.Tensor]:
191+
) -> tuple[torch.Tensor, torch.Tensor]:
194192
"""Computes the eigenbases of the preconditioner using power iteration and QR decomposition.
195193
196194
This function performs multiple rounds of power iteration followed by QR decomposition
@@ -275,7 +273,7 @@ def _is_diagonal(x: Tensor) -> bool:
275273
return not x.triu(diagonal=1).any() and not x.tril(diagonal=-1).any()
276274

277275

278-
def _try_handle_diagonal_matrix(x: Tensor) -> Optional[tuple[Tensor, Tensor]]:
276+
def _try_handle_diagonal_matrix(x: Tensor) -> tuple[Tensor, Tensor] | None:
279277
"""Checks if matrix A is diagonal and returns its eigenvalues/vectors in ascending order if so.
280278
281279
Args:

emerging_optimizers/utils/precondition_schedules.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
# limitations under the License.
1515
import math
1616
from abc import ABC, abstractmethod
17-
from typing import Dict
1817

1918

2019
__all__ = [
@@ -160,7 +159,7 @@ class StepSchedule(PreconditionSchedule):
160159
})
161160
"""
162161

163-
def __init__(self, schedule_dict: Dict[int, int], start_step: int = 0):
162+
def __init__(self, schedule_dict: dict[int, int], start_step: int = 0):
164163
"""Initialize with a dictionary mapping steps to frequencies.
165164
166165
Args:

0 commit comments

Comments
 (0)