1313# limitations under the License.
1414import pytest
1515
16+ import tests .base .develop_utils as tutils
1617from pytorch_lightning import Trainer
1718from pytorch_lightning .callbacks import LearningRateMonitor
1819from pytorch_lightning .utilities .exceptions import MisconfigurationException
19- from tests .base import EvalModelTemplate
20- import tests .base .develop_utils as tutils
20+ from tests .base import BoringModel , EvalModelTemplate
2121
2222
2323def test_lr_monitor_single_lr (tmpdir ):
@@ -43,7 +43,7 @@ def test_lr_monitor_single_lr(tmpdir):
4343 'Momentum should not be logged by default'
4444 assert len (lr_monitor .lrs ) == len (trainer .lr_schedulers ), \
4545 'Number of learning rates logged does not match number of lr schedulers'
46- assert all ([ k in [ 'lr-Adam' ] for k in lr_monitor .lrs .keys ()]) , \
46+ assert lr_monitor . lr_sch_names == list ( lr_monitor .lrs .keys ()) == [ 'lr-Adam' ] , \
4747 'Names of learning rates not set correctly'
4848
4949
@@ -134,7 +134,7 @@ def test_lr_monitor_multi_lrs(tmpdir, logging_interval):
134134 assert lr_monitor .lrs , 'No learning rates logged'
135135 assert len (lr_monitor .lrs ) == len (trainer .lr_schedulers ), \
136136 'Number of learning rates logged does not match number of lr schedulers'
137- assert all ([ k in ['lr-Adam' , 'lr-Adam-1' ] for k in lr_monitor . lrs . keys ()]) , \
137+ assert lr_monitor . lr_sch_names == ['lr-Adam' , 'lr-Adam-1' ], \
138138 'Names of learning rates not set correctly'
139139
140140 if logging_interval == 'step' :
@@ -167,5 +167,27 @@ def test_lr_monitor_param_groups(tmpdir):
167167 assert lr_monitor .lrs , 'No learning rates logged'
168168 assert len (lr_monitor .lrs ) == 2 * len (trainer .lr_schedulers ), \
169169 'Number of learning rates logged does not match number of param groups'
170- assert all ([k in ['lr-Adam/pg1' , 'lr-Adam/pg2' ] for k in lr_monitor .lrs .keys ()]), \
170+ assert lr_monitor .lr_sch_names == ['lr-Adam' ]
171+ assert list (lr_monitor .lrs .keys ()) == ['lr-Adam/pg1' , 'lr-Adam/pg2' ], \
171172 'Names of learning rates not set correctly'
173+
174+
175+ def test_lr_monitor_custom_name (tmpdir ):
176+ class TestModel (BoringModel ):
177+ def configure_optimizers (self ):
178+ optimizer , [scheduler ] = super ().configure_optimizers ()
179+ lr_scheduler = {'scheduler' : scheduler , 'name' : 'my_logging_name' }
180+ return optimizer , [lr_scheduler ]
181+
182+ lr_monitor = LearningRateMonitor ()
183+ trainer = Trainer (
184+ default_root_dir = tmpdir ,
185+ max_epochs = 2 ,
186+ limit_val_batches = 0.1 ,
187+ limit_train_batches = 0.5 ,
188+ callbacks = [lr_monitor ],
189+ progress_bar_refresh_rate = 0 ,
190+ weights_summary = None ,
191+ )
192+ trainer .fit (TestModel ())
193+ assert lr_monitor .lr_sch_names == list (lr_monitor .lrs .keys ()) == ['my_logging_name' ]
0 commit comments