@@ -545,13 +545,14 @@ def test_warning_with_few_workers(_, tmp_path, ckpt_path, stage):
545
545
546
546
trainer = Trainer (default_root_dir = tmp_path , max_epochs = 1 , limit_val_batches = 0.1 , limit_train_batches = 0.2 )
547
547
548
- with pytest . warns ( UserWarning , match = f"The ' { stage } _dataloader' does not have many workers" ) :
549
- if stage == "test" :
550
- if ckpt_path in ( "specific" , "best" ):
551
- trainer . fit ( model , train_dataloaders = train_dl , val_dataloaders = val_dl )
552
- ckpt_path = trainer . checkpoint_callback . best_model_path if ckpt_path == "specific" else ckpt_path
548
+ if stage == "test" :
549
+ if ckpt_path in ( "specific" , "best" ) :
550
+ trainer . fit ( model , train_dataloaders = train_dl , val_dataloaders = val_dl )
551
+ ckpt_path = trainer . checkpoint_callback . best_model_path if ckpt_path == "specific" else ckpt_path
552
+ with pytest . warns ( UserWarning , match = f"The ' { stage } _dataloader' does not have many workers" ):
553
553
trainer .test (model , dataloaders = train_dl , ckpt_path = ckpt_path )
554
- else :
554
+ else :
555
+ with pytest .warns (UserWarning , match = f"The '{ stage } _dataloader' does not have many workers" ):
555
556
trainer .fit (model , train_dataloaders = train_dl , val_dataloaders = val_dl )
556
557
557
558
@@ -579,16 +580,15 @@ def training_step(self, batch, batch_idx):
579
580
580
581
trainer = Trainer (default_root_dir = tmp_path , max_epochs = 1 , limit_val_batches = 0.1 , limit_train_batches = 0.2 )
581
582
582
- with pytest .warns (
583
- UserWarning ,
584
- match = f"The '{ stage } _dataloader' does not have many workers" ,
585
- ):
586
- if stage == "test" :
587
- if ckpt_path in ("specific" , "best" ):
588
- trainer .fit (model , train_dataloaders = train_multi_dl , val_dataloaders = val_multi_dl )
589
- ckpt_path = trainer .checkpoint_callback .best_model_path if ckpt_path == "specific" else ckpt_path
583
+ if stage == "test" :
584
+ if ckpt_path in ("specific" , "best" ):
585
+ trainer .fit (model , train_dataloaders = train_multi_dl , val_dataloaders = val_multi_dl )
586
+ ckpt_path = trainer .checkpoint_callback .best_model_path if ckpt_path == "specific" else ckpt_path
587
+
588
+ with pytest .warns (UserWarning , match = f"The '{ stage } _dataloader' does not have many workers" ):
590
589
trainer .test (model , dataloaders = test_multi_dl , ckpt_path = ckpt_path )
591
- else :
590
+ else :
591
+ with pytest .warns (UserWarning , match = f"The '{ stage } _dataloader' does not have many workers" ):
592
592
trainer .fit (model , train_dataloaders = train_multi_dl , val_dataloaders = val_multi_dl )
593
593
594
594
@@ -669,28 +669,35 @@ def test_auto_add_worker_init_fn_distributed(tmp_path, monkeypatch):
669
669
trainer .fit (model , train_dataloaders = dataloader )
670
670
671
671
672
- def test_warning_with_small_dataloader_and_logging_interval (tmp_path ):
672
+ @pytest .mark .parametrize ("log_interval" , [2 , 11 ])
673
+ def test_warning_with_small_dataloader_and_logging_interval (log_interval , tmp_path ):
673
674
"""Test that a warning message is shown if the dataloader length is too short for the chosen logging interval."""
674
675
model = BoringModel ()
675
676
dataloader = DataLoader (RandomDataset (32 , length = 10 ))
676
677
model .train_dataloader = lambda : dataloader
677
678
678
- with pytest .warns (UserWarning , match = r"The number of training batches \(10\) is smaller than the logging interval" ):
679
- trainer = Trainer (default_root_dir = tmp_path , max_epochs = 1 , log_every_n_steps = 11 , logger = CSVLogger (tmp_path ))
679
+ trainer = Trainer (
680
+ default_root_dir = tmp_path ,
681
+ max_epochs = 1 ,
682
+ log_every_n_steps = log_interval ,
683
+ limit_train_batches = 1 if log_interval < 10 else None ,
684
+ logger = CSVLogger (tmp_path ),
685
+ )
686
+ with pytest .warns (
687
+ UserWarning ,
688
+ match = rf"The number of training batches \({ log_interval - 1 } \) is smaller than the logging interval" ,
689
+ ):
680
690
trainer .fit (model )
681
691
682
- with pytest .warns (UserWarning , match = r"The number of training batches \(1\) is smaller than the logging interval" ):
683
- trainer = Trainer (
684
- default_root_dir = tmp_path ,
685
- max_epochs = 1 ,
686
- log_every_n_steps = 2 ,
687
- limit_train_batches = 1 ,
688
- logger = CSVLogger (tmp_path ),
689
- )
690
- trainer .fit (model )
691
692
693
+ def test_warning_with_small_dataloader_and_fast_dev_run (tmp_path ):
694
+ """Test that a warning message is shown if the dataloader length is too short for the chosen logging interval."""
695
+ model = BoringModel ()
696
+ dataloader = DataLoader (RandomDataset (32 , length = 10 ))
697
+ model .train_dataloader = lambda : dataloader
698
+
699
+ trainer = Trainer (default_root_dir = tmp_path , fast_dev_run = True , log_every_n_steps = 2 )
692
700
with no_warning_call (UserWarning , match = "The number of training batches" ):
693
- trainer = Trainer (default_root_dir = tmp_path , fast_dev_run = True , log_every_n_steps = 2 )
694
701
trainer .fit (model )
695
702
696
703
0 commit comments