Skip to content

Commit da65344

Browse files
authored
Merge pull request #222 from kozistr/feature/rex-lr-scheduler
[Feature] Implement REX lr scheduler
2 parents fd717fc + a3b1bfb commit da65344

File tree

8 files changed

+105
-7
lines changed

8 files changed

+105
-7
lines changed

README.md

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
**pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch.
1212
I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas.
13-
Currently, **62 optimizers (+ `bitsandbytes`)**, **10 lr schedulers**, and **13 loss functions** are supported!
13+
Currently, **62 optimizers (+ `bitsandbytes`)**, **11 lr schedulers**, and **13 loss functions** are supported!
1414

1515
Highly inspired by [pytorch-optimizer](https://github.com/jettify/pytorch-optimizer).
1616

@@ -171,10 +171,11 @@ from pytorch_optimizer import get_supported_lr_schedulers
171171
supported_lr_schedulers = get_supported_lr_schedulers()
172172
```
173173

174-
| LR Scheduler | Description | Official Code | Paper | Citation |
175-
|-----------------|---------------------------------------------------------------------------------|---------------|------------------------------------|------------------------------------------------------------------------------|
176-
| Explore-Exploit | *Wide-minima Density Hypothesis and the Explore-Exploit Learning Rate Schedule* | | <https://arxiv.org/abs/2003.03977> | [cite](https://ui.adsabs.harvard.edu/abs/2020arXiv200303977I/exportcitation) |
177-
| Chebyshev | *Acceleration via Fractal Learning Rate Schedules* | | <https://arxiv.org/abs/2103.01338> | [cite](https://ui.adsabs.harvard.edu/abs/2021arXiv210301338A/exportcitation) |
174+
| LR Scheduler | Description | Official Code | Paper | Citation |
175+
|-----------------|---------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------|------------------------------------|------------------------------------------------------------------------------|
176+
| Explore-Exploit | *Wide-minima Density Hypothesis and the Explore-Exploit Learning Rate Schedule* | | <https://arxiv.org/abs/2003.03977> | [cite](https://ui.adsabs.harvard.edu/abs/2020arXiv200303977I/exportcitation) |
177+
| Chebyshev | *Acceleration via Fractal Learning Rate Schedules* | | <https://arxiv.org/abs/2103.01338> | [cite](https://ui.adsabs.harvard.edu/abs/2021arXiv210301338A/exportcitation) |
178+
| REX | *Revisiting Budgeted Training with an Improved Schedule* | [github](https://github.com/Nerogar/OneTrainer/blob/2c6f34ea0838e5a86774a1cf75093d7e97c70f03/modules/util/lr_scheduler_util.py#L66) | <https://arxiv.org/abs/2107.04197> | [cite](https://ui.adsabs.harvard.edu/abs/2021arXiv210704197C/exportcitation) |
178179

179180
## Supported Loss Function
180181

docs/changelogs/v3.0.0.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ Major version is updated! (`v2.12.0` -> `v3.0.0`) (#164)
44

55
### Feature
66

7+
* Implement `REX` lr scheduler. (#217, #222)
8+
* [Revisiting Budgeted Training with an Improved Schedule](https://arxiv.org/abs/2107.04197)
79
* Implement `Aida` optimizer. (#220, #221)
810
* [A DNN Optimizer that Improves over AdaBelief by Suppression of the Adaptive Stepsize Range](https://arxiv.org/abs/2203.13273)
911
* Implement `WSAM` optimizer. (#213, #216)

docs/lr_scheduler.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,7 @@
2525
::: pytorch_optimizer.ProportionScheduler
2626
:docstring:
2727
:members:
28+
29+
::: pytorch_optimizer.REXScheduler
30+
:docstring:
31+
:members:

pytorch_optimizer/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from pytorch_optimizer.lr_scheduler.experimental.deberta_v3_lr_scheduler import deberta_v3_large_lr_scheduler
2727
from pytorch_optimizer.lr_scheduler.linear_warmup import CosineScheduler, LinearScheduler, PolyScheduler
2828
from pytorch_optimizer.lr_scheduler.proportion import ProportionScheduler
29+
from pytorch_optimizer.lr_scheduler.rex import REXScheduler
2930
from pytorch_optimizer.optimizer.a2grad import A2Grad
3031
from pytorch_optimizer.optimizer.adabelief import AdaBelief
3132
from pytorch_optimizer.optimizer.adabound import AdaBound
@@ -195,6 +196,7 @@
195196
PolyScheduler,
196197
LinearScheduler,
197198
ProportionScheduler,
199+
REXScheduler,
198200
]
199201
LR_SCHEDULERS: Dict[str, SCHEDULER] = {
200202
str(lr_scheduler.__name__).lower(): lr_scheduler for lr_scheduler in LR_SCHEDULER_LIST

pytorch_optimizer/base/scheduler.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ def step(self):
7676

7777
self.step_t += 1
7878

79-
# apply the lr to optimizer if it's provided
8079
if self.optimizer is not None:
8180
for param_group in self.optimizer.param_groups:
8281
param_group['lr'] = value
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from typing import List
2+
3+
from torch.optim.lr_scheduler import _LRScheduler
4+
5+
from pytorch_optimizer.base.types import OPTIMIZER
6+
7+
8+
class REXScheduler(_LRScheduler):
9+
r"""Revisiting Budgeted Training with an Improved Schedule.
10+
11+
:param optimizer: Optimizer. wrapped optimizer instance.
12+
:param total_steps: int. number of steps to optimize.
13+
:param max_lr: float. max lr.
14+
:param min_lr: float. min lr.
15+
"""
16+
17+
def __init__(
18+
self,
19+
optimizer: OPTIMIZER,
20+
total_steps: int,
21+
max_lr: float = 1.0,
22+
min_lr: float = 0.0,
23+
):
24+
self.total_steps = total_steps
25+
self.max_lr = max_lr
26+
self.min_lr = min_lr
27+
28+
self.step_t: int = 0
29+
self.base_lrs: List[float] = []
30+
31+
# record current value in self._last_lr to match API from torch.optim.lr_scheduler
32+
self.last_lr: List[float] = [self.max_lr]
33+
34+
super().__init__(optimizer)
35+
36+
self.init_lr()
37+
38+
def init_lr(self):
39+
self.base_lrs = []
40+
for param_group in self.optimizer.param_groups:
41+
param_group['lr'] = self.min_lr
42+
self.base_lrs.append(self.min_lr)
43+
44+
def get_lr(self) -> float:
45+
return self.last_lr[0]
46+
47+
def get_linear_lr(self) -> float:
48+
if self.step_t >= self.total_steps:
49+
return self.min_lr
50+
51+
progress: float = self.step_t / self.total_steps
52+
53+
return self.min_lr + (self.max_lr - self.min_lr) * ((1.0 - progress) / (1.0 - progress / 2.0))
54+
55+
def step(self):
56+
value: float = self.get_linear_lr()
57+
58+
self.step_t += 1
59+
60+
if self.optimizer is not None:
61+
for param_group in self.optimizer.param_groups:
62+
param_group['lr'] = value
63+
64+
self.last_lr = [value]
65+
66+
return value

tests/test_load_modules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def test_get_supported_optimizers():
4242

4343

4444
def test_get_supported_lr_schedulers():
45-
assert len(get_supported_lr_schedulers()) == 10
45+
assert len(get_supported_lr_schedulers()) == 11
4646

4747

4848
def test_get_supported_loss_functions():

tests/test_lr_schedulers.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from pytorch_optimizer.lr_scheduler.experimental.deberta_v3_lr_scheduler import deberta_v3_large_lr_scheduler
1111
from pytorch_optimizer.lr_scheduler.linear_warmup import CosineScheduler, LinearScheduler, PolyScheduler
1212
from pytorch_optimizer.lr_scheduler.proportion import ProportionScheduler
13+
from pytorch_optimizer.lr_scheduler.rex import REXScheduler
1314
from tests.utils import Example
1415

1516
CAWR_RECIPES = [
@@ -263,6 +264,29 @@ def test_proportion_no_last_lr_scheduler():
263264
np.testing.assert_almost_equal(2.0, rho_scheduler.get_lr(), 6)
264265

265266

267+
def test_rex_lr_scheduler():
268+
lrs = [
269+
0.888888,
270+
0.749999,
271+
0.571428,
272+
0.333333,
273+
0.0,
274+
]
275+
276+
base_optimizer = AdamP(Example().parameters())
277+
278+
lr_scheduler = REXScheduler(
279+
base_optimizer,
280+
total_steps=5,
281+
max_lr=1.0,
282+
min_lr=0.0,
283+
)
284+
285+
for expected_lr in lrs:
286+
_ = lr_scheduler.step()
287+
np.testing.assert_almost_equal(expected_lr, lr_scheduler.get_lr(), 6)
288+
289+
266290
def test_deberta_v3_large_lr_scheduler():
267291
model = nn.Sequential(*[nn.Linear(1, 1, bias=False) for _ in range(400)])
268292
deberta_v3_large_lr_scheduler(model)

0 commit comments

Comments
 (0)