Skip to content

Commit 7f2af5d

Browse files
committed
refactor: Ranger21 optimizer
1 parent 94f69ed commit 7f2af5d

File tree

1 file changed

+24
-54
lines changed

1 file changed

+24
-54
lines changed

pytorch_optimizer/ranger21.py

Lines changed: 24 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
'@TheZothen',
88
]
99

10-
import collections
1110
import math
1211
from typing import Dict, List, Optional
1312

@@ -31,7 +30,7 @@
3130

3231
class Ranger21(Optimizer):
3332
"""
34-
Reference : https://github.com/lessw2020/Ranger21/blob/main/ranger21/ranger21.py
33+
Reference : https://github.com/lessw2020/Ranger21
3534
Example :
3635
from pytorch_optimizer import Ranger21
3736
...
@@ -82,16 +81,18 @@ def __init__(
8281
decay_type: str = 'stable',
8382
warmup_type: str = 'linear',
8483
warmup_pct_default: float = 0.22,
85-
logging_active: bool = False,
8684
):
87-
"""Ranger optimizer (RAdam + Lookahead + Gradient Centralization, combined into one optimizer)
88-
:param params: PARAMS. iterable of parameters to optimize or dicts defining parameter groups
85+
"""
86+
:param params: PARAMS. iterable of parameters to optimize
87+
or dicts defining parameter groups
8988
:param lr: float. learning rate.
90-
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
91-
:param eps: float. term added to the denominator to improve numerical stability
89+
:param betas: BETAS. coefficients used for computing running averages
90+
of gradient and the squared hessian trace
91+
:param eps: float. term added to the denominator
92+
to improve numerical stability
9293
:param weight_decay: float. weight decay (L2 penalty)
93-
:param use_gc: bool. use Gradient Centralization (both convolution & fc layers)
94-
:param gc_conv_only: bool. use Gradient Centralization (only convolution layer)
94+
:param use_gc: bool. use GC both convolution & fc layers
95+
:param gc_conv_only: bool. use GC only convolution layer
9596
"""
9697
defaults: DEFAULT_PARAMETERS = dict(
9798
lr=lr,
@@ -102,8 +103,6 @@ def __init__(
102103
)
103104
super().__init__(params, defaults)
104105

105-
self.logging = logging_active
106-
107106
self.use_madgrad = use_madgrad
108107
self.core_engine: str = self.get_core_engine(self.use_madgrad)
109108

@@ -207,9 +206,6 @@ def __init__(
207206
self.param_size: int = 0
208207

209208
self.tracking_lr: List[float] = []
210-
if self.logging:
211-
self.tracking_variance_sum: List[float] = []
212-
self.tracking_variance_normalized = []
213209

214210
@staticmethod
215211
def get_core_engine(use_madgrad: bool = False) -> str:
@@ -255,27 +251,18 @@ def warmup_dampening(self, lr: float, step: int) -> float:
255251

256252
if step > warmup:
257253
if not self.warmup_complete:
258-
if not self.warmup_curr_pct == 1.0:
259-
print(
260-
f'Error | lr did not achieve full set point from warmup, currently {self.warmup_curr_pct}'
261-
)
262-
263254
self.warmup_complete = True
264-
print(
265-
f'\n** Ranger21 update | Warmup complete - lr set to {lr}\n'
266-
)
267-
268255
return lr
269256

270257
if style == 'linear':
271258
self.warmup_curr_pct = min(1.0, (step / warmup))
272259
new_lr: float = lr * self.warmup_curr_pct
273260
self.current_lr = new_lr
274261
return new_lr
275-
else:
276-
raise NotImplementedError(
277-
f'warmup style {style} is not supported yet :('
278-
)
262+
263+
raise NotImplementedError(
264+
f'warmup style {style} is not supported yet :('
265+
)
279266

280267
def get_warm_down(self, lr: float, iteration: int) -> float:
281268
if iteration < self.start_warm_down:
@@ -284,21 +271,18 @@ def get_warm_down(self, lr: float, iteration: int) -> float:
284271
if iteration > self.start_warm_down - 1:
285272
# start iteration from 1, not 0
286273
warm_down_iteration: int = (iteration + 1) - self.start_warm_down
287-
if warm_down_iteration < 1:
288-
warm_down_iteration = 1
274+
warm_down_iteration = max(warm_down_iteration, 1)
289275

290276
warm_down_pct: float = warm_down_iteration / (
291277
self.warm_down_total_iterations + 1
292278
)
293-
if warm_down_pct > 1.00:
294-
warm_down_pct = 1.00
279+
warm_down_pct = min(warm_down_pct, 1.0)
295280

296281
lr_range: float = self.warm_down_lr_delta
297282
reduction: float = lr_range * warm_down_pct
298-
new_lr: float = self.starting_lr - reduction
299-
if new_lr < self.min_lr:
300-
new_lr = self.min_lr
301283

284+
new_lr: float = self.starting_lr - reduction
285+
new_lr = max(new_lr, self.min_lr)
302286
self.current_lr = new_lr
303287

304288
return new_lr
@@ -323,21 +307,13 @@ def get_chebyshev_lr(self, lr: float, iteration: int) -> float:
323307
self.current_epoch = current_epoch
324308

325309
index: int = current_epoch - 2
326-
if index < 0:
327-
index = 0
328-
if index > len(self.chebyshev_schedule) - 1:
329-
index = len(self.chebyshev_schedule) - 1
310+
index = max(0, index)
311+
index = min(index, len(self.chebyshev_schedule) - 1)
330312

331313
chebyshev_value = self.chebyshev_schedule[index]
332314

333-
if self.cheb_logging[:-1] != chebyshev_value:
334-
self.cheb_logging.append(chebyshev_value)
335-
336315
return lr * chebyshev_value
337316

338-
def get_variance(self):
339-
return self.tracking_variance_sum
340-
341317
@staticmethod
342318
def get_state_values(group, state):
343319
beta1, beta2 = group['betas']
@@ -348,16 +324,16 @@ def get_state_values(group, state):
348324
@torch.no_grad()
349325
def step(self, closure: CLOSURE = None) -> LOSS:
350326
loss: LOSS = None
351-
if closure is not None and isinstance(closure, collections.Callable):
327+
if closure is not None:
352328
with torch.enable_grad():
353329
loss = closure()
354330

355331
param_size: float = 0
356332
variance_ma_sum: float = 1.0
357333

358334
# phase 1 - accumulate all of the variance_ma_sum to use in stable weight decay
359-
for i, group in enumerate(self.param_groups):
360-
for j, p in enumerate(group['params']):
335+
for group in self.param_groups:
336+
for p in group['params']:
361337
if p.grad is None:
362338
continue
363339

@@ -369,7 +345,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
369345
)
370346

371347
grad = p.grad
372-
373348
if grad.is_sparse:
374349
raise RuntimeError('sparse matrix not supported atm')
375350

@@ -443,11 +418,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
443418
if math.isnan(variance_normalized):
444419
raise RuntimeError('hit nan for variance_normalized')
445420

446-
# debugging/logging
447-
if self.logging:
448-
self.tracking_variance_sum.append(variance_ma_sum.item())
449-
self.tracking_variance_normalized.append(variance_normalized)
450-
451421
# phase 2 - apply weight decay and step
452422
for group in self.param_groups:
453423
step = state['step']
@@ -464,7 +434,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
464434
# warm-down
465435
if self.warm_down_active:
466436
lr = self.get_warm_down(lr, step)
467-
if 0 > lr:
437+
if lr < 0.0:
468438
raise ValueError(f'{lr} went negative')
469439

470440
# MADGRAD outer

0 commit comments

Comments
 (0)