Skip to content

Commit 2ee7953

Browse files
committed
Add stale connection reconnect coverage
1 parent 2d5aa54 commit 2ee7953

File tree

7 files changed

+1027
-48
lines changed

7 files changed

+1027
-48
lines changed

infra/docker/docker-compose.test.yml

Lines changed: 60 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ services:
77
MSSQL_SA_PASSWORD: "TestPassword123!"
88
MSSQL_PID: "Developer"
99
ports:
10-
- "1434:1433"
10+
- "${MSSQL_PORT:-1434}:1433"
1111
healthcheck:
1212
test: /opt/mssql-tools18/bin/sqlcmd -S localhost -U sa -P "TestPassword123!" -C -Q "SELECT 1" || exit 1
1313
interval: 10s
@@ -22,7 +22,7 @@ services:
2222
POSTGRES_PASSWORD: "TestPassword123!"
2323
POSTGRES_DB: "test_sqlit"
2424
ports:
25-
- "5432:5432"
25+
- "${POSTGRES_PORT:-5432}:5432"
2626
healthcheck:
2727
test: ["CMD-SHELL", "pg_isready -U testuser -d test_sqlit"]
2828
interval: 5s
@@ -41,7 +41,7 @@ services:
4141
MYSQL_PASSWORD: "TestPassword123!"
4242
MYSQL_DATABASE: "test_sqlit"
4343
ports:
44-
- "3306:3306"
44+
- "${MYSQL_PORT:-3306}:3306"
4545
healthcheck:
4646
test: ["CMD", "mysqladmin", "ping", "-h", "localhost", "-u", "testuser", "-pTestPassword123!"]
4747
interval: 5s
@@ -70,14 +70,69 @@ services:
7070
APP_USER: "testuser"
7171
APP_USER_PASSWORD: "TestPassword123!"
7272
ports:
73-
- "1521:1521"
73+
- "${ORACLE_PORT:-1521}:1521"
7474
healthcheck:
7575
test: ["CMD", "healthcheck.sh"]
7676
interval: 10s
7777
timeout: 5s
7878
retries: 20
7979
start_period: 60s
8080

81+
oracle11g:
82+
image: wnameless/oracle-xe-11g
83+
container_name: sqlit-test-oracle11g
84+
environment:
85+
ORACLE_ALLOW_REMOTE: "true"
86+
ports:
87+
- "1522:1521"
88+
profiles:
89+
- enterprise
90+
91+
db2:
92+
image: icr.io/db2_community/db2:latest
93+
container_name: sqlit-test-db2
94+
privileged: true
95+
environment:
96+
LICENSE: "accept"
97+
DB2INST1_PASSWORD: "TestPassword123!"
98+
DBNAME: "testdb"
99+
ports:
100+
- "50000:50000"
101+
profiles:
102+
- enterprise
103+
104+
trino:
105+
image: trinodb/trino:latest
106+
container_name: sqlit-test-trino
107+
ports:
108+
- "8082:8080"
109+
volumes:
110+
- ../../tests/fixtures/trino/catalog:/etc/trino/catalog
111+
profiles:
112+
- enterprise
113+
healthcheck:
114+
test: ["CMD", "curl", "-f", "http://localhost:8080/v1/info"]
115+
interval: 5s
116+
timeout: 5s
117+
retries: 10
118+
start_period: 10s
119+
120+
presto:
121+
image: prestodb/presto:latest
122+
container_name: sqlit-test-presto
123+
ports:
124+
- "8083:8080"
125+
volumes:
126+
- ../../tests/fixtures/presto/catalog:/etc/presto/catalog
127+
profiles:
128+
- enterprise
129+
healthcheck:
130+
test: ["CMD", "curl", "-f", "http://localhost:8080/v1/info"]
131+
interval: 5s
132+
timeout: 5s
133+
retries: 10
134+
start_period: 10s
135+
81136
mariadb:
82137
image: mariadb:11
83138
container_name: sqlit-test-mariadb
@@ -87,7 +142,7 @@ services:
87142
MARIADB_PASSWORD: "TestPassword123!"
88143
MARIADB_DATABASE: "test_sqlit"
89144
ports:
90-
- "3307:3306"
145+
- "${MARIADB_PORT:-3307}:3306"
91146
healthcheck:
92147
test: ["CMD", "healthcheck.sh", "--connect", "--innodb_initialized"]
93148
interval: 5s

sqlit/domains/explorer/app/schema_service.py

Lines changed: 87 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
)
1717

1818
DbArgResolver = Callable[[str | None], str | None]
19+
ReconnectFn = Callable[[str | None], bool]
1920
ResultT = TypeVar("ResultT")
2021

2122

@@ -24,18 +25,54 @@ class ExplorerSchemaService:
2425
session: ConnectionSession
2526
object_cache: dict[str, dict[str, Any]]
2627
db_arg_resolver: DbArgResolver | None = None
28+
reconnect: ReconnectFn | None = None
2729

