1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14+ import math
1415import os
1516import pickle
1617import sys
@@ -361,10 +362,10 @@ def test_tqdm_progress_bar_value_on_colab(tmpdir):
361362 [2 , 3 , 1 , [1 , 2 , 3 , 4 , 5 ], [1 , 2 , 3 ]],
362363 [0 , 0 , 3 , None , None ],
363364 [1 , 0 , 3 , [1 ], None ],
364- [1 , 1 , 3 , [1 , 2 ], [1 ]],
365+ [1 , 1 , 3 , [2 ], [1 ]],
365366 [5 , 0 , 3 , [3 , 5 ], None ],
366- [5 , 2 , 3 , [3 , 5 , 7 ], [2 ]],
367- [5 , 2 , 6 , [5 , 7 ], [2 ]],
367+ [5 , 2 , 3 , [3 , 6 , 7 ], [2 ]],
368+ [5 , 2 , 6 , [6 , 7 ], [2 ]],
368369 ],
369370)
370371def test_main_progress_bar_update_amount (
@@ -563,16 +564,56 @@ def test_tqdm_progress_bar_can_be_pickled():
563564 pickle .dumps (bar )
564565
565566
566- @RunIf (min_gpus = 2 , standalone = True )
567567@pytest .mark .parametrize (
568- ["total_train_samples " , "train_batch_size " , "total_val_samples" , "val_batch_size" , "val_check_interval " ],
569- [(8 , 4 , 2 , 1 , 0.2 ), ( 8 , 4 , 2 , 1 , 0.5 )],
568+ ["val_check_interval " , "main_progress_bar_updates " , "val_progress_bar_updates " ],
569+ [(4 , [ 3 , 6 , 9 , 12 , 14 ], [ 3 , 6 , 7 ]), ( 0.5 , [ 3 , 6 , 9 , 12 , 15 , 18 , 21 ], [ 3 , 6 , 7 ] )],
570570)
571571def test_progress_bar_max_val_check_interval (
572- tmpdir , total_train_samples , train_batch_size , total_val_samples , val_batch_size , val_check_interval
572+ tmpdir , val_check_interval , main_progress_bar_updates , val_progress_bar_updates
573573):
574+ limit_batches = 7
575+ model = BoringModel ()
576+ trainer = Trainer (
577+ default_root_dir = tmpdir ,
578+ num_sanity_val_steps = 0 ,
579+ max_epochs = 1 ,
580+ enable_model_summary = False ,
581+ val_check_interval = val_check_interval ,
582+ limit_train_batches = limit_batches ,
583+ limit_val_batches = limit_batches ,
584+ callbacks = TQDMProgressBar (refresh_rate = 3 ),
585+ )
586+ with mock .patch ("pytorch_lightning.callbacks.progress.tqdm_progress.Tqdm" , MockTqdm ):
587+ trainer .fit (model )
588+
589+ pbar = trainer .progress_bar_callback
590+ assert pbar .main_progress_bar .n_values == main_progress_bar_updates
591+ assert pbar .val_progress_bar .n_values == val_progress_bar_updates
592+
593+ val_check_batch = (
594+ max (1 , int (limit_batches * val_check_interval )) if isinstance (val_check_interval , float ) else val_check_interval
595+ )
596+ assert trainer .val_check_batch == val_check_batch
597+ val_checks_per_epoch = math .ceil (limit_batches // val_check_batch )
598+ pbar_callback = trainer .progress_bar_callback
599+ total_val_batches = limit_batches * val_checks_per_epoch
600+
601+ assert pbar_callback .val_progress_bar .n == limit_batches
602+ assert pbar_callback .val_progress_bar .total == limit_batches
603+ assert pbar_callback .main_progress_bar .n == limit_batches + total_val_batches
604+ assert pbar_callback .main_progress_bar .total == limit_batches + total_val_batches
605+ assert pbar_callback .is_enabled
606+
607+
608+ @RunIf (min_gpus = 2 , standalone = True )
609+ @pytest .mark .parametrize ("val_check_interval" , [0.2 , 0.5 ])
610+ def test_progress_bar_max_val_check_interval_ddp (tmpdir , val_check_interval ):
574611 world_size = 2
575- train_data = DataLoader (RandomDataset (32 , total_train_samples ), batch_size = train_batch_size )
612+ total_train_samples = 16
613+ train_batch_size = 4
614+ total_val_samples = 2
615+ val_batch_size = 1
616+ train_data = DataLoader (RandomDataset (32 , 8 ), batch_size = train_batch_size )
576617 val_data = DataLoader (RandomDataset (32 , total_val_samples ), batch_size = val_batch_size )
577618
578619 model = BoringModel ()
@@ -599,8 +640,8 @@ def test_progress_bar_max_val_check_interval(
599640 assert pbar_callback .val_progress_bar .n == total_val_batches
600641 assert pbar_callback .val_progress_bar .total == total_val_batches
601642 total_val_batches = total_val_batches * val_checks_per_epoch
602- assert pbar_callback .main_progress_bar .n == total_train_batches + total_val_batches
603- assert pbar_callback .main_progress_bar .total == total_train_batches + total_val_batches
643+ assert pbar_callback .main_progress_bar .n == ( total_train_batches + total_val_batches ) // world_size
644+ assert pbar_callback .main_progress_bar .total == ( total_train_batches + total_val_batches ) // world_size
604645 assert pbar_callback .is_enabled
605646
606647
0 commit comments