|
24 | 24 |
|
25 | 25 | from pl_examples.bug_report_model import RandomDataset |
26 | 26 | from pytorch_lightning import LightningModule, Trainer |
27 | | -from pytorch_lightning.callbacks import ModelCheckpoint |
| 27 | +from pytorch_lightning.callbacks import Callback, ModelCheckpoint |
28 | 28 | from pytorch_lightning.loops import Loop, TrainingBatchLoop |
29 | 29 | from pytorch_lightning.trainer.progress import BaseProgress |
30 | 30 | from tests.helpers import BoringModel |
@@ -912,8 +912,10 @@ def val_dataloader(self): |
912 | 912 |
|
913 | 913 |
|
914 | 914 | @RunIf(min_torch="1.8.0") |
915 | | -@pytest.mark.parametrize("persistent_workers", (False, True)) |
916 | | -def test_workers_are_shutdown(tmpdir, persistent_workers): |
| 915 | +@pytest.mark.parametrize("should_fail", [False, True]) |
| 916 | +# False is de-activated due to slowness |
| 917 | +@pytest.mark.parametrize("persistent_workers", [True]) |
| 918 | +def test_workers_are_shutdown(tmpdir, should_fail, persistent_workers): |
917 | 919 | # `num_workers == 1` uses `_MultiProcessingDataLoaderIter` |
918 | 920 | # `persistent_workers` makes sure `self._iterator` gets set on the `DataLoader` instance |
919 | 921 |
|
@@ -941,12 +943,30 @@ def _get_iterator(self): |
941 | 943 | train_dataloader = TestDataLoader(RandomDataset(32, 64), num_workers=1, persistent_workers=persistent_workers) |
942 | 944 | val_dataloader = TestDataLoader(RandomDataset(32, 64), num_workers=1, persistent_workers=persistent_workers) |
943 | 945 |
|
| 946 | + class TestCallback(Callback): |
| 947 | + def on_train_epoch_end(self, trainer, *_): |
| 948 | + if trainer.current_epoch == 1: |
| 949 | + raise CustomException |
| 950 | + |
944 | 951 | max_epochs = 3 |
| 952 | + |
945 | 953 | model = BoringModel() |
946 | | - trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, max_epochs=max_epochs) |
947 | | - trainer.fit(model, train_dataloader, val_dataloader) |
948 | | - assert train_dataloader.count_shutdown_workers == (2 if persistent_workers else max_epochs) |
| 954 | + trainer = Trainer( |
| 955 | + default_root_dir=tmpdir, |
| 956 | + limit_train_batches=2, |
| 957 | + limit_val_batches=2, |
| 958 | + max_epochs=max_epochs, |
| 959 | + callbacks=TestCallback() if should_fail else None, |
| 960 | + ) |
| 961 | + |
| 962 | + if should_fail: |
| 963 | + with pytest.raises(CustomException): |
| 964 | + trainer.fit(model, train_dataloader, val_dataloader) |
| 965 | + else: |
| 966 | + trainer.fit(model, train_dataloader, val_dataloader) |
| 967 | + |
| 968 | + assert train_dataloader.count_shutdown_workers == 2 if should_fail else (2 if persistent_workers else max_epochs) |
949 | 969 | # on sanity checking end, the workers are being deleted too. |
950 | | - assert val_dataloader.count_shutdown_workers == (2 if persistent_workers else max_epochs + 1) |
| 970 | + assert val_dataloader.count_shutdown_workers == 2 if persistent_workers else (3 if should_fail else max_epochs + 1) |
951 | 971 | assert train_dataloader._iterator is None |
952 | 972 | assert val_dataloader._iterator is None |
0 commit comments