2020
2121
2222class BaseOptimizer (ABC , Optimizer ):
23- r """Base optimizer class. Provides common functionalities for the optimizers."""
23+ """Base optimizer class. Provides common functionalities for the optimizers."""
2424
2525 def __init__ (self , params : Parameters , defaults : Defaults ) -> None :
2626 super ().__init__ (params , defaults )
2727
2828 @staticmethod
2929 def load_optimizer (optimizer : OPTIMIZER_INSTANCE_OR_CLASS , ** kwargs ) -> Optimizer :
30- r """Build torch.optim.Optimizer class."""
30+ """Build torch.optim.Optimizer class."""
3131 if isinstance (optimizer , Optimizer ):
3232 return optimizer
3333
@@ -40,22 +40,22 @@ def load_optimizer(optimizer: OPTIMIZER_INSTANCE_OR_CLASS, **kwargs) -> Optimize
4040 @staticmethod
4141 @torch .no_grad ()
4242 def set_hessian (param_groups : Parameters , state : State , hessian : List [torch .Tensor ]) -> None :
43- r"""Set hessian to state from external source. Generally useful when using functorch as a base.
43+ """Set hessian to state from external source. Generally useful when using functorch as a base.
44+
45+ Args:
46+ param_groups: PARAMETERS. Parameter groups from optimizer.
47+ state: STATE. Optimizer state dictionary.
48+ hessian: List[torch.Tensor]. Sequence of Hessian tensors to set.
4449
4550 Example:
46- -------
47- # Hutchinson's Estimator using HVP
48- noise = tree_map(lambda v: torch.randn_like(v), params)
49- loss_, hvp_est = jvp(grad(run_model_fn), (params,), (noise,))
50- hessian_diag_est = tree_map(lambda a, b: a * b, hvp_est, noise)
51+ # Hutchinson's Estimator using Hessian-vector product (HVP)
52+ >>> noise = tree_map(lambda v: torch.randn_like(v), params)
53+ >>> loss_, hvp_est = jvp(grad(run_model_fn), (params,), (noise,))
54+ >>> hessian_diag_est = tree_map(lambda a, b: a * b, hvp_est, noise)
5155
52- optimizer.set_hessian(hessian_diag_est)
56+ >>> optimizer.set_hessian(hessian_diag_est)
5357 # OR
54- optimizer.step(hessian=hessian_diag_est)
55-
56- :param param_groups: PARAMETERS. parameter groups.
57- :param state: STATE. optimizer state.
58- :param hessian: List[torch.Tensor]. sequence of hessian to set.
58+ >>> optimizer.step(hessian=hessian_diag_est)
5959 """
6060 i : int = 0
6161 for group in param_groups :
@@ -70,11 +70,12 @@ def set_hessian(param_groups: Parameters, state: State, hessian: List[torch.Tens
7070
7171 @staticmethod
7272 def zero_hessian (param_groups : Parameters , state : State , pre_zero : bool = True ) -> None :
73- r """Zero-out hessian .
73+ """Zero-out Hessian .
7474
75- :param param_groups: PARAMETERS. parameter groups.
76- :param state: STATE. optimizer state.
77- :param pre_zero: bool. zero-out hessian before computing the hessian.
75+ Args:
76+ param_groups (Parameters): Parameter groups from the optimizer.
77+ state (State): Optimizer state dictionary.
78+ pre_zero (bool): If True, zero-out the Hessian before computing/updating it.
7879 """
7980 for group in param_groups :
8081 for p in group ['params' ]:
@@ -93,13 +94,14 @@ def compute_hutchinson_hessian(
9394 alpha : float = 1.0 ,
9495 distribution : HUTCHINSON_G = 'gaussian' ,
9596 ) -> None :
96- r"""Hutchinson's approximate hessian, added to the state under key `hessian`.
97-
98- :param param_groups: PARAMETERS. parameter groups.
99- :param state: STATE. optimizer state.
100- :param num_samples: int. number of times to sample `z` for the approximation of the hessian trace.
101- :param alpha: float. alpha.
102- :param distribution: HUTCHINSON_G. type of distribution.
97+ r"""Hutchinson's approximate Hessian, added to the state under key `hessian`.
98+
99+ Args:
100+ param_groups (Parameters): Parameter groups from the optimizer.
101+ state (State): Optimizer state dictionary.
102+ num_samples (int): Number of times to sample noise vector `z` for the trace approximation.
103+ alpha (float): Scaling factor for the Hessian estimate.
104+ distribution (HUTCHINSON_G): Type of noise distribution used (e.g., Rademacher).
103105 """
104106 if distribution not in ('gaussian' , 'rademacher' ):
105107 raise NotImplementedError (f'hessian with distribution { distribution } is not implemented.' )
@@ -135,15 +137,16 @@ def apply_weight_decay(
135137 fixed_decay : bool ,
136138 ratio : Optional [float ] = None ,
137139 ) -> None :
138- r"""Apply weight decay.
139-
140- :param p: torch.Tensor. parameter.
141- :param grad: torch.Tensor. gradient.
142- :param lr: float. learning rate.
143- :param weight_decay: float. weight decay (L2 penalty).
144- :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
145- :param fixed_decay: bool. fix weight decay.
146- :param ratio: Optional[float]. scale weight decay.
140+ """Apply weight decay.
141+
142+ Args:
143+ p (torch.Tensor): Parameter tensor to apply weight decay to.
144+ grad (torch.Tensor): Gradient tensor of parameter p.
145+ lr (float): Learning rate to scale the update.
146+ weight_decay (float): Weight decay coefficient (L2 penalty).
147+ weight_decouple (bool): If True, applies decoupled weight decay as in AdamW.
148+ fixed_decay (bool): If True, fixes weight decay to not depend on learning rate.
149+ ratio (Optional[float]): Optional scaling factor for weight decay.
147150 """
148151 if weight_decouple :
149152 p .mul_ (1.0 - weight_decay * (1.0 if fixed_decay else lr ) * (ratio if ratio is not None else 1.0 ))
@@ -158,13 +161,14 @@ def apply_ams_bound(
158161 eps : float ,
159162 exp_avg_sq_eps : float = 1e-15 ,
160163 ) -> torch .Tensor :
161- r"""Apply AMSBound variant.
162-
163- :param ams_bound: bool. whether to apply AMSBound.
164- :param exp_avg_sq: torch.Tensor. exp_avg_sq.
165- :param max_exp_avg_sq: Optional[torch.Tensor]. max_exp_avg_sq.
166- :param eps: float. epsilon.
167- :param exp_avg_sq_eps: float. eps value for numerical stability for exp_avg_sq.
164+ """Apply AMSBound variant.
165+
166+ Args:
167+ ams_bound (bool): Whether to apply the AMSBound variant.
168+ exp_avg_sq (torch.Tensor): Exponential moving average of squared gradients.
169+ max_exp_avg_sq (Optional[torch.Tensor]): Maximum of all exp_avg_sq elements, for AMSBound.
170+ eps (float): Small epsilon value for numerical stability.
171+ exp_avg_sq_eps (float): Epsilon used specifically for numerical stability in exp_avg_sq computations.
168172 """
169173 if ams_bound :
170174 if torch .is_complex (max_exp_avg_sq ):
@@ -179,10 +183,11 @@ def apply_ams_bound(
179183
180184 @staticmethod
181185 def debias (beta : float , step : int ) -> float :
182- r """Adam-style debias correction. Returns `1.0 - beta ** step` .
186+ """Adam-style debias correction.
183187
184- :param beta: float. beta.
185- :param step: int. number of step.
188+ Args:
189+ beta (float): Exponential decay rate for moment estimates.
190+ step (int): Current optimization step number.
186191 """
187192 return 1.0 - math .pow (beta , step ) # fmt: skip
188193
@@ -192,19 +197,21 @@ def debias_beta(beta: float, step: int) -> float:
192197
193198 Simplified version of `\^{beta} = beta * (1.0 - beta ** (step - 1)) / (1.0 - beta ** step)`
194199
195- :param beta: float. beta.
196- :param step: int. number of step.
200+ Args:
201+ beta (float): The original beta decay rate.
202+ step (int): Current optimization step number.
197203 """
198204 beta_n : float = math .pow (beta , step )
199205 return (beta_n - beta ) / (beta_n - 1.0 ) # fmt: skip
200206
201207 @staticmethod
202208 def apply_adam_debias (adam_debias : bool , step_size : float , bias_correction1 : float ) -> float :
203- r """Apply AdamD variant.
209+ """Apply AdamD variant.
204210
205- :param adam_debias: bool. Only correct the denominator to avoid inflating step sizes early in training.
206- :param step_size: float. step size.
207- :param bias_correction1: float. bias_correction.
211+ Args:
212+ adam_debias (bool): If True, only corrects the denominator to avoid inflating step sizes early in training.
213+ step_size (float): The step size for the update.
214+ bias_correction1 (float): The bias correction factor for the first moment.
208215 """
209216 return step_size if adam_debias else step_size / bias_correction1
210217
@@ -217,14 +224,15 @@ def get_rectify_step_size(
217224 n_sma_threshold : int ,
218225 degenerated_to_sgd : bool ,
219226 ) -> Tuple [float , float ]:
220- r"""Get step size for rectify optimizer.
221-
222- :param is_rectify: bool. whether to apply rectify-variant.
223- :param step: int. number of steps.
224- :param lr: float. learning rate.
225- :param beta2: float. beta2.
226- :param n_sma_threshold: float. SMA threshold.
227- :param degenerated_to_sgd: bool. degenerated to SGD.
227+ """Get step size for rectify optimizer.
228+
229+ Args:
230+ is_rectify (bool): Whether to apply the rectify variant.
231+ step (int): Current step number.
232+ lr (float): Base learning rate.
233+ beta2 (float): Beta2 parameter from optimizer (momentum term).
234+ n_sma_threshold (float): Simple Moving Average (SMA) threshold for rectification.
235+ degenerated_to_sgd (bool): Whether to degenerate to SGD if below threshold.
228236 """
229237 step_size : float = lr
230238 n_sma : float = 0.0
@@ -253,10 +261,11 @@ def get_adanorm_gradient(
253261 ) -> torch .Tensor :
254262 r"""Get AdaNorm gradient.
255263
256- :param grad: torch.Tensor. gradient.
257- :param adanorm: bool. whether to use the AdaNorm variant.
258- :param exp_grad_norm: Optional[torch.Tensor]. exp_grad_norm.
259- :param r: Optional[float]. EMA factor. between 0.9 ~ 0.99 is preferred.
264+ Args:
265+ grad (torch.Tensor): Gradient.
266+ adanorm (bool): Whether to use the AdaNorm variant.
267+ exp_grad_norm (Optional[torch.Tensor]): Exponential moving average of gradient norm.
268+ r (Optional[float]): EMA factor; between 0.9 and 0.99 is preferred.
260269 """
261270 if not adanorm or exp_grad_norm is None :
262271 return grad
@@ -272,7 +281,7 @@ def get_adanorm_gradient(
272281
273282 @staticmethod
274283 def get_rms (x : torch .Tensor ) -> torch .Tensor :
275- r """Get RMS."""
284+ """Get RMS."""
276285 return x .norm (2 ) / math .sqrt (x .numel ())
277286
278287 @staticmethod
@@ -281,29 +290,31 @@ def approximate_sq_grad(
281290 exp_avg_sq_col : torch .Tensor ,
282291 output : torch .Tensor ,
283292 ) -> None :
284- r """Get approximation of EMA of squared gradient."""
293+ """Get approximation of EMA of squared gradient."""
285294 r_factor : torch .Tensor = (exp_avg_sq_row / exp_avg_sq_row .mean (dim = - 1 , keepdim = True )).rsqrt_ ().unsqueeze (- 1 )
286295 c_factor : torch .Tensor = exp_avg_sq_col .unsqueeze (- 2 ).rsqrt ()
287296 torch .mul (r_factor , c_factor , out = output )
288297
289298 @staticmethod
290299 def apply_cautious (update : torch .Tensor , grad : torch .Tensor ) -> None :
291- r """Apply the Cautious Optimizer feature.
300+ """Apply the Cautious Optimizer feature.
292301
293- :param update: torch.Tensor. update. it'll be masked in in-place manner.
294- :param grad: torch.Tensor. gradient.
302+ Args:
303+ update (torch.Tensor): update. it'll be masked in in-place manner.
304+ grad (torch.Tensor): gradient.
295305 """
296306 mask = (update * grad > 0 ).to (grad .dtype )
297307 mask .mul_ (mask .numel () / (mask .sum () + 1 ))
298308 update .mul_ (mask )
299309
300310 @staticmethod
301311 def get_stable_adamw_rms (grad : torch .Tensor , exp_avg_sq : torch .Tensor , eps : float = 1e-16 ) -> float :
302- r """Get StableAdamW RMS.
312+ """Get StableAdamW RMS.
303313
304- :param grad: torch.Tensor. gradient.
305- :param exp_avg_sq: torch.Tensor. exp_avg_sq.
306- :param eps: float. epsilon.
314+ Args:
315+ grad (torch.Tensor): gradient.
316+ exp_avg_sq (torch.Tensor): Exponential moving average of squared gradient.
317+ eps (float): Small value to prevent division by zero.
307318 """
308319 return grad .pow (2 ).div_ (exp_avg_sq .clip (min = eps )).mean ().sqrt_ ().clip_ (min = 1.0 ).item ()
309320
@@ -382,12 +393,12 @@ def validate_nus(self, nus: Union[float, Tuple[float, float]]) -> None:
382393
383394 @abstractmethod
384395 def init_group (self , group : ParamGroup , ** kwargs ) -> None : # pragma: no cover
385- r """Initialize the group of the optimizer and return is_complex."""
396+ """Initialize the group of the optimizer and return is_complex."""
386397 return
387398
388399 @staticmethod
389400 def view_as_real (param , * state_and_grads ) -> tuple :
390- r """View imaginary tensors as real tensors."""
401+ """View imaginary tensors as real tensors."""
391402 if torch .is_complex (param ):
392403 param = torch .view_as_real (param )
393404 state_and_grads = tuple (
@@ -399,7 +410,7 @@ def view_as_real(param, *state_and_grads) -> tuple:
399410
400411 @staticmethod
401412 def maximize_gradient (grad : torch .Tensor , maximize : bool = False ) -> None :
402- r """Maximize the objective with respect to the params, instead of minimizing."""
413+ """Maximize the objective with respect to the params, instead of minimizing."""
403414 if maximize :
404415 grad .neg_ ()
405416
0 commit comments