Skip to content

Commit 74d400f

Browse files
committed
refactor: RAdam optimizer
1 parent 7f2af5d commit 74d400f

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

pytorch_optimizer/radam.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ def __init__(
4545
:param lr: float. learning rate.
4646
:param betas: BETAS. coefficients used for computing running averages
4747
of gradient and the squared hessian trace
48-
:param eps: float. term added to the denominator to improve numerical stability
48+
:param eps: float. term added to the denominator
49+
to improve numerical stability
4950
:param weight_decay: float. weight decay (L2 penalty)
5051
:param n_sma_threshold: int. (recommended is 5)
5152
:param degenerated_to_sgd: float.
@@ -82,11 +83,11 @@ def __init__(
8283
super().__init__(params, defaults)
8384

8485
def check_valid_parameters(self):
85-
if 0.0 > self.lr:
86+
if self.lr < 0.0:
8687
raise ValueError(f'Invalid learning rate : {self.lr}')
87-
if 0.0 > self.eps:
88+
if self.eps < 0.0:
8889
raise ValueError(f'Invalid eps : {self.eps}')
89-
if 0.0 > self.weight_decay:
90+
if self.weight_decay < 0.0:
9091
raise ValueError(f'Invalid weight_decay : {self.weight_decay}')
9192
if not 0.0 <= self.betas[0] < 1.0:
9293
raise ValueError(f'Invalid beta_0 : {self.betas[0]}')

0 commit comments

Comments
 (0)