Skip to content

Commit 944a353

Browse files
authored
Merge pull request #181 from kozistr/fix/chebyshev-scheduler
[Fix] Chebyshev LR Scheduler
2 parents c56b36a + 5622f95 commit 944a353

File tree

7 files changed

+114
-23
lines changed

7 files changed

+114
-23
lines changed

README.rst

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -87,14 +87,13 @@ If you want to build the optimizer with parameters & configs, there's `create_op
8787
Supported Optimizers
8888
--------------------
8989

90-
You can check the supported optimizers & lr schedulers.
90+
You can check the supported optimizers with below code.
9191

9292
::
9393

94-
from pytorch_optimizer import get_supported_optimizers, get_supported_lr_schedulers
94+
from pytorch_optimizer import get_supported_optimizers
9595

9696
supported_optimizers = get_supported_optimizers()
97-
supported_lr_schedulers = get_supported_lr_schedulers()
9897

9998
+--------------+---------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------+
10099
| Optimizer | Description | Official Code | Paper | Citation |
@@ -201,14 +200,10 @@ You can check the supported optimizers & lr schedulers.
201200
+--------------+---------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------+
202201
| Softplus T | *Calibrating the Adaptive Learning Rate to Improve Convergence of ADAM* | | `https://arxiv.org/abs/1908.00700 <https://arxiv.org/abs/1908.00700>`__ | `cite <https://ui.adsabs.harvard.edu/abs/2019arXiv190800700T/exportcitation>`__ |
203202
+--------------+---------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------+
204-
| EE LRS | *Wide-minima Density Hypothesis and the Explore-Exploit Learning Rate Schedule* | | `https://arxiv.org/abs/2003.03977 <https://arxiv.org/abs/2003.03977>`__ | `cite <https://ui.adsabs.harvard.edu/abs/2020arXiv200303977I/exportcitation>`__ |
203+
| Un-tuned w/u | *On the adequacy of untuned warmup for adaptive optimization* | | `https://arxiv.org/abs/1910.04209 <https://arxiv.org/abs/1910.04209>`__ | `cite <https://ui.adsabs.harvard.edu/abs/2019arXiv191004209M/exportcitation>`__ |
205204
+--------------+---------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------+
206205
| Norm Loss | *An efficient yet effective regularization method for deep neural networks* | | `https://arxiv.org/abs/2103.06583 <https://arxiv.org/abs/2103.06583>`__ | `cite <https://ui.adsabs.harvard.edu/abs/2021arXiv210306583G/exportcitation>`__ |
207206
+--------------+---------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------+
208-
| Chebyshev LR | *Acceleration via Fractal Learning Rate Schedules* | | `https://arxiv.org/abs/2103.01338 <https://arxiv.org/abs/2103.01338>`__ | `cite <https://ui.adsabs.harvard.edu/abs/2021arXiv210301338A/exportcitation>`__ |
209-
+--------------+---------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------+
210-
| Un-tuned WU | *On the adequacy of untuned warmup for adaptive optimization* | | `https://arxiv.org/abs/1910.04209 <https://arxiv.org/abs/1910.04209>`__ | `cite <https://ui.adsabs.harvard.edu/abs/2019arXiv191004209M/exportcitation>`__ |
211-
+--------------+---------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------+
212207
| AdaShift | *Decorrelation and Convergence of Adaptive Learning Rate Methods* | `github <https://github.com/MichaelKonobeev/adashift>`__ | `https://arxiv.org/abs/1810.00143v4 <https://arxiv.org/abs/1810.00143v4>`__ | `cite <https://ui.adsabs.harvard.edu/abs/2018arXiv181000143Z/exportcitation>`__ |
213208
+--------------+---------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------+
214209
| AdaDelta | *An Adaptive Learning Rate Method* | | `https://arxiv.org/abs/1212.5701v1 <https://arxiv.org/abs/1212.5701v1>`__ | `cite <https://ui.adsabs.harvard.edu/abs/2012arXiv1212.5701Z/exportcitation>`__ |
@@ -222,6 +217,25 @@ You can check the supported optimizers & lr schedulers.
222217
| Sophia | *A Scalable Stochastic Second-order Optimizer for Language Model Pre-training* | `github <https://github.com/Liuhong99/Sophia>`__ | `https://arxiv.org/abs/2305.14342 <https://arxiv.org/abs/2305.14342>`__ | `cite <https://github.com/Liuhong99/Sophia>`__ |
223218
+--------------+---------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------+
224219