2830
def _resolve_db_arg(self, database: str | None) -> str | None:
2931
if self.db_arg_resolver:
3032
return self.db_arg_resolver(database)
3133
return database
3234

33-
def _run(self, fn: Callable[..., ResultT], *args: Any) -> ResultT:
34-
return self.session.executor.submit(fn, *args).result()
35+
def _run(self, fn: Callable[[], ResultT]) -> ResultT:
36+
return self.session.executor.submit(fn).result()
37+
38+
def _resolve_reconnect_db(self, database: str | None) -> str | None:
39+
if database:
40+
return database
41+
endpoint = self.session.config.tcp_endpoint
42+
if endpoint and endpoint.database:
43+
return endpoint.database
44+
return database
45+
46+
def _reconnect(self, database: str | None) -> bool:
47+
target_db = self._resolve_reconnect_db(database)
48+
if self.reconnect is not None:
49+
if self.reconnect(target_db):
50+
self.object_cache.clear()
51+
return True
52+
return False
53+
if not self.session.config.tcp_endpoint:
54+
return False
55+
try:
56+
self.session.switch_database(target_db or "")
57+
self.object_cache.clear()
58+
return True
59+
except Exception:
60+
return False
61+
62+
def _run_with_retry(self, fn: Callable[[], ResultT], database: str | None) -> ResultT:
63+
try:
64+
return self._run(fn)
65+
except Exception:
66+
if not self._reconnect(database):
67+
raise
68+
return self._run(fn)
3569

3670
def list_databases(self) -> list[str]:
3771
inspector = self.session.provider.schema_inspector
38-
return self._run(inspector.get_databases, self.session.connection)
72+
return self._run_with_retry(
73+
lambda: inspector.get_databases(self.session.connection),
74+
None,
75+
)
3976

4077
def list_columns(
4178
self,
@@ -45,7 +82,10 @@ def list_columns(
4582
) -> list[ColumnInfo]:
4683
inspector = self.session.provider.schema_inspector
4784
db_arg = self._resolve_db_arg(database)
48-
return self._run(inspector.get_columns, self.session.connection, name, db_arg, schema)
85+
return self._run_with_retry(
86+
lambda: inspector.get_columns(self.session.connection, name, db_arg, schema),
87+
database,
88+
)
4989

5090
def list_folder_items(self, folder_type: str, database: str | None) -> list[tuple[str, str, str]]:
5191
inspector = self.session.provider.schema_inspector
@@ -64,36 +104,61 @@ def cached(key: str, loader: Callable[[], Any]) -> Any:
64104
return data
65105

66106
if folder_type == "tables":
67-
raw_data = cached("tables", lambda: self._run(inspector.get_tables, self.session.connection, db_arg))
107+
raw_data = cached(
108+
"tables",
109+
lambda: self._run_with_retry(
110+
lambda: inspector.get_tables(self.session.connection, db_arg),
111+
database,
112+
),
113+
)
68114
return [("table", schema, name) for schema, name in raw_data]
69115
if folder_type == "views":
70-
raw_data = cached("views", lambda: self._run(inspector.get_views, self.session.connection, db_arg))
116+
raw_data = cached(
117+
"views",
118+
lambda: self._run_with_retry(
119+
lambda: inspector.get_views(self.session.connection, db_arg),
120+
database,
121+
),
122+
)
71123
return [("view", schema, name) for schema, name in raw_data]
72124
if folder_type == "indexes":
73125
if caps.supports_indexes and isinstance(inspector, IndexInspector):
74126
return [
75127
("index", item.name, item.table_name)
76-
for item in self._run(inspector.get_indexes, self.session.connection, db_arg)
128+
for item in self._run_with_retry(
129+
lambda: inspector.get_indexes(self.session.connection, db_arg),
130+
database,
131+
)
77132
]
78133
return []
79134
if folder_type == "triggers":
80135
if caps.supports_triggers and isinstance(inspector, TriggerInspector):
81136
return [
82137
("trigger", item.name, item.table_name)
83-
for item in self._run(inspector.get_triggers, self.session.connection, db_arg)
138+
for item in self._run_with_retry(
139+
lambda: inspector.get_triggers(self.session.connection, db_arg),
140+
database,
141+
)
84142
]
85143
return []
86144
if folder_type == "sequences":
87145
if caps.supports_sequences and isinstance(inspector, SequenceInspector):
88146
return [
89147
("sequence", item.name, "")
90-
for item in self._run(inspector.get_sequences, self.session.connection, db_arg)
148+
for item in self._run_with_retry(
149+
lambda: inspector.get_sequences(self.session.connection, db_arg),
150+
database,
151+
)
91152
]
92153
return []
93154
if folder_type == "procedures":
94155
if caps.supports_stored_procedures and isinstance(inspector, ProcedureInspector):
95156
raw_data = cached(
96-
"procedures", lambda: self._run(inspector.get_procedures, self.session.connection, db_arg)
157+
"procedures",
158+
lambda: self._run_with_retry(
159+
lambda: inspector.get_procedures(self.session.connection, db_arg),
160+
database,
161+
),
97162
)
98163
return [("procedure", "", name) for name in raw_data]
99164
return []
@@ -104,18 +169,27 @@ def get_index_definition(self, database: str | None, name: str, table_name: str)
104169
if not isinstance(inspector, IndexInspector):
105170
return None
106171
db_arg = self._resolve_db_arg(database)
107-
return self._run(inspector.get_index_definition, self.session.connection, name, table_name, db_arg)
172+
return self._run_with_retry(
173+
lambda: inspector.get_index_definition(self.session.connection, name, table_name, db_arg),
174+
database,
175+
)
108176

109177
def get_trigger_definition(self, database: str | None, name: str, table_name: str) -> dict[str, Any] | None:
110178
inspector = self.session.provider.schema_inspector
111179
if not isinstance(inspector, TriggerInspector):
112180
return None
113181
db_arg = self._resolve_db_arg(database)
114-
return self._run(inspector.get_trigger_definition, self.session.connection, name, table_name, db_arg)
182+
return self._run_with_retry(
183+
lambda: inspector.get_trigger_definition(self.session.connection, name, table_name, db_arg),
184+
database,
185+
)
115186

116187
def get_sequence_definition(self, database: str | None, name: str) -> dict[str, Any] | None:
117188
inspector = self.session.provider.schema_inspector
118189
if not isinstance(inspector, SequenceInspector):
119190
return None
120191
db_arg = self._resolve_db_arg(database)
121-
return self._run(inspector.get_sequence_definition, self.session.connection, name, db_arg)
192+
return self._run_with_retry(
193+
lambda: inspector.get_sequence_definition(self.session.connection, name, db_arg),
194+
database,
195+
)

sqlit/domains/explorer/ui/mixins/tree_schema.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,30 @@ def _get_schema_service(self: TreeMixinHost) -> Any | None:
3535
db_arg_resolver = None
3636
else:
3737
db_arg_resolver = cast(DbArgResolver, db_arg_resolver)
38+
reconnect = None
39+
if hasattr(self, "_session"):
40+
def reconnect(database: str | None) -> bool:
41+
session = getattr(self, "_session", None)
42+
if session is None:
43+
return False
44+
target_db = database
45+
if not target_db and hasattr(self, "_get_effective_database"):
46+
target_db = self._get_effective_database()
47+
if target_db is None:
48+
target_db = ""
49+
try:
50+
session.switch_database(target_db)
51+
self.current_config = session.config
52+
self.current_connection = session.connection
53+
return True
54+
except Exception:
55+
return False
56+
3857
self._schema_service = ExplorerSchemaService(
3958
session=self._session,
4059
object_cache=self._get_object_cache(),
4160
db_arg_resolver=db_arg_resolver,
61+
reconnect=reconnect,
4262
)
4363
self._schema_service_session = self._session
4464
return self._schema_service

sqlit/domains/explorer/ui/tree/loaders.py

Lines changed: 2 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
)
2020
from sqlit.shared.ui.protocols import TreeMixinHost
2121

