Skip to content

Commit 391e0d6

Browse files
tchatonlexierule
authored andcommitted
shutdown workers on failure (#10463)
1 parent 6baa5cc commit 391e0d6

File tree

3 files changed

+32
-7
lines changed

3 files changed

+32
-7
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2727
- Fixed sampler replacement logic with `overfit_batches` to only replace the sample when `SequentialSampler` is not used ([#10486](https://github.com/PyTorchLightning/pytorch-lightning/issues/10486))
2828

2929

30+
- Fixed an issue that prevented the Trainer to shutdown workers when execution is interrupted due to failure([#10463](https://github.com/PyTorchLightning/pytorch-lightning/issues/10463))
31+
32+
3033
-
3134

3235

pytorch_lightning/trainer/trainer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -697,6 +697,8 @@ def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs:
697697
# reset bookkeeping
698698
self.state.stage = None
699699
self.on_exception(exception)
700+
# shutdown workers
701+
self._data_connector.teardown()
700702
raise
701703

702704
def fit(

tests/loops/test_loops.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
from pl_examples.bug_report_model import RandomDataset
2626
from pytorch_lightning import LightningModule, Trainer
27-
from pytorch_lightning.callbacks import ModelCheckpoint
27+
from pytorch_lightning.callbacks import Callback, ModelCheckpoint
2828
from pytorch_lightning.loops import Loop, TrainingBatchLoop
2929
from pytorch_lightning.trainer.progress import BaseProgress
3030
from tests.helpers import BoringModel
@@ -912,8 +912,10 @@ def val_dataloader(self):
912912

913913

914914
@RunIf(min_torch="1.8.0")
915-
@pytest.mark.parametrize("persistent_workers", (False, True))
916-
def test_workers_are_shutdown(tmpdir, persistent_workers):
915+
@pytest.mark.parametrize("should_fail", [False, True])
916+
# False is de-activated due to slowness
917+
@pytest.mark.parametrize("persistent_workers", [True])
918+
def test_workers_are_shutdown(tmpdir, should_fail, persistent_workers):
917919
# `num_workers == 1` uses `_MultiProcessingDataLoaderIter`
918920
# `persistent_workers` makes sure `self._iterator` gets set on the `DataLoader` instance
919921

@@ -941,12 +943,30 @@ def _get_iterator(self):
941943
train_dataloader = TestDataLoader(RandomDataset(32, 64), num_workers=1, persistent_workers=persistent_workers)
942944
val_dataloader = TestDataLoader(RandomDataset(32, 64), num_workers=1, persistent_workers=persistent_workers)
943945

946+
class TestCallback(Callback):
947+
def on_train_epoch_end(self, trainer, *_):
948+
if trainer.current_epoch == 1:
949+
raise CustomException
950+
944951
max_epochs = 3
952+
945953
model = BoringModel()
946-
trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, max_epochs=max_epochs)
947-
trainer.fit(model, train_dataloader, val_dataloader)
948-
assert train_dataloader.count_shutdown_workers == (2 if persistent_workers else max_epochs)
954+
trainer = Trainer(
955+
default_root_dir=tmpdir,
956+
limit_train_batches=2,
957+
limit_val_batches=2,
958+
max_epochs=max_epochs,
959+
callbacks=TestCallback() if should_fail else None,
960+
)
961+
962+
if should_fail:
963+
with pytest.raises(CustomException):
964+
trainer.fit(model, train_dataloader, val_dataloader)
965+
else:
966+
trainer.fit(model, train_dataloader, val_dataloader)
967+
968+
assert train_dataloader.count_shutdown_workers == 2 if should_fail else (2 if persistent_workers else max_epochs)
949969
# on sanity checking end, the workers are being deleted too.
950-
assert val_dataloader.count_shutdown_workers == (2 if persistent_workers else max_epochs + 1)
970+
assert val_dataloader.count_shutdown_workers == 2 if persistent_workers else (3 if should_fail else max_epochs + 1)
951971
assert train_dataloader._iterator is None
952972
assert val_dataloader._iterator is None

0 commit comments

Comments
 (0)