|
18 | 18 | from collections import defaultdict |
19 | 19 | from typing import Union |
20 | 20 | from unittest import mock |
21 | | -from unittest.mock import ANY, Mock, PropertyMock, call |
| 21 | +from unittest.mock import ANY, Mock, PropertyMock, call, patch |
22 | 22 |
|
23 | 23 | import pytest |
24 | 24 | import torch |
@@ -112,120 +112,122 @@ def test_tqdm_progress_bar_misconfiguration(): |
112 | 112 | @pytest.mark.parametrize("num_dl", [1, 2]) |
113 | 113 | def test_tqdm_progress_bar_totals(tmp_path, num_dl): |
114 | 114 | """Test that the progress finishes with the correct total steps processed.""" |
115 | | - |
116 | | - class CustomModel(BoringModel): |
117 | | - def _get_dataloaders(self): |
118 | | - dls = [DataLoader(RandomDataset(32, 64)), DataLoader(RandomDataset(32, 64))] |
119 | | - return dls[0] if num_dl == 1 else dls |
120 | | - |
121 | | - def val_dataloader(self): |
122 | | - return self._get_dataloaders() |
123 | | - |
124 | | - def test_dataloader(self): |
125 | | - return self._get_dataloaders() |
126 | | - |
127 | | - def predict_dataloader(self): |
128 | | - return self._get_dataloaders() |
129 | | - |
130 | | - def validation_step(self, batch, batch_idx, dataloader_idx=0): |
131 | | - return |
132 | | - |
133 | | - def test_step(self, batch, batch_idx, dataloader_idx=0): |
134 | | - return |
135 | | - |
136 | | - def predict_step(self, batch, batch_idx, dataloader_idx=0): |
137 | | - return |
138 | | - |
139 | | - model = CustomModel() |
140 | | - |
141 | | - # check the sanity dataloaders |
142 | | - num_sanity_val_steps = 4 |
143 | | - trainer = Trainer( |
144 | | - default_root_dir=tmp_path, max_epochs=1, limit_train_batches=0, num_sanity_val_steps=num_sanity_val_steps |
145 | | - ) |
146 | | - pbar = trainer.progress_bar_callback |
147 | | - with mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm", MockTqdm): |
148 | | - trainer.fit(model) |
149 | | - |
150 | | - expected_sanity_steps = [num_sanity_val_steps] * num_dl |
151 | | - assert not pbar.val_progress_bar.leave |
152 | | - assert trainer.num_sanity_val_batches == expected_sanity_steps |
153 | | - assert pbar.val_progress_bar.total_values == expected_sanity_steps |
154 | | - assert pbar.val_progress_bar.n_values == list(range(num_sanity_val_steps + 1)) * num_dl |
155 | | - assert pbar.val_progress_bar.descriptions == [f"Sanity Checking DataLoader {i}: " for i in range(num_dl)] |
156 | | - |
157 | | - # fit |
158 | | - trainer = Trainer(default_root_dir=tmp_path, max_epochs=1) |
159 | | - pbar = trainer.progress_bar_callback |
160 | | - with mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm", MockTqdm): |
161 | | - trainer.fit(model) |
162 | | - |
163 | | - n = trainer.num_training_batches |
164 | | - m = trainer.num_val_batches |
165 | | - assert len(trainer.train_dataloader) == n |
166 | | - # train progress bar should have reached the end |
167 | | - assert pbar.train_progress_bar.total == n |
168 | | - assert pbar.train_progress_bar.n == n |
169 | | - assert pbar.train_progress_bar.leave |
170 | | - |
171 | | - # check val progress bar total |
172 | | - assert pbar.val_progress_bar.total_values == m |
173 | | - assert pbar.val_progress_bar.n_values == list(range(m[0] + 1)) * num_dl |
174 | | - assert pbar.val_progress_bar.descriptions == [f"Validation DataLoader {i}: " for i in range(num_dl)] |
175 | | - assert not pbar.val_progress_bar.leave |
176 | | - |
177 | | - # validate |
178 | | - with mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm", MockTqdm): |
179 | | - trainer.validate(model) |
180 | | - assert trainer.num_val_batches == m |
181 | | - assert pbar.val_progress_bar.total_values == m |
182 | | - assert pbar.val_progress_bar.n_values == list(range(m[0] + 1)) * num_dl |
183 | | - assert pbar.val_progress_bar.descriptions == [f"Validation DataLoader {i}: " for i in range(num_dl)] |
184 | | - |
185 | | - # test |
186 | | - with mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm", MockTqdm): |
187 | | - trainer.test(model) |
188 | | - assert pbar.test_progress_bar.leave |
189 | | - k = trainer.num_test_batches |
190 | | - assert pbar.test_progress_bar.total_values == k |
191 | | - assert pbar.test_progress_bar.n_values == list(range(k[0] + 1)) * num_dl |
192 | | - assert pbar.test_progress_bar.descriptions == [f"Testing DataLoader {i}: " for i in range(num_dl)] |
193 | | - assert pbar.test_progress_bar.leave |
194 | | - |
195 | | - # predict |
196 | | - with mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm", MockTqdm): |
197 | | - trainer.predict(model) |
198 | | - assert pbar.predict_progress_bar.leave |
199 | | - k = trainer.num_predict_batches |
200 | | - assert pbar.predict_progress_bar.total_values == k |
201 | | - assert pbar.predict_progress_bar.n_values == list(range(k[0] + 1)) * num_dl |
202 | | - assert pbar.predict_progress_bar.descriptions == [f"Predicting DataLoader {i}: " for i in range(num_dl)] |
| 115 | + with patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False): |
| 116 | + |
| 117 | + class CustomModel(BoringModel): |
| 118 | + def _get_dataloaders(self): |
| 119 | + dls = [DataLoader(RandomDataset(32, 64)), DataLoader(RandomDataset(32, 64))] |
| 120 | + return dls[0] if num_dl == 1 else dls |
| 121 | + |
| 122 | + def val_dataloader(self): |
| 123 | + return self._get_dataloaders() |
| 124 | + |
| 125 | + def test_dataloader(self): |
| 126 | + return self._get_dataloaders() |
| 127 | + |
| 128 | + def predict_dataloader(self): |
| 129 | + return self._get_dataloaders() |
| 130 | + |
| 131 | + def validation_step(self, batch, batch_idx, dataloader_idx=0): |
| 132 | + return |
| 133 | + |
| 134 | + def test_step(self, batch, batch_idx, dataloader_idx=0): |
| 135 | + return |
| 136 | + |
| 137 | + def predict_step(self, batch, batch_idx, dataloader_idx=0): |
| 138 | + return |
| 139 | + |
| 140 | + model = CustomModel() |
| 141 | + |
| 142 | + # check the sanity dataloaders |
| 143 | + num_sanity_val_steps = 4 |
| 144 | + trainer = Trainer( |
| 145 | + default_root_dir=tmp_path, max_epochs=1, limit_train_batches=0, num_sanity_val_steps=num_sanity_val_steps |
| 146 | + ) |
| 147 | + pbar = trainer.progress_bar_callback |
| 148 | + with mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm", MockTqdm): |
| 149 | + trainer.fit(model) |
| 150 | + |
| 151 | + expected_sanity_steps = [num_sanity_val_steps] * num_dl |
| 152 | + assert not pbar.val_progress_bar.leave |
| 153 | + assert trainer.num_sanity_val_batches == expected_sanity_steps |
| 154 | + assert pbar.val_progress_bar.total_values == expected_sanity_steps |
| 155 | + assert pbar.val_progress_bar.n_values == list(range(num_sanity_val_steps + 1)) * num_dl |
| 156 | + assert pbar.val_progress_bar.descriptions == [f"Sanity Checking DataLoader {i}: " for i in range(num_dl)] |
| 157 | + |
| 158 | + # fit |
| 159 | + trainer = Trainer(default_root_dir=tmp_path, max_epochs=1) |
| 160 | + pbar = trainer.progress_bar_callback |
| 161 | + with mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm", MockTqdm): |
| 162 | + trainer.fit(model) |
| 163 | + |
| 164 | + n = trainer.num_training_batches |
| 165 | + m = trainer.num_val_batches |
| 166 | + assert len(trainer.train_dataloader) == n |
| 167 | + # train progress bar should have reached the end |
| 168 | + assert pbar.train_progress_bar.total == n |
| 169 | + assert pbar.train_progress_bar.n == n |
| 170 | + assert pbar.train_progress_bar.leave |
| 171 | + |
| 172 | + # check val progress bar total |
| 173 | + assert pbar.val_progress_bar.total_values == m |
| 174 | + assert pbar.val_progress_bar.n_values == list(range(m[0] + 1)) * num_dl |
| 175 | + assert pbar.val_progress_bar.descriptions == [f"Validation DataLoader {i}: " for i in range(num_dl)] |
| 176 | + assert not pbar.val_progress_bar.leave |
| 177 | + |
| 178 | + # validate |
| 179 | + with mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm", MockTqdm): |
| 180 | + trainer.validate(model) |
| 181 | + assert trainer.num_val_batches == m |
| 182 | + assert pbar.val_progress_bar.total_values == m |
| 183 | + assert pbar.val_progress_bar.n_values == list(range(m[0] + 1)) * num_dl |
| 184 | + assert pbar.val_progress_bar.descriptions == [f"Validation DataLoader {i}: " for i in range(num_dl)] |
| 185 | + |
| 186 | + # test |
| 187 | + with mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm", MockTqdm): |
| 188 | + trainer.test(model) |
| 189 | + assert pbar.test_progress_bar.leave |
| 190 | + k = trainer.num_test_batches |
| 191 | + assert pbar.test_progress_bar.total_values == k |
| 192 | + assert pbar.test_progress_bar.n_values == list(range(k[0] + 1)) * num_dl |
| 193 | + assert pbar.test_progress_bar.descriptions == [f"Testing DataLoader {i}: " for i in range(num_dl)] |
| 194 | + assert pbar.test_progress_bar.leave |
| 195 | + |
| 196 | + # predict |
| 197 | + with mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm", MockTqdm): |
| 198 | + trainer.predict(model) |
| 199 | + assert pbar.predict_progress_bar.leave |
| 200 | + k = trainer.num_predict_batches |
| 201 | + assert pbar.predict_progress_bar.total_values == k |
| 202 | + assert pbar.predict_progress_bar.n_values == list(range(k[0] + 1)) * num_dl |
| 203 | + assert pbar.predict_progress_bar.descriptions == [f"Predicting DataLoader {i}: " for i in range(num_dl)] |
203 | 204 | assert pbar.predict_progress_bar.leave |
204 | 205 |
|
205 | 206 |
|
206 | 207 | def test_tqdm_progress_bar_fast_dev_run(tmp_path): |
207 | | - model = BoringModel() |
| 208 | + with patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False): |
| 209 | + model = BoringModel() |
208 | 210 |
|
209 | | - trainer = Trainer(default_root_dir=tmp_path, fast_dev_run=True) |
| 211 | + trainer = Trainer(default_root_dir=tmp_path, fast_dev_run=True) |
210 | 212 |
|
211 | | - trainer.fit(model) |
| 213 | + trainer.fit(model) |
212 | 214 |
|
213 | | - pbar = trainer.progress_bar_callback |
| 215 | + pbar = trainer.progress_bar_callback |
214 | 216 |
|
215 | | - assert pbar.val_progress_bar.n == 1 |
216 | | - assert pbar.val_progress_bar.total == 1 |
| 217 | + assert pbar.val_progress_bar.n == 1 |
| 218 | + assert pbar.val_progress_bar.total == 1 |
217 | 219 |
|
218 | | - # the train progress bar should display 1 batch |
219 | | - assert pbar.train_progress_bar.total == 1 |
220 | | - assert pbar.train_progress_bar.n == 1 |
| 220 | + # the train progress bar should display 1 batch |
| 221 | + assert pbar.train_progress_bar.total == 1 |
| 222 | + assert pbar.train_progress_bar.n == 1 |
221 | 223 |
|
222 | | - trainer.validate(model) |
| 224 | + trainer.validate(model) |
223 | 225 |
|
224 | | - # the validation progress bar should display 1 batch |
225 | | - assert pbar.val_progress_bar.total == 1 |
226 | | - assert pbar.val_progress_bar.n == 1 |
| 226 | + # the validation progress bar should display 1 batch |
| 227 | + assert pbar.val_progress_bar.total == 1 |
| 228 | + assert pbar.val_progress_bar.n == 1 |
227 | 229 |
|
228 | | - trainer.test(model) |
| 230 | + trainer.test(model) |
229 | 231 |
|
230 | 232 | # the test progress bar should display 1 batch |
231 | 233 | assert pbar.test_progress_bar.total == 1 |
@@ -325,14 +327,15 @@ def test_tqdm_progress_bar_default_value(tmp_path): |
325 | 327 | @mock.patch.dict(os.environ, {"COLAB_GPU": "1"}) |
326 | 328 | def test_tqdm_progress_bar_value_on_colab(tmp_path): |
327 | 329 | """Test that Trainer will override the default in Google COLAB.""" |
328 | | - trainer = Trainer(default_root_dir=tmp_path) |
329 | | - assert trainer.progress_bar_callback.refresh_rate == 20 |
| 330 | + with patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False): |
| 331 | + trainer = Trainer(default_root_dir=tmp_path) |
| 332 | + assert trainer.progress_bar_callback.refresh_rate == 20 |
330 | 333 |
|
331 | | - trainer = Trainer(default_root_dir=tmp_path, callbacks=TQDMProgressBar()) |
332 | | - assert trainer.progress_bar_callback.refresh_rate == 20 |
| 334 | + trainer = Trainer(default_root_dir=tmp_path, callbacks=TQDMProgressBar()) |
| 335 | + assert trainer.progress_bar_callback.refresh_rate == 20 |
333 | 336 |
|
334 | | - trainer = Trainer(default_root_dir=tmp_path, callbacks=TQDMProgressBar(refresh_rate=19)) |
335 | | - assert trainer.progress_bar_callback.refresh_rate == 19 |
| 337 | + trainer = Trainer(default_root_dir=tmp_path, callbacks=TQDMProgressBar(refresh_rate=19)) |
| 338 | + assert trainer.progress_bar_callback.refresh_rate == 19 |
336 | 339 |
|
337 | 340 |
|
338 | 341 | @pytest.mark.parametrize( |
@@ -413,21 +416,22 @@ def test_test_progress_bar_update_amount(tmp_path, test_batches: int, refresh_ra |
413 | 416 |
|
414 | 417 | def test_tensor_to_float_conversion(tmp_path): |
415 | 418 | """Check tensor gets converted to float.""" |
416 | | - |
417 | | - class TestModel(BoringModel): |
418 | | - def training_step(self, batch, batch_idx): |
419 | | - self.log("a", torch.tensor(0.123), prog_bar=True, on_epoch=False) |
420 | | - self.log("b", torch.tensor([1]), prog_bar=True, on_epoch=False) |
421 | | - self.log("c", 2, prog_bar=True, on_epoch=False) |
422 | | - return super().training_step(batch, batch_idx) |
423 | | - |
424 | | - trainer = Trainer( |
425 | | - default_root_dir=tmp_path, max_epochs=1, limit_train_batches=2, logger=False, enable_checkpointing=False |
426 | | - ) |
427 | | - trainer.fit(TestModel()) |
428 | | - |
429 | | - torch.testing.assert_close(trainer.progress_bar_metrics["a"], 0.123) |
430 | | - assert trainer.progress_bar_metrics["b"] == 1.0 |
| 419 | + with patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False): |
| 420 | + |
| 421 | + class TestModel(BoringModel): |
| 422 | + def training_step(self, batch, batch_idx): |
| 423 | + self.log("a", torch.tensor(0.123), prog_bar=True, on_epoch=False) |
| 424 | + self.log("b", torch.tensor([1]), prog_bar=True, on_epoch=False) |
| 425 | + self.log("c", 2, prog_bar=True, on_epoch=False) |
| 426 | + return super().training_step(batch, batch_idx) |
| 427 | + |
| 428 | + trainer = Trainer( |
| 429 | + default_root_dir=tmp_path, max_epochs=1, limit_train_batches=2, logger=False, enable_checkpointing=False |
| 430 | + ) |
| 431 | + trainer.fit(TestModel()) |
| 432 | + |
| 433 | + torch.testing.assert_close(trainer.progress_bar_metrics["a"], 0.123) |
| 434 | + assert trainer.progress_bar_metrics["b"] == 1.0 |
431 | 435 | assert trainer.progress_bar_metrics["c"] == 2.0 |
432 | 436 | pbar = trainer.progress_bar_callback.train_progress_bar |
433 | 437 | actual = str(pbar.postfix) |
|
0 commit comments