1414# limitations under the License.
1515from functools import partial
1616from itertools import chain
17- from typing import Callable , List , Optional , Tuple , Union
17+ from typing import Callable
1818
1919
2020# TODO(@boxiangw): remove this once bump to python 3.12
@@ -86,14 +86,14 @@ def __init__(
8686 self ,
8787 params : ParamsT ,
8888 lr : float ,
89- betas : Tuple [float , float ] = (0.9 , 0.95 ),
89+ betas : tuple [float , float ] = (0.9 , 0.95 ),
9090 shampoo_beta : float = 0.95 ,
9191 eps : float = 1e-8 ,
9292 weight_decay : float = 0.01 ,
9393 * ,
9494 weight_decay_method : opt_mixin .WeightDecayT = "decoupled" ,
9595 use_nesterov : bool = False ,
96- precondition_frequency : Union [ int , Callable [[int ], int ] ] = 1 ,
96+ precondition_frequency : int | Callable [[int ], int ] = 1 ,
9797 adam_warmup_steps : int = 0 ,
9898 precondition_1d : bool = False ,
9999 correct_bias : bool = True ,
@@ -293,7 +293,7 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
293293def init_kronecker_factors (
294294 grad : torch .Tensor ,
295295 precondition_1d : bool = False ,
296- ) -> List [torch .Tensor ]:
296+ ) -> list [torch .Tensor ]:
297297 """Initializes the kronecker factor matrices for the SOAP optimizer.
298298
299299 This function creates the initial Kronecker factor matrices (L and R) used for
@@ -338,7 +338,7 @@ def init_kronecker_factors(
338338 >>> print(precond_2d[1].shape) # (20, 20)
339339
340340 """
341- kronecker_factor_list : List [torch .Tensor ] = []
341+ kronecker_factor_list : list [torch .Tensor ] = []
342342
343343 if grad .dim () == 1 :
344344 if not precondition_1d :
@@ -358,7 +358,7 @@ def init_kronecker_factors(
358358
359359@torch .no_grad () # type: ignore[misc]
360360def update_kronecker_factors (
361- kronecker_factor_list : List [torch .Tensor ],
361+ kronecker_factor_list : list [torch .Tensor ],
362362 grad : torch .Tensor ,
363363 shampoo_beta : float ,
364364 precondition_1d : bool = False ,
@@ -414,10 +414,10 @@ def update_kronecker_factors(
414414
415415@torch .no_grad () # type: ignore[misc]
416416def update_kronecker_factors_kl_shampoo (
417- kronecker_factor_list : List [torch .Tensor ],
417+ kronecker_factor_list : list [torch .Tensor ],
418418 grad : torch .Tensor ,
419419 shampoo_beta : float ,
420- eigenbasis_list : List [torch .Tensor ],
420+ eigenbasis_list : list [torch .Tensor ],
421421 eps : float ,
422422 eigval_exp : float = - 1.0 ,
423423) -> None :
@@ -457,16 +457,16 @@ def update_kronecker_factors_kl_shampoo(
457457
458458@torch .no_grad () # type: ignore[misc]
459459def update_eigenbasis_and_momentum (
460- kronecker_factor_list : List [torch .Tensor ],
461- eigenbasis_list : List [torch .Tensor ],
460+ kronecker_factor_list : list [torch .Tensor ],
461+ eigenbasis_list : list [torch .Tensor ],
462462 exp_avg_sq : torch .Tensor ,
463463 momentum : torch .Tensor ,
464464 use_eigh : bool = False ,
465465 use_adaptive_criteria : bool = False ,
466- adaptive_update_tolerance : Optional [ float ] = None ,
466+ adaptive_update_tolerance : float | None = None ,
467467 power_iter_steps : int = 1 ,
468468 convert_to_float : bool = True ,
469- ) -> Tuple [ List [torch .Tensor ], torch .Tensor , torch .Tensor ]:
469+ ) -> tuple [ list [torch .Tensor ], torch .Tensor , torch .Tensor ]:
470470 """Updates the eigenbases using QR decomposition and power iteration or eigh.
471471
472472 This function performs an update of the eigenbases (QL and QR)
@@ -559,8 +559,8 @@ def update_eigenbasis_and_momentum(
559559@torch .compile # type: ignore[misc]
560560def precondition (
561561 grad : torch .Tensor ,
562- eigenbasis_list : Optional [ List [ torch .Tensor ]] = None ,
563- dims : Optional [ List [ List [ int ]]] = None ,
562+ eigenbasis_list : list [ torch .Tensor ] | None = None ,
563+ dims : list [ list [ int ]] | None = None ,
564564) -> torch .Tensor :
565565 """Projects the gradient to and from the eigenbases of the kronecker factor matrices.
566566
@@ -610,7 +610,7 @@ def precondition(
610610def _is_eigenbasis_update_step (
611611 step : int ,
612612 adam_warmup_steps : int ,
613- precondition_frequency : Union [ int , Callable [[int ], int ] ],
613+ precondition_frequency : int | Callable [[int ], int ],
614614) -> bool :
615615 """Checks if amortized computation of the eigenbasis should be recomputed.
616616
0 commit comments