|
5 | 5 |
|
6 | 6 | import argparse |
7 | 7 | import os |
| 8 | +import signal |
| 9 | +import subprocess |
8 | 10 | import sys |
9 | 11 | import time |
10 | 12 | from pathlib import Path |
| 13 | +from typing import Any |
11 | 14 |
|
12 | 15 | from sqlit.domains.connections.cli.helpers import add_schema_arguments, build_connection_config_from_args |
13 | 16 | from sqlit.domains.connections.domain.config import AuthType, ConnectionConfig, DatabaseType |
@@ -98,6 +101,72 @@ def _extract_connection_url(argv: list[str]) -> tuple[str | None, list[str]]: |
98 | 101 | return url, result_argv |
99 | 102 |
|
100 | 103 |
|
| 104 | +def _sane_tty() -> None: |
| 105 | + if os.name != "posix": |
| 106 | + return |
| 107 | + if not sys.stdin.isatty(): |
| 108 | + return |
| 109 | + try: |
| 110 | + subprocess.run( |
| 111 | + ["stty", "sane"], |
| 112 | + stdin=sys.stdin, |
| 113 | + stdout=subprocess.DEVNULL, |
| 114 | + stderr=subprocess.DEVNULL, |
| 115 | + check=False, |
| 116 | + ) |
| 117 | + except Exception: |
| 118 | + pass |
| 119 | + |
| 120 | + |
| 121 | +def _run_app(app: Any) -> int: |
| 122 | + exit_code: int | None = None |
| 123 | + handled_signals = [signal.SIGINT, signal.SIGTERM] |
| 124 | + for maybe_sig in (getattr(signal, "SIGHUP", None), getattr(signal, "SIGQUIT", None)): |
| 125 | + if isinstance(maybe_sig, signal.Signals): |
| 126 | + handled_signals.append(maybe_sig) |
| 127 | + |
| 128 | + previous_handlers: dict[signal.Signals, Any] = {} |
| 129 | + |
| 130 | + def _handle_signal(signum: int, _frame: Any) -> None: |
| 131 | + nonlocal exit_code |
| 132 | + exit_code = 128 + signum |
| 133 | + try: |
| 134 | + close_worker = getattr(app, "_close_process_worker_client", None) |
| 135 | + if callable(close_worker): |
| 136 | + close_worker() |
| 137 | + except Exception: |
| 138 | + pass |
| 139 | + try: |
| 140 | + app.exit() |
| 141 | + return |
| 142 | + except Exception: |
| 143 | + _sane_tty() |
| 144 | + raise KeyboardInterrupt |
| 145 | + |
| 146 | + for sig in handled_signals: |
| 147 | + try: |
| 148 | + previous_handlers[sig] = signal.getsignal(sig) |
| 149 | + signal.signal(sig, _handle_signal) |
| 150 | + except Exception: |
| 151 | + continue |
| 152 | + |
| 153 | + try: |
| 154 | + _sane_tty() |
| 155 | + app.run() |
| 156 | + except KeyboardInterrupt: |
| 157 | + _sane_tty() |
| 158 | + return exit_code if exit_code is not None else 130 |
| 159 | + finally: |
| 160 | + _sane_tty() |
| 161 | + for sig, handler in previous_handlers.items(): |
| 162 | + try: |
| 163 | + signal.signal(sig, handler) |
| 164 | + except Exception: |
| 165 | + pass |
| 166 | + |
| 167 | + return exit_code if exit_code is not None else 0 |
| 168 | + |
| 169 | + |
101 | 170 | def _parse_missing_drivers(value: str | None) -> set[str]: |
102 | 171 | if not value: |
103 | 172 | return set() |
@@ -550,7 +619,9 @@ def main() -> int: |
550 | 619 | startup_connection=startup_config, |
551 | 620 | exclusive_connection=exclusive_connection, |
552 | 621 | ) |
553 | | - app.run() |
| 622 | + exit_code = _run_app(app) |
| 623 | + if exit_code != 0: |
| 624 | + return exit_code |
554 | 625 | if getattr(app, "_restart_requested", False): |
555 | 626 | argv = getattr(app, "_restart_argv", None) or app._compute_restart_argv() |
556 | 627 | try: |
@@ -601,8 +672,7 @@ def main() -> int: |
601 | 672 | return 1 |
602 | 673 |
|
603 | 674 | app = SSMSTUI(services=services, startup_connection=temp_config) |
604 | | - app.run() |
605 | | - return 0 |
| 675 | + return _run_app(app) |
606 | 676 |
|
607 | 677 | if args.command in {"connections", "connection"}: |
608 | 678 | if args.conn_command == "list": |
|
0 commit comments