Skip to content

Commit 619c169

Browse files
committed
refactor: weight_decay part
1 parent 3e68698 commit 619c169

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

pytorch_optimizer/radam.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
152152
step_size = -1
153153
buffered[2] = step_size
154154

155-
if (n_sma >= self.n_sma_threshold or step_size > 0) and group['weight_decay'] != 0:
155+
if group['weight_decay'] != 0 and (n_sma >= self.n_sma_threshold or step_size > 0):
156156
p_fp32.add_(p_fp32, alpha=-group['weight_decay'] * group['lr'])
157157

158158
if n_sma >= self.n_sma_threshold:

0 commit comments

Comments
 (0)