Skip to content

Commit 095cef7

Browse files
authored
Merge pull request #122 from Maxteabag/feature/auto-reconnect-after-driver-install
Auto-reconnect after driver installation
2 parents ec6e5f5 + dd3d4ca commit 095cef7

File tree

7 files changed

+353
-3
lines changed

7 files changed

+353
-3
lines changed

sqlit/domains/connections/ui/connection_error_handlers.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,19 @@ def can_handle(self, error: Exception) -> bool:
3636

3737
def handle(self, app: ConnectionErrorApp, error: Exception, config: ConnectionConfig) -> None:
3838
from sqlit.domains.connections.providers.exceptions import MissingDriverError
39+
from sqlit.shared.core.debug_events import emit_debug_event
3940

41+
from .restart_cache import write_pending_connection_cache
4042
from .screens import PackageSetupScreen
4143

44+
# Save pending connection for auto-reconnect after driver install restart
45+
if config.name:
46+
write_pending_connection_cache(config.name)
47+
emit_debug_event(
48+
"driver_install.pending_connection_saved",
49+
connection_name=config.name,
50+
)
51+
4252
# No on_success callback - uses default "Restart to apply" behavior
4353
app.push_screen(PackageSetupScreen(cast(MissingDriverError, error)))
4454

sqlit/domains/connections/ui/restart_cache.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,18 @@ def clear_restart_cache() -> None:
2828
get_restart_cache_path().unlink(missing_ok=True)
2929
except Exception:
3030
pass
31+
32+
33+
def write_pending_connection_cache(connection_name: str) -> None:
34+
"""Cache a pending connection name for auto-reconnect after driver install restart.
35+
36+
This is used when a user tries to connect to a server but the driver is missing.
37+
After the driver is installed and the app restarts, it can auto-connect to this
38+
connection.
39+
"""
40+
payload = {
41+
"version": 2,
42+
"type": "pending_connection",
43+
"connection_name": connection_name,
44+
}
45+
write_restart_cache(payload)

sqlit/domains/explorer/ui/mixins/tree.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,18 @@ def action_refresh_tree(self: TreeMixinHost) -> None:
216216
if hasattr(self, "_loading_nodes"):
217217
self._loading_nodes.clear()
218218
self._schema_service = None
219+
220+
# Reload saved connections from disk (in case added via CLI)
221+
try:
222+
services = getattr(self, "services", None)
223+
if services:
224+
store = getattr(services, "connection_store", None)
225+
if store:
226+
reloaded = store.load_all(load_credentials=False)
227+
self.connections = reloaded
228+
except Exception:
229+
pass # Keep existing connections if reload fails
230+
219231
self.refresh_tree()
220232
loader = getattr(self, "_load_schema_cache", None)
221233
if callable(loader):

sqlit/domains/shell/app/commands/debug.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,16 @@ def _set_debug_enabled(app: Any, enabled: bool) -> None:
4141
else:
4242
app._debug_events_enabled = bool(enabled)
4343

44+
# Persist the setting across sessions
45+
try:
46+
services = getattr(app, "services", None)
47+
if services:
48+
store = getattr(services, "settings_store", None)
49+
if store:
50+
store.set("debug_events_enabled", enabled)
51+
except Exception:
52+
pass
53+
4454
path = getattr(app, "_debug_event_log_path", None)
4555
suffix = f" (log: {path})" if path else ""
4656
state = "enabled" if enabled else "disabled"

sqlit/domains/shell/app/startup_flow.py

Lines changed: 89 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ def run_on_mount(app: AppProtocol) -> None:
3636
app._startup_stamp("settings_loaded")
3737

