Skip to content

Commit ec6e5f5

Browse files
authored
Merge pull request #121 from Maxteabag/fix/mssql-create-databasef
fix: enable autocommit for MSSQL connections
2 parents 3930f3a + 23acb7a commit ec6e5f5

File tree

13 files changed

+485
-19
lines changed

13 files changed

+485
-19
lines changed

sqlit/domains/connections/providers/mssql/adapter.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,10 @@ def connect(self, config: ConnectionConfig) -> Any:
180180
)
181181

182182
conn_str = self._build_connection_string(config)
183-
return mssql_python.connect(conn_str)
183+
conn = mssql_python.connect(conn_str)
184+
# Enable autocommit to allow DDL statements like CREATE DATABASE
185+
conn.autocommit = True
186+
return conn
184187

185188
def get_databases(self, conn: Any) -> list[str]:
186189
"""Get list of databases from SQL Server."""

sqlit/domains/connections/providers/oracle/adapter.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,18 @@ def connect(self, config: ConnectionConfig) -> Any:
7272
if endpoint is None:
7373
raise ValueError("Oracle connections require a TCP-style endpoint.")
7474
port = int(endpoint.port or get_default_port("oracle"))
75-
# Use Easy Connect string format: host:port/service_name
76-
dsn = f"{endpoint.host}:{port}/{endpoint.database}"
75+
76+
# Determine connection type: service_name (default) or sid
77+
connection_type = config.get_option("oracle_connection_type", "service_name")
78+
79+
if connection_type == "sid":
80+
# SID format: host:port:sid (uses colon separator)
81+
# SID is stored in oracle_sid field, fall back to database for backward compat
82+
sid = config.get_option("oracle_sid") or endpoint.database
83+
dsn = f"{endpoint.host}:{port}:{sid}"
84+
else:
85+
# Service Name format: host:port/service_name (uses slash separator)
86+
dsn = f"{endpoint.host}:{port}/{endpoint.database}"
7787

7888
# Determine connection mode based on oracle_role
7989
oracle_role = config.get_option("oracle_role", "normal")

sqlit/domains/connections/providers/oracle/schema.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,21 @@ def _get_oracle_role_options() -> tuple[SelectOption, ...]:
2020
)
2121

2222

23+
def _get_oracle_connection_type_options() -> tuple[SelectOption, ...]:
24+
return (
25+
SelectOption("service_name", "Service Name"),
26+
SelectOption("sid", "SID"),
27+
)
28+
29+
30+
def _oracle_connection_type_is_service_name(values: dict) -> bool:
31+
return values.get("oracle_connection_type", "service_name") != "sid"
32+
33+
34+
def _oracle_connection_type_is_sid(values: dict) -> bool:
35+
return values.get("oracle_connection_type") == "sid"
36+
37+
2338
SCHEMA = ConnectionSchema(
2439
db_type="oracle",
2540
display_name="Oracle",
@@ -32,11 +47,26 @@ def _get_oracle_role_options() -> tuple[SelectOption, ...]:
3247
group="server_port",
3348
),
3449
_port_field("1521"),
50+
SchemaField(
51+
name="oracle_connection_type",
52+
label="Connection Type",
53+
field_type=FieldType.DROPDOWN,
54+
options=_get_oracle_connection_type_options(),
55+
default="service_name",
56+
),
3557
SchemaField(
3658
name="database",
3759
label="Service Name",
3860
placeholder="ORCL or XEPDB1",
3961
required=True,
62+
visible_when=_oracle_connection_type_is_service_name,
63+
),
64+
SchemaField(
65+
name="oracle_sid",
66+
label="SID",
67+
placeholder="ORCL",
68+
required=True,
69+
visible_when=_oracle_connection_type_is_sid,
4070
),
4171
_username_field(),
4272
_password_field(),

sqlit/domains/connections/providers/oracle_legacy/schema.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,21 @@ def _get_oracle_role_options() -> tuple[SelectOption, ...]:
2020
)
2121

2222

