|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 | import os |
| 15 | +from unittest.mock import Mock, call |
| 16 | + |
15 | 17 | import pytest |
16 | 18 | from unittest import mock |
17 | 19 |
|
18 | 20 | from pytorch_lightning import Trainer |
19 | 21 | from pytorch_lightning.callbacks import ProgressBarBase, ProgressBar, ModelCheckpoint |
20 | 22 | from pytorch_lightning.utilities.exceptions import MisconfigurationException |
21 | | -from tests.base import EvalModelTemplate |
| 23 | +from tests.base import EvalModelTemplate, BoringModel |
22 | 24 |
|
23 | 25 |
|
24 | 26 | @pytest.mark.parametrize('callbacks,refresh_rate', [ |
@@ -252,3 +254,77 @@ def test_progress_bar_warning_on_colab(tmpdir): |
252 | 254 | ) |
253 | 255 |
|
254 | 256 | assert trainer.progress_bar_callback.refresh_rate == 19 |
| 257 | + |
| 258 | + |
| 259 | +class MockedUpdateProgressBars(ProgressBar): |
| 260 | + """ Mocks the update method once bars get initializied. """ |
| 261 | + |
| 262 | + def _mock_bar_update(self, bar): |
| 263 | + bar.update = Mock(wraps=bar.update) |
| 264 | + return bar |
| 265 | + |
| 266 | + def init_train_tqdm(self): |
| 267 | + bar = super().init_train_tqdm() |
| 268 | + return self._mock_bar_update(bar) |
| 269 | + |
| 270 | + def init_validation_tqdm(self): |
| 271 | + bar = super().init_validation_tqdm() |
| 272 | + return self._mock_bar_update(bar) |
| 273 | + |
| 274 | + def init_test_tqdm(self): |
| 275 | + bar = super().init_test_tqdm() |
| 276 | + return self._mock_bar_update(bar) |
| 277 | + |
| 278 | + |
| 279 | +@pytest.mark.parametrize("train_batches,val_batches,refresh_rate,train_deltas,val_deltas", [ |
| 280 | + [2, 3, 1, [1, 1, 1, 1, 1], [1, 1, 1]], |
| 281 | + [0, 0, 3, [], []], |
| 282 | + [1, 0, 3, [1], []], |
| 283 | + [1, 1, 3, [2], [1]], |
| 284 | + [5, 0, 3, [3, 2], []], |
| 285 | + [5, 2, 3, [3, 3, 1], [2]], |
| 286 | + [5, 2, 6, [6, 1], [2]], |
| 287 | +]) |
| 288 | +def test_main_progress_bar_update_amount(tmpdir, train_batches, val_batches, refresh_rate, train_deltas, val_deltas): |
| 289 | + """ |
| 290 | + Test that the main progress updates with the correct amount together with the val progress. At the end of |
| 291 | + the epoch, the progress must not overshoot if the number of steps is not divisible by the refresh rate. |
| 292 | + """ |
| 293 | + model = BoringModel() |
| 294 | + progress_bar = MockedUpdateProgressBars(refresh_rate=refresh_rate) |
| 295 | + trainer = Trainer( |
| 296 | + default_root_dir=tmpdir, |
| 297 | + max_epochs=1, |
| 298 | + limit_train_batches=train_batches, |
| 299 | + limit_val_batches=val_batches, |
| 300 | + callbacks=[progress_bar], |
| 301 | + logger=False, |
| 302 | + checkpoint_callback=False, |
| 303 | + ) |
| 304 | + trainer.fit(model) |
| 305 | + progress_bar.main_progress_bar.update.assert_has_calls([call(delta) for delta in train_deltas]) |
| 306 | + if val_batches > 0: |
| 307 | + progress_bar.val_progress_bar.update.assert_has_calls([call(delta) for delta in val_deltas]) |
| 308 | + |
| 309 | + |
| 310 | +@pytest.mark.parametrize("test_batches,refresh_rate,test_deltas", [ |
| 311 | + [1, 3, [1]], |
| 312 | + [3, 1, [1, 1, 1]], |
| 313 | + [5, 3, [3, 2]], |
| 314 | +]) |
| 315 | +def test_test_progress_bar_update_amount(tmpdir, test_batches, refresh_rate, test_deltas): |
| 316 | + """ |
| 317 | + Test that test progress updates with the correct amount. |
| 318 | + """ |
| 319 | + model = BoringModel() |
| 320 | + progress_bar = MockedUpdateProgressBars(refresh_rate=refresh_rate) |
| 321 | + trainer = Trainer( |
| 322 | + default_root_dir=tmpdir, |
| 323 | + max_epochs=1, |
| 324 | + limit_test_batches=test_batches, |
| 325 | + callbacks=[progress_bar], |
| 326 | + logger=False, |
| 327 | + checkpoint_callback=False, |
| 328 | + ) |
| 329 | + trainer.test(model) |
| 330 | + progress_bar.test_progress_bar.update.assert_has_calls([call(delta) for delta in test_deltas]) |
0 commit comments