@@ -109,98 +109,98 @@ def test_tqdm_progress_bar_misconfiguration():
109
109
Trainer (callbacks = TQDMProgressBar (), enable_progress_bar = False )
110
110
111
111
112
+ @patch ("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE" , False )
112
113
@pytest .mark .parametrize ("num_dl" , [1 , 2 ])
113
114
def test_tqdm_progress_bar_totals (tmp_path , num_dl ):
114
115
"""Test that the progress finishes with the correct total steps processed."""
115
- with patch ("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE" , False ):
116
-
117
- class CustomModel (BoringModel ):
118
- def _get_dataloaders (self ):
119
- dls = [DataLoader (RandomDataset (32 , 64 )), DataLoader (RandomDataset (32 , 64 ))]
120
- return dls [0 ] if num_dl == 1 else dls
121
-
122
- def val_dataloader (self ):
123
- return self ._get_dataloaders ()
124
-
125
- def test_dataloader (self ):
126
- return self ._get_dataloaders ()
127
-
128
- def predict_dataloader (self ):
129
- return self ._get_dataloaders ()
130
-
131
- def validation_step (self , batch , batch_idx , dataloader_idx = 0 ):
132
- return
133
-
134
- def test_step (self , batch , batch_idx , dataloader_idx = 0 ):
135
- return
136
-
137
- def predict_step (self , batch , batch_idx , dataloader_idx = 0 ):
138
- return
139
-
140
- model = CustomModel ()
141
-
142
- # check the sanity dataloaders
143
- num_sanity_val_steps = 4
144
- trainer = Trainer (
145
- default_root_dir = tmp_path , max_epochs = 1 , limit_train_batches = 0 , num_sanity_val_steps = num_sanity_val_steps
146
- )
147
- pbar = trainer .progress_bar_callback
148
- with mock .patch ("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm" , MockTqdm ):
149
- trainer .fit (model )
150
-
151
- expected_sanity_steps = [num_sanity_val_steps ] * num_dl
152
- assert not pbar .val_progress_bar .leave
153
- assert trainer .num_sanity_val_batches == expected_sanity_steps
154
- assert pbar .val_progress_bar .total_values == expected_sanity_steps
155
- assert pbar .val_progress_bar .n_values == list (range (num_sanity_val_steps + 1 )) * num_dl
156
- assert pbar .val_progress_bar .descriptions == [f"Sanity Checking DataLoader { i } : " for i in range (num_dl )]
157
-
158
- # fit
159
- trainer = Trainer (default_root_dir = tmp_path , max_epochs = 1 )
160
- pbar = trainer .progress_bar_callback
161
- with mock .patch ("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm" , MockTqdm ):
162
- trainer .fit (model )
163
-
164
- n = trainer .num_training_batches
165
- m = trainer .num_val_batches
166
- assert len (trainer .train_dataloader ) == n
167
- # train progress bar should have reached the end
168
- assert pbar .train_progress_bar .total == n
169
- assert pbar .train_progress_bar .n == n
170
- assert pbar .train_progress_bar .leave
171
-
172
- # check val progress bar total
173
- assert pbar .val_progress_bar .total_values == m
174
- assert pbar .val_progress_bar .n_values == list (range (m [0 ] + 1 )) * num_dl
175
- assert pbar .val_progress_bar .descriptions == [f"Validation DataLoader { i } : " for i in range (num_dl )]
176
- assert not pbar .val_progress_bar .leave
177
-
178
- # validate
179
- with mock .patch ("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm" , MockTqdm ):
180
- trainer .validate (model )
181
- assert trainer .num_val_batches == m
182
- assert pbar .val_progress_bar .total_values == m
183
- assert pbar .val_progress_bar .n_values == list (range (m [0 ] + 1 )) * num_dl
184
- assert pbar .val_progress_bar .descriptions == [f"Validation DataLoader { i } : " for i in range (num_dl )]
185
-
186
- # test
187
- with mock .patch ("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm" , MockTqdm ):
188
- trainer .test (model )
189
- assert pbar .test_progress_bar .leave
190
- k = trainer .num_test_batches
191
- assert pbar .test_progress_bar .total_values == k
192
- assert pbar .test_progress_bar .n_values == list (range (k [0 ] + 1 )) * num_dl
193
- assert pbar .test_progress_bar .descriptions == [f"Testing DataLoader { i } : " for i in range (num_dl )]
194
- assert pbar .test_progress_bar .leave
195
-
196
- # predict
197
- with mock .patch ("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm" , MockTqdm ):
198
- trainer .predict (model )
199
- assert pbar .predict_progress_bar .leave
200
- k = trainer .num_predict_batches
201
- assert pbar .predict_progress_bar .total_values == k
202
- assert pbar .predict_progress_bar .n_values == list (range (k [0 ] + 1 )) * num_dl
203
- assert pbar .predict_progress_bar .descriptions == [f"Predicting DataLoader { i } : " for i in range (num_dl )]
116
+
117
+ class CustomModel (BoringModel ):
118
+ def _get_dataloaders (self ):
119
+ dls = [DataLoader (RandomDataset (32 , 64 )), DataLoader (RandomDataset (32 , 64 ))]
120
+ return dls [0 ] if num_dl == 1 else dls
121
+
122
+ def val_dataloader (self ):
123
+ return self ._get_dataloaders ()
124
+
125
+ def test_dataloader (self ):
126
+ return self ._get_dataloaders ()
127
+
128
+ def predict_dataloader (self ):
129
+ return self ._get_dataloaders ()
130
+
131
+ def validation_step (self , batch , batch_idx , dataloader_idx = 0 ):
132
+ return
133
+
134
+ def test_step (self , batch , batch_idx , dataloader_idx = 0 ):
135
+ return
136
+
137
+ def predict_step (self , batch , batch_idx , dataloader_idx = 0 ):
138
+ return
139
+
140
+ model = CustomModel ()
141
+
142
+ # check the sanity dataloaders
143
+ num_sanity_val_steps = 4
144
+ trainer = Trainer (
145
+ default_root_dir = tmp_path , max_epochs = 1 , limit_train_batches = 0 , num_sanity_val_steps = num_sanity_val_steps
146
+ )
147
+ pbar = trainer .progress_bar_callback
148
+ with mock .patch ("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm" , MockTqdm ):
149
+ trainer .fit (model )
150
+
151
+ expected_sanity_steps = [num_sanity_val_steps ] * num_dl
152
+ assert not pbar .val_progress_bar .leave
153
+ assert trainer .num_sanity_val_batches == expected_sanity_steps
154
+ assert pbar .val_progress_bar .total_values == expected_sanity_steps
155
+ assert pbar .val_progress_bar .n_values == list (range (num_sanity_val_steps + 1 )) * num_dl
156
+ assert pbar .val_progress_bar .descriptions == [f"Sanity Checking DataLoader { i } : " for i in range (num_dl )]
157
+
158
+ # fit
159
+ trainer = Trainer (default_root_dir = tmp_path , max_epochs = 1 )
160
+ pbar = trainer .progress_bar_callback
161
+ with mock .patch ("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm" , MockTqdm ):
162
+ trainer .fit (model )
163
+
164
+ n = trainer .num_training_batches
165
+ m = trainer .num_val_batches
166
+ assert len (trainer .train_dataloader ) == n
167
+ # train progress bar should have reached the end
168
+ assert pbar .train_progress_bar .total == n
169
+ assert pbar .train_progress_bar .n == n
170
+ assert pbar .train_progress_bar .leave
171
+
172
+ # check val progress bar total
173
+ assert pbar .val_progress_bar .total_values == m
174
+ assert pbar .val_progress_bar .n_values == list (range (m [0 ] + 1 )) * num_dl
175
+ assert pbar .val_progress_bar .descriptions == [f"Validation DataLoader { i } : " for i in range (num_dl )]
176
+ assert not pbar .val_progress_bar .leave
177
+
178
+ # validate
179
+ with mock .patch ("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm" , MockTqdm ):
180
+ trainer .validate (model )
181
+ assert trainer .num_val_batches == m
182
+ assert pbar .val_progress_bar .total_values == m
183
+ assert pbar .val_progress_bar .n_values == list (range (m [0 ] + 1 )) * num_dl
184
+ assert pbar .val_progress_bar .descriptions == [f"Validation DataLoader { i } : " for i in range (num_dl )]
185
+
186
+ # test
187
+ with mock .patch ("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm" , MockTqdm ):
188
+ trainer .test (model )
189
+ assert pbar .test_progress_bar .leave
190
+ k = trainer .num_test_batches
191
+ assert pbar .test_progress_bar .total_values == k
192
+ assert pbar .test_progress_bar .n_values == list (range (k [0 ] + 1 )) * num_dl
193
+ assert pbar .test_progress_bar .descriptions == [f"Testing DataLoader { i } : " for i in range (num_dl )]
194
+ assert pbar .test_progress_bar .leave
195
+
196
+ # predict
197
+ with mock .patch ("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm" , MockTqdm ):
198
+ trainer .predict (model )
199
+ assert pbar .predict_progress_bar .leave
200
+ k = trainer .num_predict_batches
201
+ assert pbar .predict_progress_bar .total_values == k
202
+ assert pbar .predict_progress_bar .n_values == list (range (k [0 ] + 1 )) * num_dl
203
+ assert pbar .predict_progress_bar .descriptions == [f"Predicting DataLoader { i } : " for i in range (num_dl )]
204
204
assert pbar .predict_progress_bar .leave
205
205
206
206
@@ -414,24 +414,30 @@ def test_test_progress_bar_update_amount(tmp_path, test_batches: int, refresh_ra
414
414
assert progress_bar .test_progress_bar .n_values == updates
415
415
416
416
417
+ @patch ("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE" , False )
417
418
def test_tensor_to_float_conversion (tmp_path ):
418
419
"""Check tensor gets converted to float."""
419
- with patch ("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE" , False ):
420
-
421
- class TestModel (BoringModel ):
422
- def training_step (self , batch , batch_idx ):
423
- self .log ("a" , torch .tensor (0.123 ), prog_bar = True , on_epoch = False )
424
- self .log ("b" , torch .tensor ([1 ]), prog_bar = True , on_epoch = False )
425
- self .log ("c" , 2 , prog_bar = True , on_epoch = False )
426
- return super ().training_step (batch , batch_idx )
427
-
428
- trainer = Trainer (
429
- default_root_dir = tmp_path , max_epochs = 1 , limit_train_batches = 2 , logger = False , enable_checkpointing = False
430
- )
420
+
421
+ class TestModel (BoringModel ):
422
+ def training_step (self , batch , batch_idx ):
423
+ self .log ("a" , torch .tensor (0.123 ), prog_bar = True , on_epoch = False )
424
+ self .log ("b" , torch .tensor ([1 ]), prog_bar = True , on_epoch = False )
425
+ self .log ("c" , 2 , prog_bar = True , on_epoch = False )
426
+ return super ().training_step (batch , batch_idx )
427
+
428
+ trainer = Trainer (
429
+ default_root_dir = tmp_path , max_epochs = 1 , limit_train_batches = 2 , logger = False , enable_checkpointing = False
430
+ )
431
+
432
+ with mock .patch .object (sys .stdout , "write" ) as mock_write :
431
433
trainer .fit (TestModel ())
434
+ bar_updates = "" .join (call .args [0 ] for call in mock_write .call_args_list )
435
+ assert "a=0.123" in bar_updates
436
+ assert "b=1.000" in bar_updates
437
+ assert "c=2.000" in bar_updates
432
438
433
- torch .testing .assert_close (trainer .progress_bar_metrics ["a" ], 0.123 )
434
- assert trainer .progress_bar_metrics ["b" ] == 1.0
439
+ torch .testing .assert_close (trainer .progress_bar_metrics ["a" ], 0.123 )
440
+ assert trainer .progress_bar_metrics ["b" ] == 1.0
435
441
assert trainer .progress_bar_metrics ["c" ] == 2.0
436
442
pbar = trainer .progress_bar_callback .train_progress_bar
437
443
actual = str (pbar .postfix )
0 commit comments