3838
app._expanded_paths = set(settings.get("expanded_nodes", []))
39+
if settings.get("debug_events_enabled"):
40+
setter = getattr(app, "_set_debug_events_enabled", None)
41+
if callable(setter):
42+
setter(True)
3943
if "process_worker" in settings:
4044
app.services.runtime.process_worker = bool(settings.get("process_worker"))
4145
if "process_worker_warm_on_idle" in settings:
@@ -83,6 +87,9 @@ def run_on_mount(app: AppProtocol) -> None:
8387
app.object_tree.cursor_line = 0
8488
app._update_section_labels()
8589
maybe_restore_connection_screen(app)
90+
# Auto-connect to pending connection after driver install (if not already connecting)
91+
if app._startup_connect_config is None:
92+
maybe_auto_connect_pending(app)
8693
app._startup_stamp("restore_checked")
8794
if app._debug_mode:
8895
app.call_after_refresh(app._record_launch_ms)
@@ -224,6 +231,83 @@ def _get_restart_cache_path() -> Path:
224231
return Path(tempfile.gettempdir()) / "sqlit-driver-install-restore.json"
225232

226233

234+
def maybe_auto_connect_pending(app: AppProtocol) -> bool:
235+
"""Auto-connect to a pending connection after driver install restart.
236+
237+
Returns True if a connection was initiated, False otherwise.
238+
"""
239+
from sqlit.shared.core.debug_events import emit_debug_event
240+
241+
from sqlit.domains.connections.ui.restart_cache import (
242+
clear_restart_cache,
243+
get_restart_cache_path,
244+
)
245+
246+
cache_path = get_restart_cache_path()
247+
emit_debug_event(
248+
"startup.pending_connection_check",
249+
cache_path=str(cache_path),
250+
exists=cache_path.exists(),
251+
)
252+
if not cache_path.exists():
253+
return False
254+
255+
emit_debug_event(
256+
"startup.pending_connection_found",
257+
contents=cache_path.read_text(),
258+
)
259+
260+
try:
261+
payload = json.loads(cache_path.read_text(encoding="utf-8"))
262+
except Exception as e:
263+
emit_debug_event("startup.pending_connection_parse_error", error=str(e))
264+
clear_restart_cache()
265+
return False
266+
267+
# Always clear cache after reading
268+
clear_restart_cache()
269+
270+
# Check for version 2 pending_connection type
271+
if not isinstance(payload, dict):
272+
emit_debug_event("startup.pending_connection_invalid", reason="not a dict")
273+
return False
274+
if payload.get("version") != 2:
275+
emit_debug_event("startup.pending_connection_invalid", reason="wrong version", version=payload.get("version"))
276+
return False
277+
if payload.get("type") != "pending_connection":
278+
emit_debug_event("startup.pending_connection_invalid", reason="wrong type", type=payload.get("type"))
279+
return False
280+
281+
connection_name = payload.get("connection_name")
282+
if not connection_name:
283+
emit_debug_event("startup.pending_connection_invalid", reason="no connection_name")
284+
return False
285+
286+
emit_debug_event(
287+
"startup.pending_connection_lookup",
288+
connection_name=connection_name,
289+
available_connections=[getattr(c, "name", None) for c in app.connections],
290+
)
291+
292+
# Find the connection by name
293+
config = next(
294+
(c for c in app.connections if getattr(c, "name", None) == connection_name),
295+
None,
296+
)
297+
if config is None:
298+
emit_debug_event("startup.pending_connection_not_found", connection_name=connection_name)
299+
return False
300+
301+
emit_debug_event("startup.pending_connection_connecting", connection_name=connection_name)
302+
303+
# Auto-connect after refresh (same pattern as startup_connect_config)
304+
def _connect_pending() -> None:
305+
app.connect_to_server(config)
306+
307+
app.call_after_refresh(_connect_pending)
308+
return True
309+
310+
227311
def maybe_restore_connection_screen(app: AppProtocol) -> None:
228312
"""Restore an in-progress connection form after a driver-install restart."""
229313
cache_path = _get_restart_cache_path()
@@ -239,14 +323,16 @@ def maybe_restore_connection_screen(app: AppProtocol) -> None:
239323
pass
240324
return
241325

326+
# Only handle version 1 (connection form restore), leave version 2 for maybe_auto_connect_pending
327+
if not isinstance(payload, dict) or payload.get("version") != 1:
328+
return
329+
330+
# Clear cache only for version 1
242331
try:
243332
cache_path.unlink(missing_ok=True)
244333
except Exception:
245334
pass
246335

