Skip to content

Commit 5d7bf06

Browse files
committed
Harden shutdown cleanup for TTY and worker
1 parent 696f6b2 commit 5d7bf06

File tree

2 files changed

+76
-3
lines changed

2 files changed

+76
-3
lines changed

sqlit/cli.py

Lines changed: 73 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,12 @@
55

66
import argparse
77
import os
8+
import signal
9+
import subprocess
810
import sys
911
import time
1012
from pathlib import Path
13+
from typing import Any
1114

1215
from sqlit.domains.connections.cli.helpers import add_schema_arguments, build_connection_config_from_args
1316
from sqlit.domains.connections.domain.config import AuthType, ConnectionConfig, DatabaseType
@@ -98,6 +101,72 @@ def _extract_connection_url(argv: list[str]) -> tuple[str | None, list[str]]:
98101
return url, result_argv
99102

100103

104+
def _sane_tty() -> None:
105+
if os.name != "posix":
106+
return
107+
if not sys.stdin.isatty():
108+
return
109+
try:
110+
subprocess.run(
111+
["stty", "sane"],
112+
stdin=sys.stdin,
113+
stdout=subprocess.DEVNULL,
114+
stderr=subprocess.DEVNULL,
115+
check=False,
116+
)
117+
except Exception:
118+
pass
119+
120+
121+
def _run_app(app: Any) -> int:
122+
exit_code: int | None = None
123+
handled_signals = [signal.SIGINT, signal.SIGTERM]
124+
for maybe_sig in (getattr(signal, "SIGHUP", None), getattr(signal, "SIGQUIT", None)):
125+
if isinstance(maybe_sig, signal.Signals):
126+
handled_signals.append(maybe_sig)
127+
128+
previous_handlers: dict[signal.Signals, Any] = {}
129+
130+
def _handle_signal(signum: int, _frame: Any) -> None:
131+
nonlocal exit_code
132+
exit_code = 128 + signum
133+
try:
134+
close_worker = getattr(app, "_close_process_worker_client", None)
135+
if callable(close_worker):
136+
close_worker()
137+
except Exception:
138+
pass
139+
try:
140+
app.exit()
141+
return
142+
except Exception:
143+
_sane_tty()
144+
raise KeyboardInterrupt
145+
146+
for sig in handled_signals:
147+
try:
148+
previous_handlers[sig] = signal.getsignal(sig)
149+
signal.signal(sig, _handle_signal)
150+
except Exception:
151+
continue
152+
153+
try:
154+
_sane_tty()
155+
app.run()
156+
except KeyboardInterrupt:
157+
_sane_tty()
158+
return exit_code if exit_code is not None else 130
159+
finally:
160+
_sane_tty()
161+
for sig, handler in previous_handlers.items():
162+
try:
163+
signal.signal(sig, handler)
164+
except Exception:
165+
pass
166+
167+
return exit_code if exit_code is not None else 0
168+
169+
101170
def _parse_missing_drivers(value: str | None) -> set[str]:
102171
if not value:
103172
return set()
@@ -550,7 +619,9 @@ def main() -> int:
550619
startup_connection=startup_config,
551620
exclusive_connection=exclusive_connection,
552621
)
553-
app.run()
622+
exit_code = _run_app(app)
623+
if exit_code != 0:
624+
return exit_code
554625
if getattr(app, "_restart_requested", False):
555626
argv = getattr(app, "_restart_argv", None) or app._compute_restart_argv()
556627
try:
@@ -601,8 +672,7 @@ def main() -> int:
601672
return 1
602673

603674
app = SSMSTUI(services=services, startup_connection=temp_config)
604-
app.run()
605-
return 0
675+
return _run_app(app)
606676

607677
if args.command in {"connections", "connection"}:
608678
if args.conn_command == "list":

sqlit/domains/shell/app/main.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1167,6 +1167,9 @@ def on_unmount(self) -> None:
11671167
if self._ui_stall_watchdog_timer is not None:
11681168
self._ui_stall_watchdog_timer.stop()
11691169
self._ui_stall_watchdog_timer = None
1170+
close_worker = getattr(self, "_close_process_worker_client", None)
1171+
if callable(close_worker):
1172+
close_worker()
11701173

11711174
def _startup_stamp(self, name: str) -> None:
11721175
if not self._startup_profile:

0 commit comments

Comments
 (0)