Skip to content

Commit 4371112

Browse files
committed
Attach result table metadata for PK-safe updates
1 parent f3a1358 commit 4371112

File tree

7 files changed

+209
-16
lines changed

7 files changed

+209
-16
lines changed

sqlit/domains/explorer/ui/mixins/tree.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,8 @@ def action_select_table(self: TreeMixinHost) -> None:
295295
"name": data.name,
296296
"columns": [],
297297
}
298+
# Stash per-result metadata so results can resolve PKs without relying on globals.
299+
self._pending_result_table_info = self._last_query_table
298300
self._prime_last_query_table_columns(data.database, data.schema, data.name)
299301

300302
self.query_input.text = self.current_provider.dialect.build_select_query(

sqlit/domains/query/ui/mixins/query_execution.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,17 @@ async def _run_query_async(self: QueryMixinHost, query: str, keep_insert_mode: b
417417
executable_statements = [s for s in statements if not is_comment_only_statement(s)]
418418
is_multi_statement = len(executable_statements) > 1
419419

420+
if is_multi_statement:
421+
self._pending_result_table_info = None
422+
elif executable_statements:
423+
if getattr(self, "_pending_result_table_info", None) is None:
424+
table_info = self._infer_result_table_info(executable_statements[0])
425+
if table_info is not None:
426+
self._pending_result_table_info = table_info
427+
prime = getattr(self, "_prime_result_table_columns", None)
428+
if callable(prime):
429+
prime(table_info)
430+
420431
try:
421432
start_time = time.perf_counter()
422433
max_rows = self.services.runtime.max_rows or MAX_FETCH_ROWS

sqlit/domains/query/ui/mixins/query_results.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ def _render_results_table_incremental(
224224
escape: bool,
225225
row_limit: int,
226226
render_token: int,
227+
table_info: dict[str, Any] | None = None,
227228
) -> None:
228229
initial_count = min(RESULTS_RENDER_INITIAL_ROWS, row_limit)
229230
initial_rows = rows[:initial_count] if initial_count > 0 else []
@@ -269,9 +270,16 @@ def _render_results_table_incremental(
269270
pass
270271
if render_token == getattr(self, "_results_render_token", 0):
271272
self._replace_results_table_with_data(columns, rows, escape=escape)
273+
if table_info is not None:
274+
try:
275+
self.results_table.result_table_info = table_info
276+
except Exception:
277+
pass
272278
return
273279
if render_token != getattr(self, "_results_render_token", 0):
274280
return
281+
if table_info is not None:
282+
table.result_table_info = table_info
275283
self._replace_results_table_with_table(table)
276284
self._schedule_results_render(
277285
table,
@@ -291,6 +299,7 @@ async def _display_query_results(
291299
self._last_result_columns = columns
292300
self._last_result_rows = rows
293301
self._last_result_row_count = row_count
302+
table_info = getattr(self, "_pending_result_table_info", None)
294303

295304
# Switch to single result mode (in case we were showing stacked results)
296305
self._show_single_result_mode()
@@ -304,12 +313,15 @@ async def _display_query_results(
304313
escape=True,
305314
row_limit=row_limit,
306315
render_token=render_token,
316+
table_info=table_info,
307317
)
308318
else:
309319
render_rows = rows[:row_limit] if row_limit else []
310320
table = self._build_results_table(columns, render_rows, escape=True)
311321
if render_token != getattr(self, "_results_render_token", 0):
312322
return
323+
if table_info is not None:
324+
table.result_table_info = table_info
313325
self._replace_results_table_with_table(table)
314326

315327
time_str = format_duration_ms(elapsed_ms)
@@ -320,9 +332,15 @@ async def _display_query_results(
320332
)
321333
else:
322334
self.notify(f"Query returned {row_count} rows in {time_str}")
335+
if table_info is not None:
336+
prime = getattr(self, "_prime_result_table_columns", None)
337+
if callable(prime):
338+
prime(table_info)
339+
self._pending_result_table_info = None
323340

324341
def _display_non_query_result(self: QueryMixinHost, affected: int, elapsed_ms: float) -> None:
325342
"""Display non-query result (called on main thread)."""
343+
self._pending_result_table_info = None
326344
self._last_result_columns = ["Result"]
327345
self._last_result_rows = [(f"{affected} row(s) affected",)]
328346
self._last_result_row_count = 1
@@ -337,6 +355,7 @@ def _display_non_query_result(self: QueryMixinHost, affected: int, elapsed_ms: f
337355
def _display_query_error(self: QueryMixinHost, error_message: str) -> None:
338356
"""Display query error (called on main thread)."""
339357
self._cancel_results_render()
358+
self._pending_result_table_info = None
340359
# notify(severity="error") handles displaying the error in results via _show_error_in_results
341360
self.notify(f"Query error: {error_message}", severity="error")
342361

@@ -360,7 +379,17 @@ def _display_multi_statement_results(
360379

361380
# Add each result section
362381
for i, stmt_result in enumerate(multi_result.results):
363-
container.add_result_section(stmt_result, i, auto_collapse=auto_collapse)
382+
table_info = self._infer_result_table_info(stmt_result.statement)
383+
if table_info is not None:
384+
prime = getattr(self, "_prime_result_table_columns", None)
385+
if callable(prime):
386+
prime(table_info)
387+
container.add_result_section(
388+
stmt_result,
389+
i,
390+
auto_collapse=auto_collapse,
391+
table_info=table_info,
392+
)
364393

365394
# Show the stacked results container, hide single result table
366395
self._show_stacked_results_mode()
@@ -378,6 +407,7 @@ def _display_multi_statement_results(
378407
)
379408
else:
380409
self.notify(f"Executed {total} statements in {time_str}")
410+
self._pending_result_table_info = None
381411

382412
def _get_stacked_results_container(self: QueryMixinHost) -> Any:
383413
"""Get the stacked results container."""
@@ -410,3 +440,30 @@ def _show_single_result_mode(self: QueryMixinHost) -> None:
410440
stacked.remove_class("active")
411441
except Exception:
412442
pass
443+
444+
def _infer_result_table_info(self: QueryMixinHost, sql: str) -> dict[str, Any] | None:
445+
"""Best-effort inference of a single source table for query results."""
446+
from sqlit.domains.query.completion import extract_table_refs
447+
448+
refs = extract_table_refs(sql)
449+
if len(refs) != 1:
450+
return None
451+
ref = refs[0]
452+
schema = ref.schema
453+
name = ref.name
454+
database = None
455+
table_metadata = getattr(self, "_table_metadata", {}) or {}
456+
key_candidates = [name.lower()]
457+
if schema:
458+
key_candidates.insert(0, f"{schema}.{name}".lower())
459+
for key in key_candidates:
460+
metadata = table_metadata.get(key)
461+
if metadata:
462+
schema, name, database = metadata
463+
break
464+
return {
465+
"database": database,
466+
"schema": schema,
467+
"name": name,
468+
"columns": [],
469+
}

sqlit/domains/results/ui/mixins/results.py

Lines changed: 131 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from sqlit.shared.ui.protocols import ResultsMixinHost
99
from sqlit.shared.ui.widgets import SqlitDataTable
1010

11+
MIN_TIMER_DELAY_S = 0.001
12+
1113

1214
class ResultsMixin:
1315
"""Mixin providing results handling functionality."""
@@ -19,6 +21,122 @@ class ResultsMixin:
1921
_tooltip_showing: bool = False
2022
_tooltip_timer: Any | None = None
2123

24+
def _schedule_results_timer(self: ResultsMixinHost, delay_s: float, callback: Any) -> Any | None:
25+
set_timer = getattr(self, "set_timer", None)
26+
if callable(set_timer):
27+
return set_timer(delay_s, callback)
28+
call_later = getattr(self, "call_later", None)
29+
if callable(call_later):
30+
try:
31+
call_later(callback)
32+
return None
33+
except Exception:
34+
pass
35+
try:
36+
callback()
37+
except Exception:
38+
pass
39+
return None
40+
41+
def _apply_result_table_columns(
42+
self: ResultsMixinHost,
43+
table_info: dict[str, Any],
44+
token: int,
45+
columns: list[Any],
46+
) -> None:
47+
if table_info.get("_columns_token") != token:
48+
return
49+
table_info["columns"] = columns
50+
51+
def _prime_result_table_columns(self: ResultsMixinHost, table_info: dict[str, Any] | None) -> None:
52+
if not table_info:
53+
return
54+
if table_info.get("columns"):
55+
return
56+
name = table_info.get("name")
57+
if not name:
58+
return
59+
database = table_info.get("database")
60+
schema = table_info.get("schema")
61+
token = int(table_info.get("_columns_token", 0)) + 1
62+
table_info["_columns_token"] = token
63+
64+
async def work_async() -> None:
65+
import asyncio
66+
67+
columns: list[Any] = []
68+
try:
69+
runtime = getattr(self.services, "runtime", None)
70+
use_worker = bool(getattr(runtime, "process_worker", False)) and not bool(
71+
getattr(getattr(runtime, "mock", None), "enabled", False)
72+
)
73+
client = None
74+
if use_worker and hasattr(self, "_get_process_worker_client_async"):
75+
client = await self._get_process_worker_client_async() # type: ignore[attr-defined]
76+
77+
if client is not None and hasattr(client, "list_columns") and self.current_config is not None:
78+
outcome = await asyncio.to_thread(
79+
client.list_columns,
80+
config=self.current_config,
81+
database=database,
82+
schema=schema,
83+
name=name,
84+
)
85+
if getattr(outcome, "cancelled", False):
86+
return
87+
error = getattr(outcome, "error", None)
88+
if error:
89+
raise RuntimeError(error)
90+
columns = outcome.columns or []
91+
else:
92+
schema_service = getattr(self, "_get_schema_service", None)
93+
if callable(schema_service):
94+
service = self._get_schema_service()
95+
if service:
96+
columns = await asyncio.to_thread(
97+
service.list_columns,
98+
database,
99+
schema,
100+
name,
101+
)
102+
except Exception:
103+
columns = []
104+
105+
self._schedule_results_timer(
106+
MIN_TIMER_DELAY_S,
107+
lambda: self._apply_result_table_columns(table_info, token, columns),
108+
)
109+
110+
self.run_worker(work_async(), name=f"prime-result-columns-{name}", exclusive=False)
111+
112+
def _normalize_column_name(self: ResultsMixinHost, name: str) -> str:
113+
trimmed = name.strip()
114+
if len(trimmed) >= 2:
115+
if trimmed[0] == trimmed[-1] and trimmed[0] in ("\"", "`"):
116+
trimmed = trimmed[1:-1]
117+
elif trimmed[0] == "[" and trimmed[-1] == "]":
118+
trimmed = trimmed[1:-1]
119+
if "." in trimmed and not any(q in trimmed for q in ("\"", "`", "[")):
120+
trimmed = trimmed.split(".")[-1]
121+
return trimmed.lower()
122+
123+
def _get_active_results_table_info(
124+
self: ResultsMixinHost,
125+
table: SqlitDataTable | None,
126+
stacked: bool,
127+
) -> dict[str, Any] | None:
128+
if not table:
129+
return None
130+
if stacked:
131+
section = self._find_results_section(table)
132+
table_info = getattr(section, "result_table_info", None)
133+
if table_info:
134+
return table_info
135+
table_info = getattr(table, "result_table_info", None)
136+
if table_info:
137+
return table_info
138+
return getattr(self, "_last_query_table", None)
139+
22140
def _copy_text(self: ResultsMixinHost, text: str) -> bool:
23141
"""Copy text to clipboard if possible, otherwise store internally."""
24142
self._internal_clipboard = text
@@ -610,21 +728,20 @@ def sql_value(v: object) -> str:
610728
# Get table name and primary key columns
611729
table_name = "<table>"
612730
pk_column_names: set[str] = set()
613-
614-
if hasattr(self, "_last_query_table") and self._last_query_table:
615-
table_info = self._last_query_table
616-
table_name = table_info["name"]
731+
table_info = self._get_active_results_table_info(table, _stacked)
732+
if table_info:
733+
table_name = table_info.get("name", table_name)
617734
# Get PK columns from column info
618735
for col in table_info.get("columns", []):
619736
if col.is_primary_key:
620-
pk_column_names.add(col.name)
737+
pk_column_names.add(self._normalize_column_name(col.name))
621738

622739
# Build WHERE clause - prefer PK columns, fall back to all columns
623740
where_parts = []
624741
for i, col in enumerate(columns):
625742
if i < len(row_values):
626743
# If we have PK info, only use PK columns; otherwise use all columns
627-
if pk_column_names and col not in pk_column_names:
744+
if pk_column_names and self._normalize_column_name(col) not in pk_column_names:
628745
continue
629746
val = row_values[i]
630747
if val is None:
@@ -685,9 +802,10 @@ def action_edit_cell(self: ResultsMixinHost) -> None:
685802
column_name = columns[cursor_col]
686803

687804
# Check if this column is a primary key - don't allow editing PKs
688-
if hasattr(self, "_last_query_table") and self._last_query_table:
689-
for col in self._last_query_table.get("columns", []):
690-
if col.name == column_name and col.is_primary_key:
805+
table_info = self._get_active_results_table_info(table, _stacked)
806+
if table_info:
807+
for col in table_info.get("columns", []):
808+
if col.is_primary_key and self._normalize_column_name(col.name) == self._normalize_column_name(column_name):
691809
self.notify("Cannot edit primary key column", severity="warning")
692810
return
693811

@@ -705,21 +823,19 @@ def sql_value(v: object) -> str:
705823
# Get table name and primary key columns
706824
table_name = "<table>"
707825
pk_column_names: set[str] = set()
708-
709-
if hasattr(self, "_last_query_table") and self._last_query_table:
710-
table_info = self._last_query_table
711-
table_name = table_info["name"]
826+
if table_info:
827+
table_name = table_info.get("name", table_name)
712828
# Get PK columns from column info
713829
for col in table_info.get("columns", []):
714830
if col.is_primary_key:
715-
pk_column_names.add(col.name)
831+
pk_column_names.add(self._normalize_column_name(col.name))
716832

717833
# Build WHERE clause - prefer PK columns, fall back to all columns
718834
where_parts = []
719835
for i, col in enumerate(columns):
720836
if i < len(row_values):
721837
# If we have PK info, only use PK columns; otherwise use all columns
722-
if pk_column_names and col not in pk_column_names:
838+
if pk_column_names and self._normalize_column_name(col) not in pk_column_names:
723839
continue
724840
val = row_values[i]
725841
if val is None:

sqlit/domains/shell/app/main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ def __init__(
197197
self._columns_loading: set[str] = set()
198198
self._state_machine = UIStateMachine()
199199
self._last_query_table: dict[str, Any] | None = None
200+
self._pending_result_table_info: dict[str, Any] | None = None
200201
self._query_target_database: str | None = None # Target DB for auto-generated queries
201202
self._restart_requested: bool = False
202203
# Idle scheduler for background work

sqlit/shared/ui/protocols/results.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ class ResultsStateProtocol(Protocol):
1616
_last_result_row_count: int
1717
_internal_clipboard: str
1818
_last_query_table: dict[str, Any] | None
19+
_pending_result_table_info: dict[str, Any] | None
1920
_results_table_counter: int
2021
_results_filter_visible: bool
2122
_results_filter_text: str

0 commit comments

Comments
 (0)