23+
def _get_oracle_connection_type_options() -> tuple[SelectOption, ...]:
24+
return (
25+
SelectOption("service_name", "Service Name"),
26+
SelectOption("sid", "SID"),
27+
)
28+
29+
30+
def _oracle_connection_type_is_service_name(values: dict) -> bool:
31+
return values.get("oracle_connection_type", "service_name") != "sid"
32+
33+
34+
def _oracle_connection_type_is_sid(values: dict) -> bool:
35+
return values.get("oracle_connection_type") == "sid"
36+
37+
2338
def _get_oracle_client_mode_options() -> tuple[SelectOption, ...]:
2439
return (
2540
SelectOption("thick", "Thick (Instant Client)"),
@@ -43,11 +58,26 @@ def _oracle_thick_mode_enabled(values: dict) -> bool:
4358
group="server_port",
4459
),
4560
_port_field("1521"),
61+
SchemaField(
62+
name="oracle_connection_type",
63+
label="Connection Type",
64+
field_type=FieldType.DROPDOWN,
65+
options=_get_oracle_connection_type_options(),
66+
default="service_name",
67+
),
4668
SchemaField(
4769
name="database",
4870
label="Service Name",
71+
placeholder="ORCL or XEPDB1",
72+
required=True,
73+
visible_when=_oracle_connection_type_is_service_name,
74+
),
75+
SchemaField(
76+
name="oracle_sid",
77+
label="SID",
4978
placeholder="ORCL",
5079
required=True,
80+
visible_when=_oracle_connection_type_is_sid,
5181
),
5282
_username_field(),
5383
_password_field(),

sqlit/domains/explorer/ui/mixins/tree.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,8 @@ def action_select_table(self: TreeMixinHost) -> None:
295295
"name": data.name,
296296
"columns": [],
297297
}
298+
# Stash per-result metadata so results can resolve PKs without relying on globals.
299+
self._pending_result_table_info = self._last_query_table
298300
self._prime_last_query_table_columns(data.database, data.schema, data.name)
299301