220+
Supported LR Scheduler
221+
----------------------
222+
223+
You can check the supported learning rate schedulers with below code.
224+
225+
::
226+
227+
from pytorch_optimizer import get_supported_lr_schedulers
228+
229+
supported_lr_schedulers = get_supported_lr_schedulers()
230+
231+
+------------------+---------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------+
232+
| LR Scheduler | Description | Official Code | Paper | Citation |
233+
+==================+===================================================================================================+===================================================================================+===============================================================================================+======================================================================================================================+
234+
| Explore-Exploit | *Wide-minima Density Hypothesis and the Explore-Exploit Learning Rate Schedule* | | `https://arxiv.org/abs/2003.03977 <https://arxiv.org/abs/2003.03977>`__ | `cite <https://ui.adsabs.harvard.edu/abs/2020arXiv200303977I/exportcitation>`__ |
235+
+------------------+---------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------+
236+
| Chebyshev | *Acceleration via Fractal Learning Rate Schedules* | | `https://arxiv.org/abs/2103.01338 <https://arxiv.org/abs/2103.01338>`__ | `cite <https://ui.adsabs.harvard.edu/abs/2021arXiv210301338A/exportcitation>`__ |
237+
+------------------+---------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------+
238+
225239
Useful Resources
226240
----------------
227241

docs/changelogs/v2.10.1.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
### Fix
44

55
* `perturb` isn't multiplied by `-step_size` in SWATS optimizer. (#179)
6+
* `chebyshev step` has size of `T` while the permutation is `2^T`. (#168, #181)
67

78
### Diff
89

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ target-version = "py39"
9292
"./tests/test_general_optimizer_parameters.py" = ["D", "S101"]
9393
"./tests/test_load_optimizers.py" = ["D", "S101"]
9494
"./tests/test_load_lr_schedulers.py" = ["D", "S101"]
95-
"./tests/test_lr_schedulers.py" = ["D"]
95+
"./tests/test_lr_schedulers.py" = ["D", "S101"]
9696
"./tests/test_lr_scheduler_parameters.py" = ["D", "S101"]
9797
"./tests/test_create_optimizer.py" = ["D"]
9898
"./pytorch_optimizer/__init__.py" = ["F401"]

pytorch_optimizer/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
CyclicLR,
1212
OneCycleLR,
1313
)
14-
from pytorch_optimizer.lr_scheduler.chebyshev import get_chebyshev_schedule
14+
from pytorch_optimizer.lr_scheduler.chebyshev import get_chebyshev_lr, get_chebyshev_schedule
1515
from pytorch_optimizer.lr_scheduler.cosine_anealing import CosineAnnealingWarmupRestarts
1616
from pytorch_optimizer.lr_scheduler.experimental.deberta_v3_lr_scheduler import deberta_v3_large_lr_scheduler
1717
from pytorch_optimizer.lr_scheduler.linear_warmup import CosineScheduler, LinearScheduler, PolyScheduler
Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,66 @@
11
import numpy as np
22

33

4-
def chebyshev_steps(small_m: float, big_m: float, num_epochs: int) -> np.ndarray:
4+
def get_chebyshev_steps(num_epochs: int, small_m: float = 0.05, big_m: float = 1.0) -> np.ndarray:
55
r"""Chebyshev steps.
66
7+
gamma_{t} = (M + m) / 2.0 - (M - m) * cos ((t - 0.5) * pi / T) / 2, where t = 1, ..., T
8+
9+
:param num_epochs: int. stands for 'T' notation.
710
:param small_m: float. stands for 'm' notation.
811
:param big_m: float. stands for 'M' notation.
9-
:param num_epochs: int. stands for 'T' notation.
1012
:return: np.array. chebyshev_steps.
1113
"""
1214
c, r = (big_m + small_m) / 2.0, (big_m - small_m) / 2.0
13-
thetas = (np.arange(num_epochs) + 0.5) / num_epochs * np.pi
15+
thetas = (np.arange(num_epochs) + 0.5) * np.pi / num_epochs # epoch starts from 0, so +0.5 instead of -0.5
1416

1517
return 1.0 / (c - r * np.cos(thetas))
1618

1719

18-
def chebyshev_perm(num_epochs: int) -> np.ndarray:
19-
r"""Chebyshev permutation."""
20+
def get_chebyshev_permutation(num_epochs: int) -> np.ndarray:
21+
r"""Fractal chebyshev permutation.
22+
23+
sigma_{2T} := interlace(sigma_{T}, 2T + 1 - sigma_{T}), where
24+
interlace([a_{1}, ..., a_{n}], [b_{1}, ..., b_{n}]) := [a_{1}, b_{1}, ..., n_{1}, b_{n}]
25+
26+
:param num_epochs: int. number of epochs.
27+
"""
2028
perm = np.array([0])
2129
while len(perm) < num_epochs:
2230
perm = np.vstack([perm, 2 * len(perm) - 1 - perm]).T.flatten()
2331
return perm
2432

2533

