@@ -109,98 +109,98 @@ def test_tqdm_progress_bar_misconfiguration():
109109 Trainer (callbacks = TQDMProgressBar (), enable_progress_bar = False )
110110
111111
112+ @patch ("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE" , False )
112113@pytest .mark .parametrize ("num_dl" , [1 , 2 ])
113114def test_tqdm_progress_bar_totals (tmp_path , num_dl ):
114115 """Test that the progress finishes with the correct total steps processed."""
115- with patch ("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE" , False ):
116-
117- class CustomModel (BoringModel ):
118- def _get_dataloaders (self ):
119- dls = [DataLoader (RandomDataset (32 , 64 )), DataLoader (RandomDataset (32 , 64 ))]
120- return dls [0 ] if num_dl == 1 else dls
121-
122- def val_dataloader (self ):
123- return self ._get_dataloaders ()
124-
125- def test_dataloader (self ):
126- return self ._get_dataloaders ()
127-
128- def predict_dataloader (self ):
129- return self ._get_dataloaders ()
130-
131- def validation_step (self , batch , batch_idx , dataloader_idx = 0 ):
132- return
133-
134- def test_step (self , batch , batch_idx , dataloader_idx = 0 ):
135- return
136-
137- def predict_step (self , batch , batch_idx , dataloader_idx = 0 ):
138- return
139-
140- model = CustomModel ()
141-
142- # check the sanity dataloaders
143- num_sanity_val_steps = 4
144- trainer = Trainer (
145- default_root_dir = tmp_path , max_epochs = 1 , limit_train_batches = 0 , num_sanity_val_steps = num_sanity_val_steps
146- )
147- pbar = trainer .progress_bar_callback
148- with mock .patch ("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm" , MockTqdm ):
149- trainer .fit (model )
150-
151- expected_sanity_steps = [num_sanity_val_steps ] * num_dl
152- assert not pbar .val_progress_bar .leave
153- assert trainer .num_sanity_val_batches == expected_sanity_steps
154- assert pbar .val_progress_bar .total_values == expected_sanity_steps
155- assert pbar .val_progress_bar .n_values == list (range (num_sanity_val_steps + 1 )) * num_dl
156- assert pbar .val_progress_bar .descriptions == [f"Sanity Checking DataLoader { i } : " for i in range (num_dl )]
157-
158- # fit
159- trainer = Trainer (default_root_dir = tmp_path , max_epochs = 1 )
160- pbar = trainer .progress_bar_callback
161- with mock .patch ("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm" , MockTqdm ):
162- trainer .fit (model )
163-
164- n = trainer .num_training_batches
165- m = trainer .num_val_batches
166- assert len (trainer .train_dataloader ) == n
167- # train progress bar should have reached the end
168- assert pbar .train_progress_bar .total == n
169- assert pbar .train_progress_bar .n == n
170- assert pbar .train_progress_bar .leave
171-
172- # check val progress bar total
173- assert pbar .val_progress_bar .total_values == m
174- assert pbar .val_progress_bar .n_values == list (range (m [0 ] + 1 )) * num_dl
175- assert pbar .val_progress_bar .descriptions == [f"Validation DataLoader { i } : " for i in range (num_dl )]
176- assert not pbar .val_progress_bar .leave
177-
178- # validate
179- with mock .patch ("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm" , MockTqdm ):
180- trainer .validate (model )
181- assert trainer .num_val_batches == m
182- assert pbar .val_progress_bar .total_values == m
183- assert pbar .val_progress_bar .n_values == list (range (m [0 ] + 1 )) * num_dl
184- assert pbar .val_progress_bar .descriptions == [f"Validation DataLoader { i } : " for i in range (num_dl )]
185-
186- # test
187- with mock .patch ("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm" , MockTqdm ):
188- trainer .test (model )
189- assert pbar .test_progress_bar .leave
190- k = trainer .num_test_batches
191- assert pbar .test_progress_bar .total_values == k
192- assert pbar .test_progress_bar .n_values == list (range (k [0 ] + 1 )) * num_dl
193- assert pbar .test_progress_bar .descriptions == [f"Testing DataLoader { i } : " for i in range (num_dl )]
194- assert pbar .test_progress_bar .leave
195-
196- # predict
197- with mock .patch ("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm" , MockTqdm ):
198- trainer .predict (model )
199- assert pbar .predict_progress_bar .leave
200- k = trainer .num_predict_batches
201- assert pbar .predict_progress_bar .total_values == k
202- assert pbar .predict_progress_bar .n_values == list (range (k [0 ] + 1 )) * num_dl
203- assert pbar .predict_progress_bar .descriptions == [f"Predicting DataLoader { i } : " for i in range (num_dl )]
116+
117+ class CustomModel (BoringModel ):
118+ def _get_dataloaders (self ):
119+ dls = [DataLoader (RandomDataset (32 , 64 )), DataLoader (RandomDataset (32 , 64 ))]
120+ return dls [0 ] if num_dl == 1 else dls
121+
122+ def val_dataloader (self ):
123+ return self ._get_dataloaders ()
124+
125+ def test_dataloader (self ):
126+ return self ._get_dataloaders ()
127+
128+ def predict_dataloader (self ):
129+ return self ._get_dataloaders ()
130+
131+ def validation_step (self , batch , batch_idx , dataloader_idx = 0 ):
132+ return
133+
134+ def test_step (self , batch , batch_idx , dataloader_idx = 0 ):
135+ return
136+
137+ def predict_step (self , batch , batch_idx , dataloader_idx = 0 ):
138+ return
139+
140+ model = CustomModel ()
141+
142+ # check the sanity dataloaders
143+ num_sanity_val_steps = 4
144+ trainer = Trainer (
145+ default_root_dir = tmp_path , max_epochs = 1 , limit_train_batches = 0 , num_sanity_val_steps = num_sanity_val_steps
146+ )
147+ pbar = trainer .progress_bar_callback
148+ with mock .patch ("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm" , MockTqdm ):
149+ trainer .fit (model )
150+
151+ expected_sanity_steps = [num_sanity_val_steps ] * num_dl
152+ assert not pbar .val_progress_bar .leave
153+ assert trainer .num_sanity_val_batches == expected_sanity_steps
154+ assert pbar .val_progress_bar .total_values == expected_sanity_steps
155+ assert pbar .val_progress_bar .n_values == list (range (num_sanity_val_steps + 1 )) * num_dl
156+ assert pbar .val_progress_bar .descriptions == [f"Sanity Checking DataLoader { i } : " for i in range (num_dl )]
157+
158+ # fit
159+ trainer = Trainer (default_root_dir = tmp_path , max_epochs = 1 )
160+ pbar = trainer .progress_bar_callback
161+ with mock .patch ("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm" , MockTqdm ):
162+ trainer .fit (model )
163+
164+ n = trainer .num_training_batches
165+ m = trainer .num_val_batches
166+ assert len (trainer .train_dataloader ) == n
167+ # train progress bar should have reached the end
168+ assert pbar .train_progress_bar .total == n
169+ assert pbar .train_progress_bar .n == n
170+ assert pbar .train_progress_bar .leave
171+
172+ # check val progress bar total
173+ assert pbar .val_progress_bar .total_values == m
174+ assert pbar .val_progress_bar .n_values == list (range (m [0 ] + 1 )) * num_dl
175+ assert pbar .val_progress_bar .descriptions == [f"Validation DataLoader { i } : " for i in range (num_dl )]
176+ assert not pbar .val_progress_bar .leave
177+
178+ # validate
179+ with mock .patch ("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm" , MockTqdm ):
180+ trainer .validate (model )
181+ assert trainer .num_val_batches == m
182+ assert pbar .val_progress_bar .total_values == m
183+ assert pbar .val_progress_bar .n_values == list (range (m [0 ] + 1 )) * num_dl
184+ assert pbar .val_progress_bar .descriptions == [f"Validation DataLoader { i } : " for i in range (num_dl )]
185+
186+ # test
187+ with mock .patch ("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm" , MockTqdm ):
188+ trainer .test (model )
189+ assert pbar .test_progress_bar .leave
190+ k = trainer .num_test_batches
191+ assert pbar .test_progress_bar .total_values == k
192+ assert pbar .test_progress_bar .n_values == list (range (k [0 ] + 1 )) * num_dl
193+ assert pbar .test_progress_bar .descriptions == [f"Testing DataLoader { i } : " for i in range (num_dl )]
194+ assert pbar .test_progress_bar .leave
195+
196+ # predict
197+ with mock .patch ("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm" , MockTqdm ):
198+ trainer .predict (model )
199+ assert pbar .predict_progress_bar .leave
200+ k = trainer .num_predict_batches
201+ assert pbar .predict_progress_bar .total_values == k
202+ assert pbar .predict_progress_bar .n_values == list (range (k [0 ] + 1 )) * num_dl
203+ assert pbar .predict_progress_bar .descriptions == [f"Predicting DataLoader { i } : " for i in range (num_dl )]
204204 assert pbar .predict_progress_bar .leave
205205
206206
@@ -414,24 +414,30 @@ def test_test_progress_bar_update_amount(tmp_path, test_batches: int, refresh_ra
414414 assert progress_bar .test_progress_bar .n_values == updates
415415
416416
417+ @patch ("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE" , False )
417418def test_tensor_to_float_conversion (tmp_path ):
418419 """Check tensor gets converted to float."""
419- with patch ("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE" , False ):
420-
421- class TestModel (BoringModel ):
422- def training_step (self , batch , batch_idx ):
423- self .log ("a" , torch .tensor (0.123 ), prog_bar = True , on_epoch = False )
424- self .log ("b" , torch .tensor ([1 ]), prog_bar = True , on_epoch = False )
425- self .log ("c" , 2 , prog_bar = True , on_epoch = False )
426- return super ().training_step (batch , batch_idx )
427-
428- trainer = Trainer (
429- default_root_dir = tmp_path , max_epochs = 1 , limit_train_batches = 2 , logger = False , enable_checkpointing = False
430- )
420+
421+ class TestModel (BoringModel ):
422+ def training_step (self , batch , batch_idx ):
423+ self .log ("a" , torch .tensor (0.123 ), prog_bar = True , on_epoch = False )
424+ self .log ("b" , torch .tensor ([1 ]), prog_bar = True , on_epoch = False )
425+ self .log ("c" , 2 , prog_bar = True , on_epoch = False )
426+ return super ().training_step (batch , batch_idx )
427+
428+ trainer = Trainer (
429+ default_root_dir = tmp_path , max_epochs = 1 , limit_train_batches = 2 , logger = False , enable_checkpointing = False
430+ )
431+
432+ with mock .patch .object (sys .stdout , "write" ) as mock_write :
431433 trainer .fit (TestModel ())
434+ bar_updates = "" .join (call .args [0 ] for call in mock_write .call_args_list )
435+ assert "a=0.123" in bar_updates
436+ assert "b=1.000" in bar_updates
437+ assert "c=2.000" in bar_updates
432438
433- torch .testing .assert_close (trainer .progress_bar_metrics ["a" ], 0.123 )
434- assert trainer .progress_bar_metrics ["b" ] == 1.0
439+ torch .testing .assert_close (trainer .progress_bar_metrics ["a" ], 0.123 )
440+ assert trainer .progress_bar_metrics ["b" ] == 1.0
435441 assert trainer .progress_bar_metrics ["c" ] == 2.0
436442 pbar = trainer .progress_bar_callback .train_progress_bar
437443 actual = str (pbar .postfix )
0 commit comments