Skip to content

Commit 457120d

Browse files
committed
update: REX lr scheduler
1 parent 4af1880 commit 457120d

File tree

3 files changed

+3
-2
lines changed

3 files changed

+3
-2
lines changed

pytorch_optimizer/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from pytorch_optimizer.lr_scheduler.experimental.deberta_v3_lr_scheduler import deberta_v3_large_lr_scheduler
2727
from pytorch_optimizer.lr_scheduler.linear_warmup import CosineScheduler, LinearScheduler, PolyScheduler
2828
from pytorch_optimizer.lr_scheduler.proportion import ProportionScheduler
29+
from pytorch_optimizer.lr_scheduler.rex import REXScheduler
2930
from pytorch_optimizer.optimizer.a2grad import A2Grad
3031
from pytorch_optimizer.optimizer.adabelief import AdaBelief
3132
from pytorch_optimizer.optimizer.adabound import AdaBound
@@ -195,6 +196,7 @@
195196
PolyScheduler,
196197
LinearScheduler,
197198
ProportionScheduler,
199+
REXScheduler,
198200
]
199201
LR_SCHEDULERS: Dict[str, SCHEDULER] = {
200202
str(lr_scheduler.__name__).lower(): lr_scheduler for lr_scheduler in LR_SCHEDULER_LIST

pytorch_optimizer/base/scheduler.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ def step(self):
7676

7777
self.step_t += 1
7878

79-
# apply the lr to optimizer if it's provided
8079
if self.optimizer is not None:
8180
for param_group in self.optimizer.param_groups:
8281
param_group['lr'] = value

tests/test_load_modules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def test_get_supported_optimizers():
4242

4343

4444
def test_get_supported_lr_schedulers():
45-
assert len(get_supported_lr_schedulers()) == 10
45+
assert len(get_supported_lr_schedulers()) == 11
4646

4747

4848
def test_get_supported_loss_functions():

0 commit comments

Comments
 (0)