Skip to content

Commit d044523

Browse files
authored
Make SSH packages optional (#19)
* Make SSH packages optional * Update gh worker and readme --------- Co-authored-by: Peter Adams <18162810+Maxteabag@users.noreply.github.com>
1 parent 971b7b2 commit d044523

File tree

7 files changed

+114
-10
lines changed

7 files changed

+114
-10
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@ jobs:
414414
- name: Install dependencies
415415
run: |
416416
python -m pip install --upgrade pip
417-
pip install -e ".[test]"
417+
pip install -e ".[test,ssh]"
418418
pip install psycopg2-binary
419419
420420
- name: Create Docker network

README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,18 @@ Most of the time you can just run `sqlit` and connect. If a Python driver is mis
196196

197197
**Note:** SQL Server also requires the platform-specific ODBC driver. On your first connection attempt, `sqlit` can help you install it if it's missing.
198198

199+
### SSH Tunnel Support
200+
201+
SSH tunnel functionality requires additional dependencies. Install with the `ssh` extra:
202+
203+
| Method | Command |
204+
| :--- | :--- |
205+
| pipx | `pipx install 'sqlit-tui[ssh]'` |
206+
| uv | `uv tool install 'sqlit-tui[ssh]'` |
207+
| pip | `pip install 'sqlit-tui[ssh]'` |
208+
209+
If you try to create an SSH connection without these dependencies, sqlit will detect this and show you the exact command to install them for your environment.
210+
199211
## License
200212

201213
MIT

mock-ssh-missing.json

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
{
2+
"_note": "Run: sqlit --settings mock-ssh-missing.json",
3+
"mock": {
4+
"enabled": true,
5+
"profile": "empty",
6+
"use_default_adapters": true,
7+
"drivers": {
8+
"missing": ["ssh"],
9+
"install_result": "real",
10+
"pipx": "auto"
11+
}
12+
}
13+
}

pyproject.toml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,6 @@ classifiers = [
2929
dependencies = [
3030
"textual[syntax]>=6.10.0",
3131
"pyperclip>=1.8.2",
32-
"sshtunnel>=0.4.0",
33-
"paramiko>=2.0.0,<4.0.0", # sshtunnel requires paramiko<4.0.0 (DSSKey removed in 4.0)
3432
"keyring>=24.0.0",
3533
]
3634

@@ -44,6 +42,8 @@ all = [
4442
"duckdb>=0.9.0",
4543
"requests>=2.0.0",
4644
"libsql-client>=0.1.0",
45+
"sshtunnel>=0.4.0",
46+
"paramiko>=2.0.0,<4.0.0",
4747
]
4848
postgres = ["psycopg2-binary>=2.9.0"]
4949
cockroachdb = ["psycopg2-binary>=2.9.0"]
@@ -54,6 +54,10 @@ oracle = ["oracledb>=2.0.0"]
5454
duckdb = ["duckdb>=0.9.0"]
5555
d1 = ["requests>=2.0.0"]
5656
turso = ["libsql-client>=0.1.0"]
57+
ssh = [
58+
"sshtunnel>=0.4.0",
59+
"paramiko>=2.0.0,<4.0.0",
60+
]
5761
test = [
5862
"pytest>=7.0",
5963
"pytest-timeout>=2.0",

sqlit/db/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
is_file_based,
1616
supports_ssh,
1717
)
18-
from .tunnel import create_ssh_tunnel
18+
from .tunnel import create_ssh_tunnel, ensure_ssh_tunnel_available
1919

2020
__all__ = [
2121
# Base
@@ -50,6 +50,7 @@
5050
"TursoAdapter",
5151
# Tunnel
5252
"create_ssh_tunnel",
53+
"ensure_ssh_tunnel_available",
5354
]
5455

5556
if TYPE_CHECKING:

sqlit/db/tunnel.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,23 @@
1010
from ..config import ConnectionConfig
1111

1212

13+
def ensure_ssh_tunnel_available() -> None:
14+
"""Ensure SSH tunnel dependencies are installed."""
15+
forced_missing = os.environ.get("SQLIT_MOCK_MISSING_DRIVERS", "").strip()
16+
if forced_missing:
17+
forced = {s.strip() for s in forced_missing.split(",") if s.strip()}
18+
if "ssh" in forced:
19+
from .exceptions import MissingDriverError
20+
21+
raise MissingDriverError("SSH tunnel", "ssh", "sshtunnel")
22+
try:
23+
import sshtunnel # noqa: F401
24+
except Exception as e:
25+
from .exceptions import MissingDriverError
26+
27+
raise MissingDriverError("SSH tunnel", "ssh", "sshtunnel") from e
28+
29+
1330
def create_ssh_tunnel(config: ConnectionConfig) -> tuple[Any, str, int]:
1431
"""Create an SSH tunnel for the connection if SSH is enabled.
1532
@@ -21,6 +38,8 @@ def create_ssh_tunnel(config: ConnectionConfig) -> tuple[Any, str, int]:
2138
port = int(config.port) if config.port else 0
2239
return None, config.server, port
2340

41+
ensure_ssh_tunnel_available()
42+
2443
from sshtunnel import SSHTunnelForwarder
2544

2645
# Parse remote database host and port

sqlit/ui/screens/connection.py

Lines changed: 61 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,8 @@ def __init__(
277277
self.validation_state: ValidationState = ValidationState()
278278
self._saved_dialog_subtitle: str | None = None
279279
self._missing_driver_error: Any = None # Stores MissingDriverError if driver is missing
280+
self._missing_ssh_driver_error: Any = None # Stores MissingDriverError for SSH tunnel
281+
self._install_error: Any = None
280282
self._install_in_progress: bool = False
281283
self._install_spinner_timer: Timer | None = None
282284
self._install_spinner_index: int = 0
@@ -506,6 +508,36 @@ def _check_driver_availability(self, db_type: DatabaseType) -> None:
506508

507509
self._update_driver_status_ui()
508510

511+
def _check_ssh_driver_availability(self) -> None:
512+
from ...db import ensure_ssh_tunnel_available
513+
514+
self._missing_ssh_driver_error = None
515+
if not supports_ssh(self._current_db_type.value):
516+
self._update_driver_status_ui()
517+
return
518+
try:
519+
ensure_ssh_tunnel_available()
520+
except MissingDriverError as e:
521+
self._missing_ssh_driver_error = e
522+
523+
self._update_driver_status_ui()
524+
525+
def _get_active_tab(self) -> str:
526+
try:
527+
tabs = self.query_one("#connection-tabs", TabbedContent)
528+
return tabs.active
529+
except Exception:
530+
return "tab-general"
531+
532+
def _format_install_hint(self, strategy: Any, package_name: str) -> str:
533+
if strategy.kind == "pip":
534+
return f"pip install {package_name}"
535+
if strategy.kind == "pip-user":
536+
return f"pip install --user {package_name}"
537+
if strategy.kind == "pipx":
538+
return f"pipx inject sqlit-tui {package_name}"
539+
return strategy.manual_instructions.split("\n")[0].strip()
540+
509541
def _update_driver_status_ui(self) -> None:
510542
try:
511543
test_status = self.query_one("#test-status", Static)
@@ -518,8 +550,8 @@ def _update_driver_status_ui(self) -> None:
518550
except Exception:
519551
pass
520552

521-
if self._install_in_progress and self._missing_driver_error:
522-
error = self._missing_driver_error
553+
if self._install_in_progress and self._install_error:
554+
error = self._install_error
523555
spinner = self._INSTALL_SPINNER_FRAMES[self._install_spinner_index % len(self._INSTALL_SPINNER_FRAMES)]
524556
test_status.update(
525557
f"[yellow]⚠ Missing driver:[/] {error.package_name}\n"
@@ -528,11 +560,16 @@ def _update_driver_status_ui(self) -> None:
528560
dialog.border_subtitle = "[bold]Installing…[/] Cancel <esc>"
529561
return
530562

531-
if self._missing_driver_error:
563+
active_tab = self._get_active_tab()
564+
if active_tab == "tab-ssh":
565+
error = self._missing_ssh_driver_error
566+
else:
532567
error = self._missing_driver_error
568+
569+
if error:
533570
strategy = detect_strategy(extra_name=error.extra_name, package_name=error.package_name)
534571
if strategy.can_auto_install:
535-
install_cmd = strategy.manual_instructions.split("\n")[0].strip()
572+
install_cmd = self._format_install_hint(strategy, error.package_name)
536573
test_status.update(
537574
f"[yellow]⚠ Missing driver:[/] {error.package_name}\n"
538575
f"[dim]Install with:[/] {escape(install_cmd)}"
@@ -645,6 +682,7 @@ def _start_missing_driver_install(self, error: Any) -> None:
645682
return
646683

647684
self._install_in_progress = True
685+
self._install_error = error
648686
self._install_spinner_index = 0
649687
self._post_install_message = None
650688
self._update_driver_status_ui()
@@ -672,9 +710,13 @@ def _on_missing_driver_install_complete(self, success: bool, output: str, error:
672710

673711
self._stop_install_spinner()
674712
self._install_in_progress = False
713+
self._install_error = None
675714

676715
if success:
677-
self._check_driver_availability(self._current_db_type)
716+
if isinstance(error, MissingDriverError) and error.extra_name == "ssh":
717+
self._check_ssh_driver_availability()
718+
else:
719+
self._check_driver_availability(self._current_db_type)
678720
self._post_install_message = "Successfully installed driver"
679721
self._update_driver_status_ui()
680722

@@ -828,6 +870,8 @@ def _deferred_driver_check(self) -> None:
828870
t0 = time.perf_counter()
829871

830872
self._check_driver_availability(self._current_db_type)
873+
if self._get_active_tab() == "tab-ssh":
874+
self._check_ssh_driver_availability()
831875

832876
if debug:
833877
print(f"[DEBUG] _check_driver_availability: {(time.perf_counter() - t0)*1000:.1f}ms", file=sys.stderr)
@@ -882,6 +926,12 @@ def _apply_prefill_values(self) -> None:
882926
except Exception:
883927
pass
884928

929+
def on_tabbed_content_tab_activated(self, event: TabbedContent.TabActivated) -> None:
930+
if self._get_active_tab() == "tab-ssh":
931+
self._check_ssh_driver_availability()
932+
else:
933+
self._update_driver_status_ui()
934+
885935
def on_descendant_focus(self, event: Any) -> None:
886936
focused = self.focused
887937
if focused is None:
@@ -1235,9 +1285,14 @@ def action_open_odbc_setup(self) -> None:
12351285
def action_install_driver(self) -> None:
12361286
if self._install_in_progress:
12371287
return
1288+
active_tab = self._get_active_tab()
1289+
if active_tab == "tab-ssh" and self._missing_ssh_driver_error:
1290+
self._prompt_install_missing_driver(self._missing_ssh_driver_error)
1291+
return
12381292
if self._missing_driver_error:
12391293
self._prompt_install_missing_driver(self._missing_driver_error)
1240-
elif self._current_db_type.value == "mssql":
1294+
return
1295+
if self._current_db_type.value == "mssql":
12411296
self._open_odbc_driver_setup()
12421297

12431298
def _clear_field_error(self, name: str) -> None:

0 commit comments

Comments
 (0)