2020from pytorch_lightning import Callback , Trainer
2121from pytorch_lightning .callbacks import ModelCheckpoint
2222from pytorch_lightning .utilities .exceptions import MisconfigurationException
23- from tests .base import EvalModelTemplate
2423from tests .helpers .boring_model import BoringModel
2524from tests .helpers .runif import RunIf
2625
@@ -79,7 +78,7 @@ def test_reducelronplateau_with_no_monitor_raises(tmpdir):
7978 """
8079 Test exception when a ReduceLROnPlateau is used with no monitor
8180 """
82- model = EvalModelTemplate ()
81+ model = BoringModel ()
8382 optimizer = optim .Adam (model .parameters ())
8483 model .configure_optimizers = lambda : ([optimizer ], [optim .lr_scheduler .ReduceLROnPlateau (optimizer )])
8584 trainer = Trainer (default_root_dir = tmpdir , fast_dev_run = True )
@@ -93,7 +92,7 @@ def test_reducelronplateau_with_no_monitor_in_lr_scheduler_dict_raises(tmpdir):
9392 """
9493 Test exception when lr_scheduler dict has a ReduceLROnPlateau with no monitor
9594 """
96- model = EvalModelTemplate ()
95+ model = BoringModel ()
9796 optimizer = optim .Adam (model .parameters ())
9897 model .configure_optimizers = lambda : {
9998 "optimizer" : optimizer ,
@@ -376,33 +375,47 @@ def configure_optimizers(self):
376375 trainer .fit (model )
377376
378377
379- def test_lr_scheduler_strict (tmpdir ):
378+ @pytest .mark .parametrize ("complete_epoch" , [True , False ])
379+ @mock .patch ("torch.optim.lr_scheduler.ReduceLROnPlateau.step" )
380+ def test_lr_scheduler_strict (step_mock , tmpdir , complete_epoch ):
380381 """
381382 Test "strict" support in lr_scheduler dict
382383 """
383- model = EvalModelTemplate ()
384+ model = BoringModel ()
384385 optimizer = optim .Adam (model .parameters ())
385386 scheduler = optim .lr_scheduler .ReduceLROnPlateau (optimizer )
386- trainer = Trainer (default_root_dir = tmpdir , max_epochs = 1 )
387+ max_epochs = 1 if complete_epoch else None
388+ max_steps = None if complete_epoch else 1
389+ trainer = Trainer (default_root_dir = tmpdir , max_epochs = max_epochs , max_steps = max_steps )
387390
388391 model .configure_optimizers = lambda : {
389392 "optimizer" : optimizer ,
390393 "lr_scheduler" : {"scheduler" : scheduler , "monitor" : "giraffe" , "strict" : True },
391394 }
392- with pytest .raises (
393- MisconfigurationException ,
394- match = r"ReduceLROnPlateau conditioned on metric .* which is not available\. Available metrics are:" ,
395- ):
395+
396+ if complete_epoch :
397+ with pytest .raises (
398+ MisconfigurationException ,
399+ match = r"ReduceLROnPlateau conditioned on metric .* which is not available\. Available metrics are:" ,
400+ ):
401+ trainer .fit (model )
402+ else :
396403 trainer .fit (model )
397404
405+ step_mock .assert_not_called ()
406+
398407 model .configure_optimizers = lambda : {
399408 "optimizer" : optimizer ,
400409 "lr_scheduler" : {"scheduler" : scheduler , "monitor" : "giraffe" , "strict" : False },
401410 }
402- with pytest .warns (
403- RuntimeWarning , match = r"ReduceLROnPlateau conditioned on metric .* which is not available but strict"
404- ):
405- trainer .fit (model )
411+
412+ if complete_epoch :
413+ with pytest .warns (
414+ RuntimeWarning , match = r"ReduceLROnPlateau conditioned on metric .* which is not available but strict"
415+ ):
416+ trainer .fit (model )
417+
418+ step_mock .assert_not_called ()
406419
407420
408421def test_unknown_configure_optimizers_raises (tmpdir ):
0 commit comments