247-
if not isinstance(payload, dict) or payload.get("version") != 1:
248-
return
249-
250336
values = payload.get("values")
251337
if not isinstance(values, dict):
252338
return
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
"""Test auto-reconnect after driver installation restart."""
2+
3+
from __future__ import annotations
4+
5+
import json
6+
from unittest.mock import MagicMock
7+
8+
from sqlit.domains.connections.domain.config import ConnectionConfig
9+
10+
11+
class TestAutoReconnectAfterDriverInstall:
12+
"""Test that app auto-connects after driver install restart."""
13+
14+
def test_pending_connection_cache_written_on_missing_driver(self):
15+
"""
16+
When user tries to connect but driver is missing,
17+
the connection name should be cached for auto-reconnect after restart.
18+
"""
19+
from sqlit.domains.connections.ui.restart_cache import (
20+
get_restart_cache_path,
21+
write_pending_connection_cache,
22+
)
23+
24+
config = ConnectionConfig(name="my-mssql-server", db_type="mssql")
25+
26+
# Write the pending connection cache
27+
write_pending_connection_cache(config.name)
28+
29+
# Verify cache was written
30+
cache_path = get_restart_cache_path()
31+
assert cache_path.exists()
32+
33+
payload = json.loads(cache_path.read_text())
34+
assert payload["version"] == 2
35+
assert payload["type"] == "pending_connection"
36+
assert payload["connection_name"] == "my-mssql-server"
37+
38+
# Cleanup
39+
cache_path.unlink(missing_ok=True)
40+
41+
def test_startup_reads_pending_connection_and_connects(self):
42+
"""
43+
On startup, if pending_connection cache exists,
44+
app should auto-connect to that connection.
45+
"""
46+
from sqlit.domains.connections.ui.restart_cache import (
47+
get_restart_cache_path,
48+
write_pending_connection_cache,
49+
)
50+
from sqlit.domains.shell.app.startup_flow import maybe_auto_connect_pending
51+
52+
# Setup: Write pending connection cache
53+
write_pending_connection_cache("my-mssql-server")
54+
55+
# Mock app with the saved connection
56+
mock_app = MagicMock()
57+
saved_config = ConnectionConfig(name="my-mssql-server", db_type="mssql")
58+
mock_app.connections = [saved_config]
59+
mock_app.connect_to_server = MagicMock()
60+
mock_app.call_after_refresh = MagicMock()
61+
62+
# Call the startup function
63+
result = maybe_auto_connect_pending(mock_app)
64+
65+
# Should have scheduled a connection via call_after_refresh
66+
assert result is True
67+
mock_app.call_after_refresh.assert_called_once()
68+
69+
# Execute the callback to verify it calls connect_to_server
70+
callback = mock_app.call_after_refresh.call_args[0][0]
71+
callback()
72+
mock_app.connect_to_server.assert_called_once_with(saved_config)
73+
74+
# Cache should be cleared
75+
assert not get_restart_cache_path().exists()
76+
77+
def test_startup_ignores_missing_connection(self):
78+
"""
79+
If the cached connection no longer exists, don't crash.
80+
"""
81+
from sqlit.domains.connections.ui.restart_cache import (
82+
get_restart_cache_path,
83+
write_pending_connection_cache,
84+
)
85+
from sqlit.domains.shell.app.startup_flow import maybe_auto_connect_pending
86+
87+
write_pending_connection_cache("deleted-connection")
88+
89+
mock_app = MagicMock()
90+
mock_app.connections = [] # No connections
91+
mock_app.connect_to_server = MagicMock()
92+
93+
result = maybe_auto_connect_pending(mock_app)
94+
95+
# Should return False (no connection made)
96+
assert result is False
97+
mock_app.connect_to_server.assert_not_called()
98+
99+
# Cache should still be cleared
100+
assert not get_restart_cache_path().exists()

0 commit comments

Comments
 (0)