Skip to content

Commit 90ce1fd

Browse files
committed
update: move step into group
1 parent 896a16a commit 90ce1fd

File tree

1 file changed

+21
-29
lines changed

1 file changed

+21
-29
lines changed

pytorch_optimizer/optimizer/ranger21.py

Lines changed: 21 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,11 @@ def step(self, closure: CLOSURE = None) -> LOSS:
200200

201201
# Phase 1 - Accumulate all the variance_ma_sum to use in stable weight decay
202202
for group in self.param_groups:
203+
if 'step' in group:
204+
group['step'] += 1
205+
else:
206+
group['step'] = 1
207+
203208
beta1, beta2 = group['betas']
204209
for p in group['params']:
205210
if p.grad is None:
@@ -216,7 +221,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
216221

217222
state = self.state[p]
218223
if len(state) == 0:
219-
state['step'] = 0
220224
state['grad_ma'] = torch.zeros_like(p)
221225
state['variance_ma'] = torch.zeros_like(p)
222226
state['lookahead_params'] = torch.empty_like(p)
@@ -229,17 +233,14 @@ def step(self, closure: CLOSURE = None) -> LOSS:
229233
grad = centralize_gradient(grad, gc_conv_only=False)
230234
grad = normalize_gradient(grad)
231235

232-
state['step'] += 1
233-
234-
bias_correction2 = 1.0 - beta2 ** state['step']
236+
bias_correction2 = 1.0 - beta2 ** group['step']
235237

236238
# second moment estimation
237239
# using positive-negative momentum and bias correction
238240
variance_ma = state['variance_ma']
239241
variance_ma.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
240242
variance_ma_sum += (variance_ma / bias_correction2).sum()
241243

242-
# stable weight decay
243244
if param_size == 0:
244245
raise ZeroParameterSizeError()
245246

@@ -249,41 +250,32 @@ def step(self, closure: CLOSURE = None) -> LOSS:
249250

250251
# Phase 2 - Apply weight decay and step
251252
for group in self.param_groups:
252-
if len(self.state) == 0:
253-
continue
254-
255-
p = next(iter(self.state.keys()))
256-
257253
lr = group['lr']
258-
step = self.state[p]['step']
259-
260254
beta1, beta2 = group['betas']
261-
bias_correction1 = 1.0 - beta1 ** step # fmt: skip
262-
bias_correction2_sq = math.sqrt(1.0 - beta2 ** step) # fmt: skip
263255

264-
noise_norm: float = math.sqrt((1.0 + beta2) ** 2 + beta2 ** 2) # fmt: skip
256+
bias_correction1 = 1.0 - beta1 ** group['step'] # fmt: skip
257+
bias_correction2_sq = math.sqrt(1.0 - beta2 ** group['step']) # fmt: skip
265258

266-
# warm up
267-
lr = self.warm_up_dampening(lr, step)
268-
269-
# warm down
270-
lr = self.warm_down(lr, step)
271-
272-
# stable decay
273-
decay = group['weight_decay']
274-
if decay:
275-
p.mul_(1.0 - decay * lr / variance_normalized)
259+
noise_norm: float = math.sqrt((1.0 + beta2) ** 2 + beta2 ** 2) # fmt: skip
276260

277-
# norm loss
278-
correction = 2.0 * self.norm_loss_factor * (1.0 - torch.div(1, unit_norm(p) + self.eps))
279-
p.mul_(1.0 - lr * correction)
261+
# warm up & down
262+
lr = self.warm_up_dampening(lr, group['step'])
263+
lr = self.warm_down(lr, group['step'])
280264

281265
for p in group['params']:
282266
if p.grad is None:
283267
continue
284268

269+
# stable weight decay
270+
if group['weight_decay']:
271+
p.mul_(1.0 - group['weight_decay'] * lr / variance_normalized)
272+
273+
# norm loss
274+
correction = 2.0 * self.norm_loss_factor * (1.0 - torch.div(1, unit_norm(p) + self.eps))
275+
p.mul_(1.0 - lr * correction)
276+
285277
state = self.state[p]
286-
if state['step'] % 2 == 1:
278+
if group['step'] % 2 == 1:
287279
grad_ma, neg_grad_ma = state['grad_ma'], state['neg_grad_ma']
288280
else:
289281
grad_ma, neg_grad_ma = state['neg_grad_ma'], state['grad_ma']

0 commit comments

Comments
 (0)