1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515import math
16- from typing import Callable , List , Tuple , override
16+ from typing import Callable , overload
17+
18+
19+ try :
20+ from typing import override
21+ except ImportError :
22+ from typing_extensions import override
1723
1824import torch
1925from torch .optim .optimizer import ParamsT
@@ -85,6 +91,12 @@ def __init__(
8591 }
8692 super ().__init__ (params , defaults )
8793
94+ @overload
95+ def step (self , closure : None = ...) -> None : ...
96+
97+ @overload
98+ def step (self , closure : Callable [[], float ]) -> float : ...
99+
88100 @torch .no_grad () # type: ignore[misc]
89101 @override
90102 def step (self , closure : Callable [[], float ] | None = None ) -> float | None :
@@ -154,7 +166,7 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
154166def _init_psgd_kron_states (
155167 grad : torch .Tensor ,
156168 precond_init_scale : float = 1.0 ,
157- ) -> Tuple [ List [torch .Tensor ], List [torch .Tensor ]]:
169+ ) -> tuple [ list [torch .Tensor ], list [torch .Tensor ]]:
158170 """Initialize the Kronecker factor matrices and Lipschitz constants.
159171
160172 Args:
@@ -165,8 +177,8 @@ def _init_psgd_kron_states(
165177 q_list: List of Kronecker factors.
166178 lip_const_list: List of Lipschitz constants for the Kronecker factors.
167179 """
168- q_list : List [torch .Tensor ] = []
169- lip_const_list : List [torch .Tensor ] = []
180+ q_list : list [torch .Tensor ] = []
181+ lip_const_list : list [torch .Tensor ] = []
170182
171183 # Create identity matrices scaled by precond_init_scale for each dimension
172184 for size in grad .shape :
@@ -177,13 +189,13 @@ def _init_psgd_kron_states(
177189
178190
179191def _update_precond_procrustes (
180- q_list : List [torch .Tensor ],
181- lip_const_list : List [torch .Tensor ],
192+ q_list : list [torch .Tensor ],
193+ lip_const_list : list [torch .Tensor ],
182194 exp_avg : torch .Tensor ,
183195 damping_noise_scale : float = 1e-9 ,
184196 precond_lr : float = 0.1 ,
185197 beta_lip : float = 0.9 ,
186- ) -> Tuple [ List [torch .Tensor ], List [torch .Tensor ]]:
198+ ) -> tuple [ list [torch .Tensor ], list [torch .Tensor ]]:
187199 r"""Update the Kron preconditioner Q using procrustes step and uniformization.
188200
189201 Args:
@@ -201,8 +213,8 @@ def _update_precond_procrustes(
201213 dampened_momentum = exp_avg + (damping_noise_scale + 1e-7 * exp_avg .abs ()) * torch .randn_like (exp_avg )
202214 pg = psgd_kron_contractions .apply_preconditioner (q_list , dampened_momentum )
203215 total_numel = pg .numel ()
204- updated_q_list : List [torch .Tensor ] = []
205- updated_lip_const_list : List [torch .Tensor ] = []
216+ updated_q_list : list [torch .Tensor ] = []
217+ updated_lip_const_list : list [torch .Tensor ] = []
206218 for dim , q in enumerate (q_list ):
207219 # compute gradient covariance
208220 precond_grad_cov = psgd_kron_contractions .partial_contraction (pg , pg , dim )
@@ -229,7 +241,7 @@ def _update_matrix_preconditioner(
229241 total_numel : int ,
230242 precond_lr : float ,
231243 beta_lip : float ,
232- ) -> Tuple [torch .Tensor , torch .Tensor ]:
244+ ) -> tuple [torch .Tensor , torch .Tensor ]:
233245 r"""Update matrix-structured preconditioner with adaptive Lipschitz constant.
234246
235247 Args:
@@ -259,7 +271,7 @@ def _update_1d_preconditioner(
259271 total_numel : int ,
260272 precond_lr : float ,
261273 beta_lip : float ,
262- ) -> Tuple [torch .Tensor , torch .Tensor ]:
274+ ) -> tuple [torch .Tensor , torch .Tensor ]:
263275 r"""Update 1D preconditioner with adaptive Lipschitz constant.
264276
265277 Args:
0 commit comments