Skip to content

Commit c1af4d0

Browse files
authored
Better graceful shutdown for KeyboardInterrupt (#19976)
1 parent b16e998 commit c1af4d0

File tree

11 files changed

+134
-22
lines changed

11 files changed

+134
-22
lines changed

src/lightning/fabric/CHANGELOG.md

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,40 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

77

8+
## [unreleased] - YYYY-MM-DD
9+
10+
### Added
11+
12+
-
13+
14+
-
15+
16+
### Changed
17+
18+
-
19+
20+
-
21+
22+
### Deprecated
23+
24+
-
25+
26+
-
27+
28+
### Removed
29+
30+
-
31+
32+
-
33+
34+
### Fixed
35+
36+
-
37+
38+
-
39+
40+
41+
842
## [2.3.0] - 2024-06-13
943

1044
### Added

src/lightning/fabric/utilities/distributed.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import contextlib
33
import logging
44
import os
5+
import signal
56
import time
67
from contextlib import nullcontext
78
from datetime import timedelta
@@ -306,8 +307,11 @@ def _init_dist_connection(
306307

307308

308309
def _destroy_dist_connection() -> None:
310+
# Don't allow Ctrl+C to interrupt this handler
311+
signal.signal(signal.SIGINT, signal.SIG_IGN)
309312
if _distributed_is_initialized():
310313
torch.distributed.destroy_process_group()
314+
signal.signal(signal.SIGINT, signal.SIG_DFL)
311315

312316

313317
def _get_default_process_group_backend_for_device(device: torch.device) -> str:

src/lightning/pytorch/CHANGELOG.md

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,41 @@ All notable changes to this project will be documented in this file.
44

55
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

7+
8+
## [unreleased] - YYYY-MM-DD
9+
10+
### Added
11+
12+
-
13+
14+
-
15+
16+
### Changed
17+
18+
- Triggering KeyboardInterrupt (Ctrl+C) during `.fit()`, `.evaluate()`, `.test()` or `.predict()` now terminates all processes launched by the Trainer and exits the program ([#19976](https://github.com/Lightning-AI/pytorch-lightning/pull/19976))
19+
20+
-
21+
22+
### Deprecated
23+
24+
-
25+
26+
-
27+
28+
### Removed
29+
30+
-
31+
32+
-
33+
34+
### Fixed
35+
36+
-
37+
38+
-
39+
40+
41+
742
## [2.3.0] - 2024-06-13
843

944
### Added

src/lightning/pytorch/strategies/launchers/multiprocessing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def update_main_process_results(self, trainer: "pl.Trainer", extra: Dict[str, An
259259
def kill(self, signum: _SIGNUM) -> None:
260260
for proc in self.procs:
261261
if proc.is_alive() and proc.pid is not None:
262-
log.info(f"pid {os.getpid()} killing {proc.pid} with {signum}")
262+
log.debug(f"Process {os.getpid()} is terminating {proc.pid} with {signum}")
263263
with suppress(ProcessLookupError):
264264
os.kill(proc.pid, signum)
265265

src/lightning/pytorch/strategies/launchers/subprocess_script.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"]
107107
@override
108108
def kill(self, signum: _SIGNUM) -> None:
109109
for proc in self.procs:
110-
log.info(f"pid {os.getpid()} killing {proc.pid} with {signum}")
110+
log.debug(f"Process {os.getpid()} is terminating {proc.pid} with {signum}")
111111
# this skips subprocesses already terminated
112112
proc.send_signal(signum)
113113

src/lightning/pytorch/trainer/call.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import logging
15+
import signal
1516
from copy import deepcopy
1617
from typing import Any, Callable, Dict, Optional, Type, Union
1718

@@ -20,10 +21,12 @@
2021
import lightning.pytorch as pl
2122
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
2223
from lightning.pytorch.callbacks import Checkpoint, EarlyStopping
24+
from lightning.pytorch.strategies.launchers import _SubprocessScriptLauncher
25+
from lightning.pytorch.trainer.connectors.signal_connector import _get_sigkill_signal
2326
from lightning.pytorch.trainer.states import TrainerStatus
2427
from lightning.pytorch.utilities.exceptions import _TunerExitException
2528
from lightning.pytorch.utilities.model_helpers import is_overridden
26-
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
29+
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn
2730

2831
log = logging.getLogger(__name__)
2932

@@ -49,12 +52,17 @@ def _call_and_handle_interrupt(trainer: "pl.Trainer", trainer_fn: Callable, *arg
4952
trainer.state.status = TrainerStatus.FINISHED
5053
trainer.state.stage = None
5154

52-
# TODO: Unify both exceptions below, where `KeyboardError` doesn't re-raise
5355
except KeyboardInterrupt as exception:
54-
rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
55-
# user could press Ctrl+c many times... only shutdown once
56-
if not trainer.interrupted:
57-
_interrupt(trainer, exception)
56+
rank_zero_info("\nDetected KeyboardInterrupt, attempting graceful shutdown ...")
57+
# user could press Ctrl+C many times, disable KeyboardInterrupt for shutdown
58+
signal.signal(signal.SIGINT, signal.SIG_IGN)
59+
_interrupt(trainer, exception)
60+
trainer._teardown()
61+
launcher = trainer.strategy.launcher
62+
if isinstance(launcher, _SubprocessScriptLauncher):
63+
launcher.kill(_get_sigkill_signal())
64+
exit(1)
65+
5866
except BaseException as exception:
5967
_interrupt(trainer, exception)
6068
trainer._teardown()

src/lightning/pytorch/trainer/connectors/signal_connector.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import os
33
import re
44
import signal
5-
import sys
65
import threading
76
from subprocess import call
87
from types import FrameType
@@ -54,7 +53,7 @@ def register_signal_handlers(self) -> None:
5453
sigterm_handlers.append(self._sigterm_handler_fn)
5554

5655
# Windows seems to have signal incompatibilities
57-
if not self._is_on_windows():
56+
if not _IS_WINDOWS:
5857
sigusr = environment.requeue_signal if isinstance(environment, SLURMEnvironment) else signal.SIGUSR1
5958
assert sigusr is not None
6059
if sigusr_handlers and not self._has_already_handler(sigusr):
@@ -155,10 +154,6 @@ def _valid_signals() -> Set[signal.Signals]:
155154
}
156155
return set(signal.Signals)
157156

158-
@staticmethod
159-
def _is_on_windows() -> bool:
160-
return sys.platform == "win32"
161-
162157
@staticmethod
163158
def _has_already_handler(signum: _SIGNUM) -> bool:
164159
return signal.getsignal(signum) not in (None, signal.SIG_DFL)
@@ -172,3 +167,7 @@ def __getstate__(self) -> Dict:
172167
state = self.__dict__.copy()
173168
state["_original_handlers"] = {}
174169
return state
170+
171+
172+
def _get_sigkill_signal() -> _SIGNUM:
173+
return signal.SIGTERM if _IS_WINDOWS else signal.SIGKILL

tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def on_train_start(self) -> None:
143143

144144
with mock.patch(
145145
"lightning.pytorch.callbacks.progress.rich_progress.Progress.stop", autospec=True
146-
) as mock_progress_stop:
146+
) as mock_progress_stop, pytest.raises(SystemExit):
147147
progress_bar = RichProgressBar()
148148
trainer = Trainer(
149149
default_root_dir=tmp_path,

tests/tests_pytorch/callbacks/test_lambda_function.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
from functools import partial
1515

16+
import pytest
1617
from lightning.pytorch import Trainer, seed_everything
1718
from lightning.pytorch.callbacks import Callback, LambdaCallback
1819
from lightning.pytorch.demos.boring_classes import BoringModel
@@ -23,10 +24,13 @@
2324
def test_lambda_call(tmp_path):
2425
seed_everything(42)
2526

27+
class CustomException(Exception):
28+
pass
29+
2630
class CustomModel(BoringModel):
2731
def on_train_epoch_start(self):
2832
if self.current_epoch > 1:
29-
raise KeyboardInterrupt
33+
raise CustomException("Custom exception to trigger `on_exception` hooks")
3034

3135
checker = set()
3236

@@ -59,7 +63,8 @@ def call(hook, *_, **__):
5963
limit_predict_batches=1,
6064
callbacks=[LambdaCallback(**hooks_args)],
6165
)
62-
trainer.fit(model, ckpt_path=ckpt_path)
66+
with pytest.raises(CustomException):
67+
trainer.fit(model, ckpt_path=ckpt_path)
6368
trainer.test(model)
6469
trainer.predict(model)
6570

tests/tests_pytorch/trainer/test_states.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,5 +84,6 @@ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
8484

8585
trainer = Trainer(callbacks=[InterruptCallback()], default_root_dir=tmp_path, **extra_params)
8686

87-
trainer.fit(model)
87+
with pytest.raises(SystemExit):
88+
trainer.fit(model)
8889
assert trainer.interrupted

0 commit comments

Comments
 (0)