300302
self.query_input.text = self.current_provider.dialect.build_select_query(

sqlit/domains/query/ui/mixins/query_execution.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,17 @@ async def _run_query_async(self: QueryMixinHost, query: str, keep_insert_mode: b
417417
executable_statements = [s for s in statements if not is_comment_only_statement(s)]
418418
is_multi_statement = len(executable_statements) > 1
419419

420+
if is_multi_statement:
421+
self._pending_result_table_info = None
422+
elif executable_statements:
423+
if getattr(self, "_pending_result_table_info", None) is None:
424+
table_info = self._infer_result_table_info(executable_statements[0])
425+
if table_info is not None:
426+
self._pending_result_table_info = table_info
427+
prime = getattr(self, "_prime_result_table_columns", None)
428+
if callable(prime):
429+
prime(table_info)
430+
420431
try:
421432
start_time = time.perf_counter()
422433
max_rows = self.services.runtime.max_rows or MAX_FETCH_ROWS

sqlit/domains/query/ui/mixins/query_results.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ def _render_results_table_incremental(
224224
escape: bool,
225225
row_limit: int,
226226
render_token: int,
227+
table_info: dict[str, Any] | None = None,
227228
) -> None:
228229
initial_count = min(RESULTS_RENDER_INITIAL_ROWS, row_limit)
229230
initial_rows = rows[:initial_count] if initial_count > 0 else []
@@ -269,9 +270,16 @@ def _render_results_table_incremental(
269270
pass
270271
if render_token == getattr(self, "_results_render_token", 0):
271272
self._replace_results_table_with_data(columns, rows, escape=escape)
273+
if table_info is not None:
274+
try:
275+
self.results_table.result_table_info = table_info
276+
except Exception:
277+
pass
272278
return
273279
if render_token != getattr(self, "_results_render_token", 0):
274280
return
281+
if table_info is not None:
282+
table.result_table_info = table_info
275283
self._replace_results_table_with_table(table)
276284
self._schedule_results_render(
277285
table,
@@ -291,6 +299,7 @@ async def _display_query_results(
291299
self._last_result_columns = columns
292300
self._last_result_rows = rows
293301
self._last_result_row_count = row_count
302+
table_info = getattr(self, "_pending_result_table_info", None)
294303

295304
# Switch to single result mode (in case we were showing stacked results)
296305
self._show_single_result_mode()
@@ -304,12 +313,15 @@ async def _display_query_results(
304313
escape=True,
305314
row_limit=row_limit,
306315
render_token=render_token,
316+
table_info=table_info,
307317
)
308318
else:
309319
render_rows = rows[:row_limit] if row_limit else []
310320
table = self._build_results_table(columns, render_rows, escape=True)
311321
if render_token != getattr(self, "_results_render_token", 0):
312322
return
323+
if table_info is not None:
324+
table.result_table_info = table_info
313325
self._replace_results_table_with_table(table)
314326

315327
time_str = format_duration_ms(elapsed_ms)
@@ -320,9 +332,15 @@ async def _display_query_results(
320332
)
321333
else:
322334
self.notify(f"Query returned {row_count} rows in {time_str}")
335+
if table_info is not None:
336+
prime = getattr(self, "_prime_result_table_columns", None)
337+
if callable(prime):
338+
prime(table_info)
339+
self._pending_result_table_info = None
323340

324341
def _display_non_query_result(self: QueryMixinHost, affected: int, elapsed_ms: float) -> None:
325342
"""Display non-query result (called on main thread)."""
343+
self._pending_result_table_info = None
326344
self._last_result_columns = ["Result"]
327345
self._last_result_rows = [(f"{affected} row(s) affected",)]
328346
self._last_result_row_count = 1
@@ -337,6 +355,7 @@ def _display_non_query_result(self: QueryMixinHost, affected: int, elapsed_ms: f
337355
def _display_query_error(self: QueryMixinHost, error_message: str) -> None:
338356
"""Display query error (called on main thread)."""
339357
self._cancel_results_render()
358+
self._pending_result_table_info = None
340359
# notify(severity="error") handles displaying the error in results via _show_error_in_results
341360
self.notify(f"Query error: {error_message}", severity="error")
342361

@@ -360,7 +379,17 @@ def _display_multi_statement_results(
360379

361380
# Add each result section
362381
for i, stmt_result in enumerate(multi_result.results):
363-
container.add_result_section(stmt_result, i, auto_collapse=auto_collapse)
382+
table_info = self._infer_result_table_info(stmt_result.statement)
383+
if table_info is not None:
384+
prime = getattr(self, "_prime_result_table_columns", None)
385+
if callable(prime):
386+
prime(table_info)
387+
container.add_result_section(
388+
stmt_result,
389+
i,
390+
auto_collapse=auto_collapse,
391+
table_info=table_info,
392+
)
364393

365394
# Show the stacked results container, hide single result table
366395
self._show_stacked_results_mode()
@@ -378,6 +407,7 @@ def _display_multi_statement_results(
378407
)
379408
else:
380409
self.notify(f"Executed {total} statements in {time_str}")
410+
self._pending_result_table_info = None
381411

382412
def _get_stacked_results_container(self: QueryMixinHost) -> Any:
383413
"""Get the stacked results container."""
@@ -410,3 +440,30 @@ def _show_single_result_mode(self: QueryMixinHost) -> None:
410440
stacked.remove_class("active")
411441
except Exception:
412442
pass
443+
444+
def _infer_result_table_info(self: QueryMixinHost, sql: str) -> dict[str, Any] | None:
445+
"""Best-effort inference of a single source table for query results."""
446+
from sqlit.domains.query.completion import extract_table_refs
447+
448+
refs = extract_table_refs(sql)
449+
if len(refs) != 1:
450+
return None
451+
ref = refs[0]
452+
schema = ref.schema
453+
name = ref.name
454+
database = None
455+
table_metadata = getattr(self, "_table_metadata", {}) or {}
456+
key_candidates = [name.lower()]
457+
if schema:
458+
key_candidates.insert(0, f"{schema}.{name}".lower())
459+
for key in key_candidates:
460+
metadata = table_metadata.get(key)
461+
if metadata:
462+
schema, name, database = metadata
463+
break
464+
return {
465+
"database": database,
466+
"schema": schema,
467+
"name": name,
468+
"columns": [],
469+
}

0 commit comments

Comments
 (0)