Skip to content

Commit e9e4634

Browse files
committed
feat: fix g+r/g+t keybindings and add multi-statement support to atomic execution
- Add missing action_g_execute_query and action_g_execute_query_atomic handlers so g+r and g+t work from the g menu - Modify atomic_execute to split and execute statements individually, collecting results into MultiStatementResult - Update _run_query_atomic_async to display multiple result tables - All statements run in same transaction; rollback on any failure
1 parent 4318107 commit e9e4634

File tree

4 files changed

+85
-13
lines changed

4 files changed

+85
-13
lines changed

sqlit/domains/query/app/transaction.py

Lines changed: 60 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from dataclasses import dataclass
1414
from typing import TYPE_CHECKING, Any
1515

16+
from .multi_statement import MultiStatementResult
1617
from .query_service import KeywordQueryAnalyzer, NonQueryResult, QueryKind, QueryResult
1718

1819
if TYPE_CHECKING:
@@ -254,25 +255,35 @@ def _execute_on_connection(
254255
rows_affected = self.provider.query_executor.execute_non_query(conn, sql)
255256
return NonQueryResult(rows_affected=rows_affected)
256257

257-
def atomic_execute(self, sql: str, max_rows: int | None = None) -> QueryResult | NonQueryResult:
258+
def atomic_execute(
259+
self, sql: str, max_rows: int | None = None
260+
) -> QueryResult | NonQueryResult | MultiStatementResult:
258261
"""Execute SQL atomically (all-or-nothing).
259262
260263
Wraps the SQL in BEGIN/COMMIT and rolls back on any error.
264+
Supports multiple statements, returning results for each.
261265
262266
Args:
263267
sql: SQL statement(s) to execute atomically.
264268
max_rows: Maximum rows to fetch for SELECT queries.
265269
266270
Returns:
267-
Result of the last statement.
271+
For single statement: QueryResult or NonQueryResult.
272+
For multiple statements: MultiStatementResult with all results.
268273
269274
Raises:
270275
Exception: If any statement fails (after rollback).
271276
"""
272-
from .multi_statement import normalize_for_execution
277+
from .multi_statement import (
278+
MultiStatementResult,
279+
StatementResult,
280+
normalize_for_execution,
281+
split_statements,
282+
)
273283

274284
# Normalize SQL: convert blank-line-separated to semicolon-separated
275285
sql = normalize_for_execution(sql)
286+
statements = split_statements(sql)
276287

277288
# Create a dedicated connection for this atomic operation
278289
conn = self.provider.connection_factory.connect(self.config)
@@ -285,16 +296,56 @@ def atomic_execute(self, sql: str, max_rows: int | None = None) -> QueryResult |
285296
# Start transaction
286297
self.provider.query_executor.execute_non_query(conn, "BEGIN")
287298

288-
# Execute the SQL
289-
result = self._execute_on_connection(conn, sql, max_rows)
299+
# Single statement - return simple result for backwards compatibility
300+
if len(statements) <= 1:
301+
result = self._execute_on_connection(conn, sql, max_rows)
302+
self.provider.query_executor.execute_non_query(conn, "COMMIT")
303+
return result
290304

291-
# Commit
305+
# Multiple statements - execute each and collect results
306+
results: list[StatementResult] = []
307+
for i, statement in enumerate(statements):
308+
try:
309+
result = self._execute_on_connection(conn, statement, max_rows)
310+
results.append(
311+
StatementResult(
312+
statement=statement,
313+
result=result,
314+
success=True,
315+
error=None,
316+
)
317+
)
318+
except Exception as e:
319+
# Record the error
320+
results.append(
321+
StatementResult(
322+
statement=statement,
323+
result=None,
324+
success=False,
325+
error=str(e),
326+
)
327+
)
328+
# Rollback and return partial results
329+
try:
330+
self.provider.query_executor.execute_non_query(conn, "ROLLBACK")
331+
except Exception:
332+
pass
333+
return MultiStatementResult(
334+
results=results,
335+
completed=False,
336+
error_index=i,
337+
)
338+
339+
# All succeeded - commit
292340
self.provider.query_executor.execute_non_query(conn, "COMMIT")
293-
294-
return result
341+
return MultiStatementResult(
342+
results=results,
343+
completed=True,
344+
error_index=None,
345+
)
295346

296347
except Exception:
297-
# Rollback on any error
348+
# Rollback on any error (e.g., BEGIN or COMMIT failed)
298349
try:
299350
self.provider.query_executor.execute_non_query(conn, "ROLLBACK")
300351
except Exception:

sqlit/domains/query/ui/mixins/query_editing_cursor.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,16 @@ def action_g_execute_single_statement(self: QueryMixinHost) -> None:
4545
self._clear_leader_pending()
4646
self.action_execute_single_statement()
4747

48+
def action_g_execute_query(self: QueryMixinHost) -> None:
49+
"""Execute query via g menu (gr)."""
50+
self._clear_leader_pending()
51+
self.action_execute_query()
52+
53+
def action_g_execute_query_atomic(self: QueryMixinHost) -> None:
54+
"""Execute query as transaction via g menu (gt)."""
55+
self._clear_leader_pending()
56+
self.action_execute_query_atomic()
57+
4858
def action_cursor_left(self: QueryMixinHost) -> None:
4959
"""Move cursor left (h in normal mode)."""
5060
row, col = self.query_input.cursor_location

sqlit/domains/query/ui/mixins/query_execution.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,7 @@ async def _run_query_atomic_async(self: QueryMixinHost, query: str) -> None:
439439
import asyncio
440440
import time
441441

442+
from sqlit.domains.query.app.multi_statement import MultiStatementResult
442443
from sqlit.domains.query.app.query_service import QueryResult
443444
from sqlit.domains.query.app.transaction import TransactionExecutor
444445

@@ -476,14 +477,21 @@ async def _run_query_atomic_async(self: QueryMixinHost, query: str) -> None:
476477
except Exception:
477478
pass
478479

479-
if isinstance(result, QueryResult):
480+
if isinstance(result, MultiStatementResult):
481+
# Multi-statement atomic execution
482+
self._display_multi_statement_results(result, elapsed_ms)
483+
if result.has_error:
484+
self.notify("Transaction rolled back (error in statement)", severity="error")
485+
else:
486+
self.notify("Query executed atomically (committed)", severity="information")
487+
elif isinstance(result, QueryResult):
480488
await self._display_query_results(
481489
result.columns, result.rows, result.row_count, result.truncated, elapsed_ms
482490
)
491+
self.notify("Query executed atomically (committed)", severity="information")
483492
else:
484493
self._display_non_query_result(result.rows_affected, elapsed_ms)
485-
486-
self.notify("Query executed atomically (committed)", severity="information")
494+
self.notify("Query executed atomically (committed)", severity="information")
487495

488496
except Exception as e:
489497
self._display_query_error(f"Transaction rolled back: {e}")

sqlit/shared/ui/protocols/query.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,10 @@ class QueryActionsProtocol(Protocol):
4343
def action_execute_query(self) -> None:
4444
...
4545

46-
def action_execute_single_statement(self) -> None:
46+
def action_execute_query_atomic(self) -> None:
47+
...
48+
49+
def action_execute_single_statement(self) -> None:
4750
...
4851

4952
def _get_history_store(self) -> Any:

0 commit comments

Comments
 (0)