@@ -545,13 +545,14 @@ def test_warning_with_few_workers(_, tmp_path, ckpt_path, stage):
545545
546546 trainer = Trainer (default_root_dir = tmp_path , max_epochs = 1 , limit_val_batches = 0.1 , limit_train_batches = 0.2 )
547547
548- with pytest . warns ( UserWarning , match = f"The ' { stage } _dataloader' does not have many workers" ) :
549- if stage == "test" :
550- if ckpt_path in ( "specific" , "best" ):
551- trainer . fit ( model , train_dataloaders = train_dl , val_dataloaders = val_dl )
552- ckpt_path = trainer . checkpoint_callback . best_model_path if ckpt_path == "specific" else ckpt_path
548+ if stage == "test" :
549+ if ckpt_path in ( "specific" , "best" ) :
550+ trainer . fit ( model , train_dataloaders = train_dl , val_dataloaders = val_dl )
551+ ckpt_path = trainer . checkpoint_callback . best_model_path if ckpt_path == "specific" else ckpt_path
552+ with pytest . warns ( UserWarning , match = f"The ' { stage } _dataloader' does not have many workers" ):
553553 trainer .test (model , dataloaders = train_dl , ckpt_path = ckpt_path )
554- else :
554+ else :
555+ with pytest .warns (UserWarning , match = f"The '{ stage } _dataloader' does not have many workers" ):
555556 trainer .fit (model , train_dataloaders = train_dl , val_dataloaders = val_dl )
556557
557558
@@ -579,16 +580,15 @@ def training_step(self, batch, batch_idx):
579580
580581 trainer = Trainer (default_root_dir = tmp_path , max_epochs = 1 , limit_val_batches = 0.1 , limit_train_batches = 0.2 )
581582
582- with pytest .warns (
583- UserWarning ,
584- match = f"The '{ stage } _dataloader' does not have many workers" ,
585- ):
586- if stage == "test" :
587- if ckpt_path in ("specific" , "best" ):
588- trainer .fit (model , train_dataloaders = train_multi_dl , val_dataloaders = val_multi_dl )
589- ckpt_path = trainer .checkpoint_callback .best_model_path if ckpt_path == "specific" else ckpt_path
583+ if stage == "test" :
584+ if ckpt_path in ("specific" , "best" ):
585+ trainer .fit (model , train_dataloaders = train_multi_dl , val_dataloaders = val_multi_dl )
586+ ckpt_path = trainer .checkpoint_callback .best_model_path if ckpt_path == "specific" else ckpt_path
587+
588+ with pytest .warns (UserWarning , match = f"The '{ stage } _dataloader' does not have many workers" ):
590589 trainer .test (model , dataloaders = test_multi_dl , ckpt_path = ckpt_path )
591- else :
590+ else :
591+ with pytest .warns (UserWarning , match = f"The '{ stage } _dataloader' does not have many workers" ):
592592 trainer .fit (model , train_dataloaders = train_multi_dl , val_dataloaders = val_multi_dl )
593593
594594
@@ -669,28 +669,35 @@ def test_auto_add_worker_init_fn_distributed(tmp_path, monkeypatch):
669669 trainer .fit (model , train_dataloaders = dataloader )
670670
671671
672- def test_warning_with_small_dataloader_and_logging_interval (tmp_path ):
672+ @pytest .mark .parametrize ("log_interval" , [2 , 11 ])
673+ def test_warning_with_small_dataloader_and_logging_interval (log_interval , tmp_path ):
673674 """Test that a warning message is shown if the dataloader length is too short for the chosen logging interval."""
674675 model = BoringModel ()
675676 dataloader = DataLoader (RandomDataset (32 , length = 10 ))
676677 model .train_dataloader = lambda : dataloader
677678
678- with pytest .warns (UserWarning , match = r"The number of training batches \(10\) is smaller than the logging interval" ):
679- trainer = Trainer (default_root_dir = tmp_path , max_epochs = 1 , log_every_n_steps = 11 , logger = CSVLogger (tmp_path ))
679+ trainer = Trainer (
680+ default_root_dir = tmp_path ,
681+ max_epochs = 1 ,
682+ log_every_n_steps = log_interval ,
683+ limit_train_batches = 1 if log_interval < 10 else None ,
684+ logger = CSVLogger (tmp_path ),
685+ )
686+ with pytest .warns (
687+ UserWarning ,
688+ match = rf"The number of training batches \({ log_interval - 1 } \) is smaller than the logging interval" ,
689+ ):
680690 trainer .fit (model )
681691
682- with pytest .warns (UserWarning , match = r"The number of training batches \(1\) is smaller than the logging interval" ):
683- trainer = Trainer (
684- default_root_dir = tmp_path ,
685- max_epochs = 1 ,
686- log_every_n_steps = 2 ,
687- limit_train_batches = 1 ,
688- logger = CSVLogger (tmp_path ),
689- )
690- trainer .fit (model )
691692
693+ def test_warning_with_small_dataloader_and_fast_dev_run (tmp_path ):
694+ """Test that a warning message is shown if the dataloader length is too short for the chosen logging interval."""
695+ model = BoringModel ()
696+ dataloader = DataLoader (RandomDataset (32 , length = 10 ))
697+ model .train_dataloader = lambda : dataloader
698+
699+ trainer = Trainer (default_root_dir = tmp_path , fast_dev_run = True , log_every_n_steps = 2 )
692700 with no_warning_call (UserWarning , match = "The number of training batches" ):
693- trainer = Trainer (default_root_dir = tmp_path , fast_dev_run = True , log_every_n_steps = 2 )
694701 trainer .fit (model )
695702
696703
0 commit comments