Skip to content

Commit d1171da

Browse files
committed
refactor: Ranger21 optimizer
1 parent 80c1bfb commit d1171da

File tree

1 file changed

+15
-29
lines changed

1 file changed

+15
-29
lines changed

pytorch_optimizer/optimizer/ranger21.py

Lines changed: 15 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,9 @@ def __init__( # pylint: disable=R0913
9696
self.lookahead_blending_alpha = lookahead_blending_alpha
9797
self.norm_loss_factor = norm_loss_factor
9898

99-
# lookahead
10099
self.lookahead_step: int = 0
101-
102-
# learning rate
103-
self.starting_lr = lr
104-
self.current_lr = lr
100+
self.starting_lr: float = lr
101+
self.current_lr: float = lr
105102

106103
defaults: DEFAULTS = {
107104
'lr': lr,
@@ -114,7 +111,6 @@ def __init__( # pylint: disable=R0913
114111
}
115112
super().__init__(params, defaults)
116113

117-
# warmup iterations
118114
self.num_warm_up_iterations: int = (
119115
self.build_warm_up_iterations(num_iterations, betas[1])
120116
if num_warm_up_iterations is None
@@ -140,8 +136,7 @@ def reset(self):
140136

141137
state['grad_ma'] = torch.zeros_like(p)
142138
state['variance_ma'] = torch.zeros_like(p)
143-
state['lookahead_params'] = torch.empty_like(p)
144-
state['lookahead_params'].copy_(p)
139+
state['lookahead_params'] = p.clone()
145140
state['neg_grad_ma'] = torch.zeros_like(p)
146141
state['max_variance_ma'] = torch.zeros_like(p)
147142

@@ -162,28 +157,21 @@ def warm_up_dampening(self, lr: float, step: int) -> float:
162157

163158
warm_up_current_pct: float = min(1.0, (step / self.num_warm_up_iterations))
164159

165-
new_lr: float = lr * warm_up_current_pct
166-
self.current_lr = new_lr
160+
self.current_lr = lr * warm_up_current_pct
167161

168-
return new_lr
162+
return self.current_lr
169163

170164
def warm_down(self, lr: float, iteration: int) -> float:
171165
if iteration < self.start_warm_down:
172166
return lr
173167

174168
# start iteration from 1, not 0
175-
warm_down_iteration: int = (iteration + 1) - self.start_warm_down
176-
warm_down_iteration = max(warm_down_iteration, 1)
177-
178-
warm_down_pct: float = warm_down_iteration / (self.num_warm_down_iterations + 1)
179-
warm_down_pct = min(warm_down_pct, 1.0)
180-
181-
new_lr: float = self.starting_lr - self.warm_down_lr_delta * warm_down_pct
182-
new_lr = max(new_lr, self.min_lr)
169+
warm_down_iteration: int = max((iteration + 1) - self.start_warm_down, 1)
170+
warm_down_pct: float = min(warm_down_iteration / (self.num_warm_down_iterations + 1), 1.0)
183171

184-
self.current_lr = new_lr
172+
self.current_lr = max(self.starting_lr - self.warm_down_lr_delta * warm_down_pct, self.min_lr)
185173

186-
return new_lr
174+
return self.current_lr
187175

188176
@torch.no_grad()
189177
def step(self, closure: CLOSURE = None) -> LOSS:
@@ -220,8 +208,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
220208
if len(state) == 0:
221209
state['grad_ma'] = torch.zeros_like(p)
222210
state['variance_ma'] = torch.zeros_like(p)
223-
state['lookahead_params'] = torch.empty_like(p)
224-
state['lookahead_params'].copy_(p)
211+
state['lookahead_params'] = p.clone()
225212
state['neg_grad_ma'] = torch.zeros_like(p)
226213
state['max_variance_ma'] = torch.zeros_like(p)
227214

@@ -245,7 +232,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
245232

246233
# Phase 2 - Apply weight decay and step
247234
for group in self.param_groups:
248-
lr: float = group['lr']
249235
beta1, beta2 = group['betas']
250236

251237
bias_correction1: float = 1.0 - beta1 ** group['step'] # fmt: skip
@@ -254,7 +240,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
254240
noise_norm: float = math.sqrt((1.0 + beta2) ** 2 + beta2 ** 2) # fmt: skip
255241

256242
# warm up & down
257-
lr = self.warm_up_dampening(lr, group['step'])
243+
lr: float = self.warm_up_dampening(group['lr'], group['step'])
258244
lr = self.warm_down(lr, group['step'])
259245

260246
for p in group['params']:
@@ -287,16 +273,16 @@ def step(self, closure: CLOSURE = None) -> LOSS:
287273

288274
de_nom = (variance_ma.sqrt() / bias_correction2_sq).add_(group['eps'])
289275

276+
if self.use_softplus:
277+
de_nom = f.softplus(de_nom, beta=self.beta_softplus)
278+
290279
grad = p.grad
291280
centralize_gradient(grad, gc_conv_only=False)
292281
normalize_gradient(grad)
293282

294283
grad_ma.mul_(beta1 ** 2).add_(grad, alpha=1.0 - beta1 ** 2) # fmt: skip
295284

296-
step_size: float = lr if group['adam_debias'] else lr / bias_correction1
297-
298-
if self.use_softplus:
299-
de_nom = f.softplus(de_nom, beta=self.beta_softplus)
285+
step_size: float = self.apply_adam_debias(group['adam_debias'], lr, bias_correction1)
300286

301287
pn_momentum = grad_ma.mul(1.0 + 1.0).add(neg_grad_ma, alpha=-1.0).mul(1.0 / noise_norm)
302288
p.addcdiv_(pn_momentum, de_nom, value=-step_size)

0 commit comments

Comments
 (0)