2634
def get_chebyshev_schedule(num_epochs: int) -> np.ndarray:
27-
r"""Get Chebyshev schedules."""
28-
steps: np.ndarray = chebyshev_steps(0.1, 1, num_epochs - 2)
29-
perm: np.ndarray = chebyshev_perm(num_epochs - 2)
35+
r"""Get Chebyshev schedules.
36+
37+
:param num_epochs: int. number of total epochs.
38+
"""
39+
steps: np.ndarray = get_chebyshev_steps(num_epochs)
40+
perm: np.ndarray = get_chebyshev_permutation(num_epochs - 2)
3041
return steps[perm]
42+
43+
44+
def get_chebyshev_lr(lr: float, epoch: int, num_epochs: int, is_warmup: bool = False) -> float:
45+
r"""Get chebyshev learning rate.
46+
47+
:param lr: float. learning rate.
48+
:param epoch: int. current epochs.
49+
:param num_epochs: int. number of total epochs.
50+
:param is_warmup: bool. whether warm-up stage or not.
51+
"""
52+
if is_warmup:
53+
return lr
54+
55+
epoch_power: int = np.power(2, int(np.log2(num_epochs - 1)) + 1) if num_epochs > 1 else 1
56+
scheduler = get_chebyshev_schedule(epoch_power)
57+
58+
idx: int = epoch - 2
59+
if idx < 0:
60+
idx = 0
61+
elif idx > len(scheduler) - 1:
62+
idx = len(scheduler) - 1
63+
64+
chebyshev_value: float = scheduler[idx]
65+
66+
return lr * chebyshev_value

tests/test_lr_scheduler_parameters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,4 +62,4 @@ def test_linear_warmup_lr_scheduler_params():
6262

6363
def test_chebyshev_params():
6464
with pytest.raises(IndexError):
65-
get_chebyshev_schedule(2)
65+
get_chebyshev_schedule(0)

tests/test_lr_schedulers.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
import pytest
55
from torch import nn
66

7-
from pytorch_optimizer import AdamP, get_chebyshev_schedule
8-
from pytorch_optimizer.lr_scheduler.chebyshev import chebyshev_perm
7+
from pytorch_optimizer import AdamP, get_chebyshev_lr, get_chebyshev_schedule
8+
from pytorch_optimizer.lr_scheduler.chebyshev import get_chebyshev_permutation
99
from pytorch_optimizer.lr_scheduler.cosine_anealing import CosineAnnealingWarmupRestarts
1010
from pytorch_optimizer.lr_scheduler.experimental.deberta_v3_lr_scheduler import deberta_v3_large_lr_scheduler
1111
from pytorch_optimizer.lr_scheduler.linear_warmup import CosineScheduler, LinearScheduler, PolyScheduler
@@ -152,8 +152,48 @@ def test_cosine_annealing_warmup_restarts(cosine_annealing_warmup_restart_param)
152152

153153

154154
def test_get_chebyshev_scheduler():
155-
np.testing.assert_almost_equal(get_chebyshev_schedule(3), 1.81818182, decimal=6)
156-
np.testing.assert_array_equal(chebyshev_perm(5), np.asarray([0, 7, 3, 4, 1, 6, 2, 5]))
155+
# test the first nontrivial permutations sigma_{T}
156+
recipes = {
157+
2: np.asarray([0, 1]),
158+
4: np.asarray([0, 3, 1, 2]),
159+
8: np.asarray([0, 7, 3, 4, 1, 6, 2, 5]),
160+
16: np.asarray([0, 15, 7, 8, 3, 12, 4, 11, 1, 14, 6, 9, 2, 13, 5, 10]),
161+
}
162+
163+
for k, v in recipes.items():
164+
np.testing.assert_array_equal(get_chebyshev_permutation(k), v)
165+
166+
np.testing.assert_almost_equal(get_chebyshev_schedule(1), 1.904762, decimal=6)
167+
np.testing.assert_almost_equal(get_chebyshev_schedule(3), 8.799878, decimal=6)
168+
169+
170+
def test_get_chebyshev_lr():
171+
recipes = [
172+
0.019125119558059765,
173+
0.019125119558059765,
174+
0.0010022924983586518,
175+
0.0020901181252459123,
176+
0.0017496032811320122,
177+
0.006336331139456458,
178+
0.0011208500962143087,
179+
0.004471008393917827,
180+
0.0012101602977446309,
181+
0.014193791132074378,
182+
0.0010208804147606497,
183+
0.0025832131864890117,
184+
0.0015085567867114075,
185+
0.009426190153875151,
186+
0.0010594201194061095,
187+
0.0033213041232648503,
188+
0.001335267780289186,
189+
0.001335267780289186,
190+
0.001335267780289186,
191+
]
192+
193+
np.testing.assert_almost_equal(get_chebyshev_lr(1e-3, 0, 16, is_warmup=True), 1e-3)
194+
195+
for i, expected_lr in enumerate(recipes, start=1):
196+
np.testing.assert_almost_equal(get_chebyshev_lr(1e-3, i, 16, is_warmup=False), expected_lr)
157197

158198

159199
def test_linear_warmup_linear_scheduler():

0 commit comments

Comments
 (0)