Skip to content

Commit 1f7af3a

Browse files
authored
SQLite: fix transaction level accounting with bad custom command. (#3981)
In the previous code the worker would always assume that the custom command worked. However the higher level code would run a check and notice that a transaction was not actually started and raise an error without rolling back the transaction. This improves the code by moving the transaction check into the worker to ensure that the transaction depth tracker is only modified if the user's custom command actually started a transaction. Fixes: #3932
1 parent ce878ce commit 1f7af3a

File tree

5 files changed

+41
-20
lines changed

5 files changed

+41
-20
lines changed

sqlx-sqlite/src/connection/handle.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ use std::{io, ptr};
44

55
use crate::error::Error;
66
use libsqlite3_sys::{
7-
sqlite3, sqlite3_close, sqlite3_exec, sqlite3_extended_result_codes, sqlite3_last_insert_rowid,
8-
sqlite3_open_v2, SQLITE_OK,
7+
sqlite3, sqlite3_close, sqlite3_exec, sqlite3_extended_result_codes, sqlite3_get_autocommit,
8+
sqlite3_last_insert_rowid, sqlite3_open_v2, SQLITE_OK,
99
};
1010

1111
use crate::SqliteError;
@@ -78,6 +78,12 @@ impl ConnectionHandle {
7878
}
7979
}
8080

81+
pub(crate) fn in_transaction(&mut self) -> bool {
82+
// SAFETY: we have exclusive access to the database handle
83+
let ret = unsafe { sqlite3_get_autocommit(self.as_ptr()) };
84+
ret == 0
85+
}
86+
8187
pub(crate) fn last_insert_rowid(&mut self) -> i64 {
8288
// SAFETY: we have exclusive access to the database handle
8389
unsafe { sqlite3_last_insert_rowid(self.as_ptr()) }

sqlx-sqlite/src/connection/mod.rs

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ use std::ptr::NonNull;
1010

1111
use futures_intrusive::sync::MutexGuard;
1212
use libsqlite3_sys::{
13-
sqlite3, sqlite3_commit_hook, sqlite3_get_autocommit, sqlite3_progress_handler,
14-
sqlite3_rollback_hook, sqlite3_update_hook, SQLITE_DELETE, SQLITE_INSERT, SQLITE_UPDATE,
13+
sqlite3, sqlite3_commit_hook, sqlite3_progress_handler, sqlite3_rollback_hook,
14+
sqlite3_update_hook, SQLITE_DELETE, SQLITE_INSERT, SQLITE_UPDATE,
1515
};
1616
#[cfg(feature = "preupdate-hook")]
1717
pub use preupdate_hook::*;
@@ -545,11 +545,6 @@ impl LockedSqliteHandle<'_> {
545545
pub fn last_error(&mut self) -> Option<SqliteError> {
546546
self.guard.handle.last_error()
547547
}
548-
549-
pub(crate) fn in_transaction(&mut self) -> bool {
550-
let ret = unsafe { sqlite3_get_autocommit(self.as_raw_handle().as_ptr()) };
551-
ret == 0
552-
}
553548
}
554549

555550
impl Drop for ConnectionState {

sqlx-sqlite/src/connection/worker.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ impl ConnectionWorker {
213213
Command::Begin { tx, statement } => {
214214
let depth = shared.transaction_depth.load(Ordering::Acquire);
215215

216+
let is_custom_statement = statement.is_some();
216217
let statement = match statement {
217218
// custom `BEGIN` statements are not allowed if
218219
// we're already in a transaction (we need to
@@ -229,8 +230,14 @@ impl ConnectionWorker {
229230
let res =
230231
conn.handle
231232
.exec(statement.as_str())
232-
.map(|_| {
233+
.and_then(|res| {
234+
if is_custom_statement && !conn.handle.in_transaction() {
235+
return Err(Error::BeginFailed)
236+
}
237+
233238
shared.transaction_depth.fetch_add(1, Ordering::Release);
239+
240+
Ok(res)
234241
});
235242
let res_ok = res.is_ok();
236243

sqlx-sqlite/src/transaction.rs

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,7 @@ impl TransactionManager for SqliteTransactionManager {
1212
type Database = Sqlite;
1313

1414
async fn begin(conn: &mut SqliteConnection, statement: Option<SqlStr>) -> Result<(), Error> {
15-
let is_custom_statement = statement.is_some();
16-
conn.worker.begin(statement).await?;
17-
if is_custom_statement {
18-
// Check that custom statement actually put the connection into a transaction.
19-
let mut handle = conn.lock_handle().await?;
20-
if !handle.in_transaction() {
21-
return Err(Error::BeginFailed);
22-
}
23-
}
24-
Ok(())
15+
conn.worker.begin(statement).await
2516
}
2617

2718
fn commit(conn: &mut SqliteConnection) -> impl Future<Output = Result<(), Error>> + Send + '_ {

tests/sqlite/sqlite.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1375,6 +1375,28 @@ async fn it_can_use_transaction_options() -> anyhow::Result<()> {
13751375
Ok(())
13761376
}
13771377

1378+
#[sqlx_macros::test]
1379+
async fn it_can_recover_from_bad_transaction_begin() -> anyhow::Result<()> {
1380+
let mut conn = SqliteConnectOptions::new()
1381+
.in_memory(true)
1382+
.connect()
1383+
.await
1384+
.unwrap();
1385+
1386+
// This statement doesn't actually start a transaction.
1387+
assert!(conn.begin_with("SELECT 1").await.is_err());
1388+
1389+
// Transaction state bookkeeping should be correctly reset.
1390+
1391+
let mut tx = conn.begin_with("BEGIN IMMEDIATE").await?;
1392+
let value = sqlx::query_scalar::<_, i32>("SELECT 1")
1393+
.fetch_one(&mut *tx)
1394+
.await?;
1395+
assert_eq!(value, 1);
1396+
1397+
Ok(())
1398+
}
1399+
13781400
fn transaction_state(handle: &mut LockedSqliteHandle) -> SqliteTransactionState {
13791401
use libsqlite3_sys::{sqlite3_txn_state, SQLITE_TXN_NONE, SQLITE_TXN_READ, SQLITE_TXN_WRITE};
13801402

0 commit comments

Comments
 (0)