|
18 | 18 | from collections import defaultdict
|
19 | 19 | from typing import Union
|
20 | 20 | from unittest import mock
|
21 |
| -from unittest.mock import ANY, Mock, PropertyMock, call |
| 21 | +from unittest.mock import ANY, Mock, PropertyMock, call, patch |
22 | 22 |
|
23 | 23 | import pytest
|
24 | 24 | import torch
|
@@ -112,120 +112,122 @@ def test_tqdm_progress_bar_misconfiguration():
|
112 | 112 | @pytest.mark.parametrize("num_dl", [1, 2])
|
113 | 113 | def test_tqdm_progress_bar_totals(tmp_path, num_dl):
|
114 | 114 | """Test that the progress finishes with the correct total steps processed."""
|
115 |
| - |
116 |
| - class CustomModel(BoringModel): |
117 |
| - def _get_dataloaders(self): |
118 |
| - dls = [DataLoader(RandomDataset(32, 64)), DataLoader(RandomDataset(32, 64))] |
119 |
| - return dls[0] if num_dl == 1 else dls |
120 |
| - |
121 |
| - def val_dataloader(self): |
122 |
| - return self._get_dataloaders() |
123 |
| - |
124 |
| - def test_dataloader(self): |
125 |
| - return self._get_dataloaders() |
126 |
| - |
127 |
| - def predict_dataloader(self): |
128 |
| - return self._get_dataloaders() |
129 |
| - |
130 |
| - def validation_step(self, batch, batch_idx, dataloader_idx=0): |
131 |
| - return |
132 |
| - |
133 |
| - def test_step(self, batch, batch_idx, dataloader_idx=0): |
134 |
| - return |
135 |
| - |
136 |
| - def predict_step(self, batch, batch_idx, dataloader_idx=0): |
137 |
| - return |
138 |
| - |
139 |
| - model = CustomModel() |
140 |
| - |
141 |
| - # check the sanity dataloaders |
142 |
| - num_sanity_val_steps = 4 |
143 |
| - trainer = Trainer( |
144 |
| - default_root_dir=tmp_path, max_epochs=1, limit_train_batches=0, num_sanity_val_steps=num_sanity_val_steps |
145 |
| - ) |
146 |
| - pbar = trainer.progress_bar_callback |
147 |
| - with mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm", MockTqdm): |
148 |
| - trainer.fit(model) |
149 |
| - |
150 |
| - expected_sanity_steps = [num_sanity_val_steps] * num_dl |
151 |
| - assert not pbar.val_progress_bar.leave |
152 |
| - assert trainer.num_sanity_val_batches == expected_sanity_steps |
153 |
| - assert pbar.val_progress_bar.total_values == expected_sanity_steps |
154 |
| - assert pbar.val_progress_bar.n_values == list(range(num_sanity_val_steps + 1)) * num_dl |
155 |
| - assert pbar.val_progress_bar.descriptions == [f"Sanity Checking DataLoader {i}: " for i in range(num_dl)] |
156 |
| - |
157 |
| - # fit |
158 |
| - trainer = Trainer(default_root_dir=tmp_path, max_epochs=1) |
159 |
| - pbar = trainer.progress_bar_callback |
160 |
| - with mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm", MockTqdm): |
161 |
| - trainer.fit(model) |
162 |
| - |
163 |
| - n = trainer.num_training_batches |
164 |
| - m = trainer.num_val_batches |
165 |
| - assert len(trainer.train_dataloader) == n |
166 |
| - # train progress bar should have reached the end |
167 |
| - assert pbar.train_progress_bar.total == n |
168 |
| - assert pbar.train_progress_bar.n == n |
169 |
| - assert pbar.train_progress_bar.leave |
170 |
| - |
171 |
| - # check val progress bar total |
172 |
| - assert pbar.val_progress_bar.total_values == m |
173 |
| - assert pbar.val_progress_bar.n_values == list(range(m[0] + 1)) * num_dl |
174 |
| - assert pbar.val_progress_bar.descriptions == [f"Validation DataLoader {i}: " for i in range(num_dl)] |
175 |
| - assert not pbar.val_progress_bar.leave |
176 |
| - |
177 |
| - # validate |
178 |
| - with mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm", MockTqdm): |
179 |
| - trainer.validate(model) |
180 |
| - assert trainer.num_val_batches == m |
181 |
| - assert pbar.val_progress_bar.total_values == m |
182 |
| - assert pbar.val_progress_bar.n_values == list(range(m[0] + 1)) * num_dl |
183 |
| - assert pbar.val_progress_bar.descriptions == [f"Validation DataLoader {i}: " for i in range(num_dl)] |
184 |
| - |
185 |
| - # test |
186 |
| - with mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm", MockTqdm): |
187 |
| - trainer.test(model) |
188 |
| - assert pbar.test_progress_bar.leave |
189 |
| - k = trainer.num_test_batches |
190 |
| - assert pbar.test_progress_bar.total_values == k |
191 |
| - assert pbar.test_progress_bar.n_values == list(range(k[0] + 1)) * num_dl |
192 |
| - assert pbar.test_progress_bar.descriptions == [f"Testing DataLoader {i}: " for i in range(num_dl)] |
193 |
| - assert pbar.test_progress_bar.leave |
194 |
| - |
195 |
| - # predict |
196 |
| - with mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm", MockTqdm): |
197 |
| - trainer.predict(model) |
198 |
| - assert pbar.predict_progress_bar.leave |
199 |
| - k = trainer.num_predict_batches |
200 |
| - assert pbar.predict_progress_bar.total_values == k |
201 |
| - assert pbar.predict_progress_bar.n_values == list(range(k[0] + 1)) * num_dl |
202 |
| - assert pbar.predict_progress_bar.descriptions == [f"Predicting DataLoader {i}: " for i in range(num_dl)] |
| 115 | + with patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False): |
| 116 | + |
| 117 | + class CustomModel(BoringModel): |
| 118 | + def _get_dataloaders(self): |
| 119 | + dls = [DataLoader(RandomDataset(32, 64)), DataLoader(RandomDataset(32, 64))] |
| 120 | + return dls[0] if num_dl == 1 else dls |
| 121 | + |
| 122 | + def val_dataloader(self): |
| 123 | + return self._get_dataloaders() |
| 124 | + |
| 125 | + def test_dataloader(self): |
| 126 | + return self._get_dataloaders() |
| 127 | + |
| 128 | + def predict_dataloader(self): |
| 129 | + return self._get_dataloaders() |
| 130 | + |
| 131 | + def validation_step(self, batch, batch_idx, dataloader_idx=0): |
| 132 | + return |
| 133 | + |
| 134 | + def test_step(self, batch, batch_idx, dataloader_idx=0): |
| 135 | + return |
| 136 | + |
| 137 | + def predict_step(self, batch, batch_idx, dataloader_idx=0): |
| 138 | + return |
| 139 | + |
| 140 | + model = CustomModel() |
| 141 | + |
| 142 | + # check the sanity dataloaders |
| 143 | + num_sanity_val_steps = 4 |
| 144 | + trainer = Trainer( |
| 145 | + default_root_dir=tmp_path, max_epochs=1, limit_train_batches=0, num_sanity_val_steps=num_sanity_val_steps |
| 146 | + ) |
| 147 | + pbar = trainer.progress_bar_callback |
| 148 | + with mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm", MockTqdm): |
| 149 | + trainer.fit(model) |
| 150 | + |
| 151 | + expected_sanity_steps = [num_sanity_val_steps] * num_dl |
| 152 | + assert not pbar.val_progress_bar.leave |
| 153 | + assert trainer.num_sanity_val_batches == expected_sanity_steps |
| 154 | + assert pbar.val_progress_bar.total_values == expected_sanity_steps |
| 155 | + assert pbar.val_progress_bar.n_values == list(range(num_sanity_val_steps + 1)) * num_dl |
| 156 | + assert pbar.val_progress_bar.descriptions == [f"Sanity Checking DataLoader {i}: " for i in range(num_dl)] |
| 157 | + |
| 158 | + # fit |
| 159 | + trainer = Trainer(default_root_dir=tmp_path, max_epochs=1) |
| 160 | + pbar = trainer.progress_bar_callback |
| 161 | + with mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm", MockTqdm): |
| 162 | + trainer.fit(model) |
| 163 | + |
| 164 | + n = trainer.num_training_batches |
| 165 | + m = trainer.num_val_batches |
| 166 | + assert len(trainer.train_dataloader) == n |
| 167 | + # train progress bar should have reached the end |
| 168 | + assert pbar.train_progress_bar.total == n |
| 169 | + assert pbar.train_progress_bar.n == n |
| 170 | + assert pbar.train_progress_bar.leave |
| 171 | + |
| 172 | + # check val progress bar total |
| 173 | + assert pbar.val_progress_bar.total_values == m |
| 174 | + assert pbar.val_progress_bar.n_values == list(range(m[0] + 1)) * num_dl |
| 175 | + assert pbar.val_progress_bar.descriptions == [f"Validation DataLoader {i}: " for i in range(num_dl)] |
| 176 | + assert not pbar.val_progress_bar.leave |
| 177 | + |
| 178 | + # validate |
| 179 | + with mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm", MockTqdm): |
| 180 | + trainer.validate(model) |
| 181 | + assert trainer.num_val_batches == m |
| 182 | + assert pbar.val_progress_bar.total_values == m |
| 183 | + assert pbar.val_progress_bar.n_values == list(range(m[0] + 1)) * num_dl |
| 184 | + assert pbar.val_progress_bar.descriptions == [f"Validation DataLoader {i}: " for i in range(num_dl)] |
| 185 | + |
| 186 | + # test |
| 187 | + with mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm", MockTqdm): |
| 188 | + trainer.test(model) |
| 189 | + assert pbar.test_progress_bar.leave |
| 190 | + k = trainer.num_test_batches |
| 191 | + assert pbar.test_progress_bar.total_values == k |
| 192 | + assert pbar.test_progress_bar.n_values == list(range(k[0] + 1)) * num_dl |
| 193 | + assert pbar.test_progress_bar.descriptions == [f"Testing DataLoader {i}: " for i in range(num_dl)] |
| 194 | + assert pbar.test_progress_bar.leave |
| 195 | + |
| 196 | + # predict |
| 197 | + with mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm", MockTqdm): |
| 198 | + trainer.predict(model) |
| 199 | + assert pbar.predict_progress_bar.leave |
| 200 | + k = trainer.num_predict_batches |
| 201 | + assert pbar.predict_progress_bar.total_values == k |
| 202 | + assert pbar.predict_progress_bar.n_values == list(range(k[0] + 1)) * num_dl |
| 203 | + assert pbar.predict_progress_bar.descriptions == [f"Predicting DataLoader {i}: " for i in range(num_dl)] |
203 | 204 | assert pbar.predict_progress_bar.leave
|
204 | 205 |
|
205 | 206 |
|
206 | 207 | def test_tqdm_progress_bar_fast_dev_run(tmp_path):
|
207 |
| - model = BoringModel() |
| 208 | + with patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False): |
| 209 | + model = BoringModel() |
208 | 210 |
|
209 |
| - trainer = Trainer(default_root_dir=tmp_path, fast_dev_run=True) |
| 211 | + trainer = Trainer(default_root_dir=tmp_path, fast_dev_run=True) |
210 | 212 |
|
211 |
| - trainer.fit(model) |
| 213 | + trainer.fit(model) |
212 | 214 |
|
213 |
| - pbar = trainer.progress_bar_callback |
| 215 | + pbar = trainer.progress_bar_callback |
214 | 216 |
|
215 |
| - assert pbar.val_progress_bar.n == 1 |
216 |
| - assert pbar.val_progress_bar.total == 1 |
| 217 | + assert pbar.val_progress_bar.n == 1 |
| 218 | + assert pbar.val_progress_bar.total == 1 |
217 | 219 |
|
218 |
| - # the train progress bar should display 1 batch |
219 |
| - assert pbar.train_progress_bar.total == 1 |
220 |
| - assert pbar.train_progress_bar.n == 1 |
| 220 | + # the train progress bar should display 1 batch |
| 221 | + assert pbar.train_progress_bar.total == 1 |
| 222 | + assert pbar.train_progress_bar.n == 1 |
221 | 223 |
|
222 |
| - trainer.validate(model) |
| 224 | + trainer.validate(model) |
223 | 225 |
|
224 |
| - # the validation progress bar should display 1 batch |
225 |
| - assert pbar.val_progress_bar.total == 1 |
226 |
| - assert pbar.val_progress_bar.n == 1 |
| 226 | + # the validation progress bar should display 1 batch |
| 227 | + assert pbar.val_progress_bar.total == 1 |
| 228 | + assert pbar.val_progress_bar.n == 1 |
227 | 229 |
|
228 |
| - trainer.test(model) |
| 230 | + trainer.test(model) |
229 | 231 |
|
230 | 232 | # the test progress bar should display 1 batch
|
231 | 233 | assert pbar.test_progress_bar.total == 1
|
@@ -325,14 +327,15 @@ def test_tqdm_progress_bar_default_value(tmp_path):
|
325 | 327 | @mock.patch.dict(os.environ, {"COLAB_GPU": "1"})
|
326 | 328 | def test_tqdm_progress_bar_value_on_colab(tmp_path):
|
327 | 329 | """Test that Trainer will override the default in Google COLAB."""
|
328 |
| - trainer = Trainer(default_root_dir=tmp_path) |
329 |
| - assert trainer.progress_bar_callback.refresh_rate == 20 |
| 330 | + with patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False): |
| 331 | + trainer = Trainer(default_root_dir=tmp_path) |
| 332 | + assert trainer.progress_bar_callback.refresh_rate == 20 |
330 | 333 |
|
331 |
| - trainer = Trainer(default_root_dir=tmp_path, callbacks=TQDMProgressBar()) |
332 |
| - assert trainer.progress_bar_callback.refresh_rate == 20 |
| 334 | + trainer = Trainer(default_root_dir=tmp_path, callbacks=TQDMProgressBar()) |
| 335 | + assert trainer.progress_bar_callback.refresh_rate == 20 |
333 | 336 |
|
334 |
| - trainer = Trainer(default_root_dir=tmp_path, callbacks=TQDMProgressBar(refresh_rate=19)) |
335 |
| - assert trainer.progress_bar_callback.refresh_rate == 19 |
| 337 | + trainer = Trainer(default_root_dir=tmp_path, callbacks=TQDMProgressBar(refresh_rate=19)) |
| 338 | + assert trainer.progress_bar_callback.refresh_rate == 19 |
336 | 339 |
|
337 | 340 |
|
338 | 341 | @pytest.mark.parametrize(
|
@@ -413,21 +416,22 @@ def test_test_progress_bar_update_amount(tmp_path, test_batches: int, refresh_ra
|
413 | 416 |
|
414 | 417 | def test_tensor_to_float_conversion(tmp_path):
|
415 | 418 | """Check tensor gets converted to float."""
|
416 |
| - |
417 |
| - class TestModel(BoringModel): |
418 |
| - def training_step(self, batch, batch_idx): |
419 |
| - self.log("a", torch.tensor(0.123), prog_bar=True, on_epoch=False) |
420 |
| - self.log("b", torch.tensor([1]), prog_bar=True, on_epoch=False) |
421 |
| - self.log("c", 2, prog_bar=True, on_epoch=False) |
422 |
| - return super().training_step(batch, batch_idx) |
423 |
| - |
424 |
| - trainer = Trainer( |
425 |
| - default_root_dir=tmp_path, max_epochs=1, limit_train_batches=2, logger=False, enable_checkpointing=False |
426 |
| - ) |
427 |
| - trainer.fit(TestModel()) |
428 |
| - |
429 |
| - torch.testing.assert_close(trainer.progress_bar_metrics["a"], 0.123) |
430 |
| - assert trainer.progress_bar_metrics["b"] == 1.0 |
| 419 | + with patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False): |
| 420 | + |
| 421 | + class TestModel(BoringModel): |
| 422 | + def training_step(self, batch, batch_idx): |
| 423 | + self.log("a", torch.tensor(0.123), prog_bar=True, on_epoch=False) |
| 424 | + self.log("b", torch.tensor([1]), prog_bar=True, on_epoch=False) |
| 425 | + self.log("c", 2, prog_bar=True, on_epoch=False) |
| 426 | + return super().training_step(batch, batch_idx) |
| 427 | + |
| 428 | + trainer = Trainer( |
| 429 | + default_root_dir=tmp_path, max_epochs=1, limit_train_batches=2, logger=False, enable_checkpointing=False |
| 430 | + ) |
| 431 | + trainer.fit(TestModel()) |
| 432 | + |
| 433 | + torch.testing.assert_close(trainer.progress_bar_metrics["a"], 0.123) |
| 434 | + assert trainer.progress_bar_metrics["b"] == 1.0 |
431 | 435 | assert trainer.progress_bar_metrics["c"] == 2.0
|
432 | 436 | pbar = trainer.progress_bar_callback.train_progress_bar
|
433 | 437 | actual = str(pbar.postfix)
|
|
0 commit comments