22-
from . import db_switching, expansion_state, schema_render
22+
from . import expansion_state, schema_render
2323

2424

2525
def ensure_loading_nodes(host: TreeMixinHost) -> set[str]:
@@ -111,10 +111,7 @@ def work() -> None:
111111

112112
host.call_from_thread(on_folder_loaded, host, node, db_name, folder_type, items)
113113
except Exception as error:
114-
if db_name:
115-
host.call_from_thread(fallback_reconnect_and_retry, host, node, data, db_name, error)
116-
else:
117-
host.call_from_thread(on_tree_load_error, host, node, f"Error loading: {error}")
114+
host.call_from_thread(on_tree_load_error, host, node, f"Error loading: {error}")
118115

119116
host.run_worker(work, name=f"load-folder-{folder_type}", thread=True, exclusive=False)
120117

@@ -161,26 +158,3 @@ def on_tree_load_error(host: TreeMixinHost, node: Any, error_message: str) -> No
161158
"""Handle tree load error on main thread."""
162159
clear_loading_state(host, node)
163160
host.notify(escape_markup(error_message), severity="error")
164-
165-
166-
def fallback_reconnect_and_retry(
167-
host: TreeMixinHost,
168-
node: Any,
169-
data: FolderNode,
170-
db_name: str,
171-
original_error: Exception,
172-
) -> None:
173-
"""Try reconnecting to database and retry loading. Show original error if this also fails."""
174-
clear_loading_state(host, node)
175-
176-
try:
177-
db_switching.reconnect_to_database(host, db_name)
178-
except Exception:
179-
host.notify(escape_markup(f"Error loading: {original_error}"), severity="error")
180-
return
181-
182-
node_path = expansion_state.get_node_path(host, node)
183-
loading_nodes = ensure_loading_nodes(host)
184-
loading_nodes.add(node_path)
185-
add_loading_placeholder(host, node)
186-
load_folder_async(host, node, data)

0 commit comments

Comments
 (0)