Skip to content

Commit 56d89c2

Browse files
committed
refactor: AdaBelief optimizer
1 parent 3a8c2b6 commit 56d89c2

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

pytorch_optimizer/adabelief.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,17 +44,22 @@ def __init__(
4444
degenerated_to_sgd: bool = True,
4545
):
4646
"""AdaBelief optimizer
47-
:param params: PARAMS. iterable of parameters to optimize or dicts defining parameter groups
47+
:param params: PARAMS. iterable of parameters to optimize
48+
or dicts defining parameter groups
4849
:param lr: float. learning rate
49-
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
50-
:param eps: float. term added to the denominator to improve numerical stability
50+
:param betas: BETAS. coefficients used for computing running averages
51+
of gradient and the squared hessian trace
52+
:param eps: float. term added to the denominator
53+
to improve numerical stability
5154
:param weight_decay: float. weight decay (L2 penalty)
5255
:param n_sma_threshold: (recommended is 5)
5356
:param amsgrad: bool. whether to use the AMSBound variant
54-
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW
57+
:param weight_decouple: bool. the optimizer uses decoupled weight decay
58+
as in AdamW
5559
:param fixed_decay: bool.
5660
:param rectify: bool. perform the rectified update similar to RAdam
57-
:param degenerated_to_sgd: bool. perform SGD update when variance of gradient is high
61+
:param degenerated_to_sgd: bool. perform SGD update
62+
when variance of gradient is high
5863
"""
5964
self.lr = lr
6065
self.betas = betas

0 commit comments

Comments
 (0)