Skip to content

Commit 1eb28b3

Browse files
committed
update: test_proportion_no_last_lr_scheduler
1 parent 2051922 commit 1eb28b3

File tree

1 file changed

+24
-1
lines changed

1 file changed

+24
-1
lines changed

tests/test_lr_schedulers.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Tuple
2+
13
import numpy as np
24
import pytest
35
from torch import nn
@@ -182,7 +184,7 @@ def test_linear_warmup_poly_scheduler():
182184

183185

184186
@pytest.mark.parametrize('proportion_learning_rate', PROPORTION_LEARNING_RATES)
185-
def test_proportion_scheduler(proportion_learning_rate):
187+
def test_proportion_scheduler(proportion_learning_rate: Tuple[float, float, float]):
186188
base_optimizer = AdamP(Example().parameters())
187189
lr_scheduler = CosineScheduler(
188190
base_optimizer, t_max=10, max_lr=proportion_learning_rate[0], min_lr=proportion_learning_rate[1], init_lr=1e-2
@@ -200,6 +202,27 @@ def test_proportion_scheduler(proportion_learning_rate):
200202
np.testing.assert_almost_equal(proportion_learning_rate[2], rho_scheduler.get_lr(), 6)
201203

202204

205+
def test_proportion_no_last_lr_scheduler():
206+
base_optimizer = AdamP(Example().parameters())
207+
lr_scheduler = CosineAnnealingWarmupRestarts(
208+
base_optimizer,
209+
first_cycle_steps=10,
210+
max_lr=1e-2,
211+
min_lr=1e-2,
212+
)
213+
rho_scheduler = ProportionScheduler(
214+
lr_scheduler,
215+
max_lr=1e-2,
216+
min_lr=1e-2,
217+
max_value=2.0,
218+
min_value=1.0,
219+
)
220+
221+
for _ in range(10):
222+
_ = rho_scheduler.step()
223+
np.testing.assert_almost_equal(2.0, rho_scheduler.get_lr(), 6)
224+
225+
203226
def test_deberta_v3_large_lr_scheduler():
204227
try:
205228
from transformers import AutoConfig, AutoModel

0 commit comments

Comments
 (0)