Skip to content

Commit 721e1db

Browse files
committed
update: test_rex_lr_scheduler
1 parent 457120d commit 721e1db

File tree

1 file changed

+24
-0
lines changed

1 file changed

+24
-0
lines changed

tests/test_lr_schedulers.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from pytorch_optimizer.lr_scheduler.experimental.deberta_v3_lr_scheduler import deberta_v3_large_lr_scheduler
1111
from pytorch_optimizer.lr_scheduler.linear_warmup import CosineScheduler, LinearScheduler, PolyScheduler
1212
from pytorch_optimizer.lr_scheduler.proportion import ProportionScheduler
13+
from pytorch_optimizer.lr_scheduler.rex import REXScheduler
1314
from tests.utils import Example
1415

1516
CAWR_RECIPES = [
@@ -263,6 +264,29 @@ def test_proportion_no_last_lr_scheduler():
263264
np.testing.assert_almost_equal(2.0, rho_scheduler.get_lr(), 6)
264265

265266

267+
def test_rex_lr_scheduler():
268+
lrs = [
269+
0.888888,
270+
0.749999,
271+
0.571428,
272+
0.333333,
273+
0.0,
274+
]
275+
276+
base_optimizer = AdamP(Example().parameters())
277+
278+
lr_scheduler = REXScheduler(
279+
base_optimizer,
280+
total_steps=5,
281+
max_lr=1.0,
282+
min_lr=0.0,
283+
)
284+
285+
for expected_lr in lrs:
286+
lr: float = lr_scheduler.step()
287+
np.testing.assert_almost_equal(expected_lr, lr, 6)
288+
289+
266290
def test_deberta_v3_large_lr_scheduler():
267291
model = nn.Sequential(*[nn.Linear(1, 1, bias=False) for _ in range(400)])
268292
deberta_v3_large_lr_scheduler(model)

0 commit comments

Comments
 (0)