33from typing import List , Optional , Tuple , Union
44
55import torch
6+ from torch .optim import Optimizer
67
78from pytorch_optimizer .base .exception import NegativeLRError , NegativeStepError
8- from pytorch_optimizer .base .types import BETAS , HUTCHINSON_G , PARAMETERS , STATE
9+ from pytorch_optimizer .base .types import BETAS , CLOSURE , DEFAULTS , HUTCHINSON_G , LOSS , PARAMETERS , STATE
910
1011
11- class BaseOptimizer (ABC ):
12- r"""Base optimizer class."""
12+ class BaseOptimizer (ABC , Optimizer ):
13+ r"""Base optimizer class. Provides common functionalities for the optimizers."""
14+
15+ def __init__ (self , params : PARAMETERS , defaults : DEFAULTS ) -> None :
16+ super ().__init__ (params , defaults )
1317
1418 @staticmethod
1519 @torch .no_grad ()
16- def set_hessian (param_groups : PARAMETERS , state : STATE , hessian : List [torch .Tensor ]):
20+ def set_hessian (param_groups : PARAMETERS , state : STATE , hessian : List [torch .Tensor ]) -> None :
1721 r"""Set hessian to state from external source. Generally useful when using functorch as a base.
1822
1923 Example:
@@ -45,7 +49,7 @@ def set_hessian(param_groups: PARAMETERS, state: STATE, hessian: List[torch.Tens
4549 i += 1
4650
4751 @staticmethod
48- def zero_hessian (param_groups : PARAMETERS , state : STATE , pre_zero : bool = True ):
52+ def zero_hessian (param_groups : PARAMETERS , state : STATE , pre_zero : bool = True ) -> None :
4953 r"""Zero-out hessian.
5054
5155 :param param_groups: PARAMETERS. parameter groups.
@@ -68,7 +72,7 @@ def compute_hutchinson_hessian(
6872 num_samples : int = 1 ,
6973 alpha : float = 1.0 ,
7074 distribution : HUTCHINSON_G = 'gaussian' ,
71- ):
75+ ) -> None :
7276 r"""Hutchinson's approximate hessian, added to the state under key `hessian`.
7377
7478 :param param_groups: PARAMETERS. parameter groups.
@@ -110,7 +114,7 @@ def apply_weight_decay(
110114 weight_decouple : bool ,
111115 fixed_decay : bool ,
112116 ratio : Optional [float ] = None ,
113- ):
117+ ) -> None :
114118 r"""Apply weight decay.
115119
116120 :param p: torch.Tensor. parameter.
@@ -145,6 +149,27 @@ def apply_ams_bound(
145149
146150 return de_nom .sqrt_ ().add_ (eps )
147151
152+ @staticmethod
153+ def debias (beta : float , step : int ) -> float :
154+ r"""Adam-style debias correction. Returns `1.0 - beta ** step`.
155+
156+ :param beta: float. beta.
157+ :param step: int. number of step.
158+ """
159+ return 1.0 - math .pow (beta , step ) # fmt: skip
160+
161+ @staticmethod
162+ def debias_beta (beta : float , step : int ) -> float :
163+ r"""Apply the Adam-style debias correction into beta.
164+
165+ Simplified version of `\^{beta} = beta * (1.0 - beta ** (step - 1)) / (1.0 - beta ** step)`
166+
167+ :param beta: float. beta.
168+ :param step: int. number of step.
169+ """
170+ beta_n : float = math .pow (beta , step )
171+ return (beta_n - beta ) / (beta_n - 1.0 ) # fmt: skip
172+
148173 @staticmethod
149174 def apply_adam_debias (adam_debias : bool , step_size : float , bias_correction1 : float ) -> float :
150175 r"""Apply AdamD variant.
@@ -205,14 +230,14 @@ def get_adanorm_gradient(
205230 :param exp_grad_norm: Optional[torch.Tensor]. exp_grad_norm.
206231 :param r: float. Optional[float]. momentum (ratio).
207232 """
208- if not adanorm :
233+ if not adanorm or exp_grad_norm is None :
209234 return grad
210235
211236 grad_norm = torch .linalg .norm (grad )
212237
213238 exp_grad_norm .mul_ (r ).add_ (grad_norm , alpha = 1.0 - r )
214239
215- return grad * exp_grad_norm / grad_norm if exp_grad_norm > grad_norm else grad
240+ return grad . mul ( exp_grad_norm ). div_ ( grad_norm ) if exp_grad_norm > grad_norm else grad
216241
217242 @staticmethod
218243 def get_rms (x : torch .Tensor ) -> float :
@@ -299,5 +324,8 @@ def validate_nus(self, nus: Union[float, Tuple[float, float]]) -> None:
299324 self .validate_range (nus [1 ], 'nu2' , 0.0 , 1.0 , range_type = '[]' )
300325
301326 @abstractmethod
302- def reset (self ): # pragma: no cover
327+ def reset (self ) -> None : # pragma: no cover
328+ raise NotImplementedError
329+
330+ def step (self , closure : CLOSURE = None ) -> LOSS : # pragma: no cover
303331 raise NotImplementedError
0 commit comments