Skip to content

Commit 599fef8

Browse files
committed
fix sqlite_step function
1 parent 9e6002f commit 599fef8

File tree

4 files changed

+80
-100
lines changed

4 files changed

+80
-100
lines changed

src/lib.rs

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,17 @@ use std::{
88
};
99

1010
use sqlite::{
11-
push_error, reset_txn_on_db, ExecutionState, SQLite3, SQLite3PreparedStmt, Value, SQLITE_BUSY,
12-
SQLITE_CANTOPEN, SQLITE_DONE, SQLITE_ERROR, SQLITE_FLOAT, SQLITE_INTEGER, SQLITE_MISUSE,
13-
SQLITE_NULL, SQLITE_OK, SQLITE_RANGE, SQLITE_TEXT,
11+
push_error, reset_txn_on_db, ExecutionState, SQLite3, SQLite3ExecCallback, SQLite3PreparedStmt,
12+
Value, SQLITE_BUSY, SQLITE_CANTOPEN, SQLITE_DONE, SQLITE_ERROR, SQLITE_FLOAT, SQLITE_INTEGER,
13+
SQLITE_MISUSE, SQLITE_NULL, SQLITE_OK, SQLITE_RANGE, SQLITE_TEXT,
1414
};
1515
use utils::execute_async_task;
1616

1717
use crate::{
1818
auth::{DbAuthStrategy, GlobeStrategy},
1919
utils::{
20-
count_parameters, extract_column_names, get_tokio, is_aligned, sql_is_begin_transaction,
21-
sql_is_commit, sql_is_pragma, sql_is_rollback,
20+
count_parameters, get_tokio, is_aligned, sql_is_begin_transaction, sql_is_commit,
21+
sql_is_pragma, sql_is_rollback,
2222
},
2323
};
2424

@@ -139,7 +139,6 @@ pub unsafe extern "C" fn sqlite3_prepare_v3(
139139
};
140140

141141
let param_count = count_parameters(&sql);
142-
let column_names = extract_column_names(&sql);
143142

144143
// Mock unparsed portion of SQL
145144
if !pz_tail.is_null() {
@@ -156,7 +155,7 @@ pub unsafe extern "C" fn sqlite3_prepare_v3(
156155
execution_state: Mutex::new(ExecutionState::Prepared), // Start in the "Prepared" state
157156
result_rows: Mutex::new(vec![]), // Initialize an empty result set
158157
current_row: Mutex::new(None), // No current row initially
159-
column_names,
158+
column_names: Vec::new(),
160159
});
161160
*pp_stmt = Box::into_raw(stmt);
162161

@@ -277,21 +276,36 @@ pub unsafe extern "C" fn sqlite3_step(stmt_ptr: *mut SQLite3PreparedStmt) -> c_i
277276
}
278277
drop(exec_state);
279278

280-
let sql = stmt.sql.to_uppercase();
281-
if sql.starts_with("SELECT") {
282-
return execute_async_task(stmt.db, sqlite::handle_select(stmt));
283-
} else if sql_is_begin_transaction(&sql) {
284-
return execute_async_task(stmt.db, sqlite::begin_tnx_on_db(stmt.db));
285-
} else if sql_is_commit(&sql) {
286-
return execute_async_task(stmt.db, sqlite::commit_tnx_on_db(stmt.db));
279+
let needs_execution = stmt.result_rows.lock().unwrap().is_empty();
280+
if needs_execution {
281+
let sql = stmt.sql.to_uppercase();
282+
let sql_result_code = {
283+
if sql_is_begin_transaction(&sql) {
284+
execute_async_task(stmt.db, sqlite::begin_tnx_on_db(stmt.db, &sql))
285+
} else if sql_is_commit(&sql) {
286+
execute_async_task(stmt.db, sqlite::commit_tnx_on_db(stmt.db, &sql))
287+
} else {
288+
execute_async_task(stmt.db, sqlite::execute_stmt(stmt))
289+
}
290+
};
291+
292+
if sql_result_code != SQLITE_OK {
293+
return sql_result_code;
294+
}
287295
}
288296

289-
execute_async_task(stmt.db, sqlite::execute_statement(stmt))
297+
let sql_result_code = sqlite::iterate_rows(stmt);
298+
if let Err(error) = sql_result_code {
299+
push_error(stmt.db, (error.to_string(), SQLITE_ERROR));
300+
return SQLITE_ERROR;
301+
}
302+
303+
sql_result_code.unwrap()
290304
}
291305

292306
#[no_mangle]
293307
pub extern "C" fn sqlite3_column_count(stmt: *mut SQLite3PreparedStmt) -> i32 {
294-
if stmt.is_null() {
308+
if !is_aligned(stmt) {
295309
return 0;
296310
}
297311

@@ -595,9 +609,9 @@ pub extern "C" fn sqlite3_errstr(errcode: c_int) -> *const c_char {
595609
pub unsafe extern "C" fn sqlite3_exec(
596610
db: *mut SQLite3,
597611
sql: *const c_char,
598-
_: Option<extern "C" fn(*mut std::ffi::c_void, i32, *const *const i8, *const *const i8) -> i32>,
599-
_: *mut std::ffi::c_void,
600-
_: *mut *mut i8,
612+
callback: SQLite3ExecCallback, // Callback function
613+
arg: *mut c_void,
614+
errmsg: *mut *mut c_char,
601615
) -> c_int {
602616
if !is_aligned(db) {
603617
return SQLITE_CANTOPEN;
@@ -610,11 +624,11 @@ pub unsafe extern "C" fn sqlite3_exec(
610624
if sql_is_pragma(&sql) {
611625
return SQLITE_OK;
612626
} else if sql_is_begin_transaction(&sql) {
613-
return execute_async_task(db, sqlite::begin_tnx_on_db(db));
627+
return execute_async_task(db, sqlite::begin_tnx_on_db(db, &sql));
614628
} else if sql_is_rollback(&sql) {
615629
return reset_txn_on_db(db);
616630
} else if sql_is_commit(&sql) {
617-
return execute_async_task(db, sqlite::commit_tnx_on_db(db));
631+
return execute_async_task(db, sqlite::commit_tnx_on_db(db, &sql));
618632
}
619633

620634
execute_async_task(db, sqlite::handle_execute(db, &sql))

src/proxy.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ pub struct RemoteCol {
3333
#[derive(Debug, Deserialize, Clone)]
3434
pub struct RemoteRow {
3535
pub r#type: String,
36-
pub value: serde_json::Value,
36+
pub value: Option<serde_json::Value>,
3737
}
3838

3939
#[derive(Debug, Deserialize, Clone)]
@@ -84,13 +84,14 @@ pub async fn execute_sql_and_params(
8484
pub async fn get_transaction_baton(
8585
client: &Client,
8686
config: &TursoConfig,
87+
sql: &str,
8788
) -> Result<String, Box<dyn Error>> {
8889
let request = serde_json::json!({
8990
"requests": [
9091
{
9192
"type": "execute",
9293
"stmt": {
93-
"sql": "BEGIN"
94+
"sql": sql
9495
}
9596
}
9697
]

src/sqlite.rs

Lines changed: 42 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@ use std::{
77

88
use crate::{
99
proxy::{
10-
convert_params_to_json, execute_sql_and_params, get_execution_result,
11-
get_transaction_baton, QueryResult,
10+
convert_params_to_json, execute_sql_and_params, get_execution_result, get_transaction_baton,
1211
},
1312
utils::TursoConfig,
1413
};
@@ -158,6 +157,15 @@ impl SQLite3PreparedStmt {
158157
}
159158
}
160159

160+
pub type SQLite3ExecCallback = Option<
161+
unsafe extern "C" fn(
162+
arg: *mut c_void,
163+
column_count: c_int,
164+
column_values: *mut *mut c_char,
165+
column_names: *mut *mut c_char,
166+
) -> c_int,
167+
>;
168+
161169
pub fn get_latest_error(db: &SQLite3) -> Option<(String, c_int)> {
162170
if let Ok(stack) = db.error_stack.lock() {
163171
stack.last().cloned()
@@ -196,30 +204,7 @@ pub fn push_error(db: *mut SQLite3, error: (String, c_int)) {
196204
}
197205
}
198206

199-
pub async unsafe fn execute_statement(
200-
stmt: &mut SQLite3PreparedStmt,
201-
) -> Result<c_int, Box<dyn Error>> {
202-
match execute_stmt(stmt).await {
203-
Ok(_) => Ok(SQLITE_OK),
204-
Err(e) => {
205-
push_error(stmt.db, (e.to_string(), SQLITE_ERROR));
206-
Err(e)
207-
}
208-
}
209-
}
210-
211-
pub async unsafe fn handle_select(stmt: &mut SQLite3PreparedStmt) -> Result<c_int, Box<dyn Error>> {
212-
let needs_execution = {
213-
let result_rows = stmt.result_rows.lock().map_err(|_| "lock error")?;
214-
result_rows.is_empty()
215-
};
216-
217-
if needs_execution {
218-
if let Err(err) = execute_stmt_and_populate_result_rows(stmt).await {
219-
return Err(err);
220-
}
221-
}
222-
207+
pub fn iterate_rows(stmt: &mut SQLite3PreparedStmt) -> Result<c_int, Box<dyn Error>> {
223208
let result_rows = stmt.result_rows.lock().unwrap();
224209
let mut current_row = stmt.current_row.lock().unwrap();
225210

@@ -266,7 +251,7 @@ pub async unsafe fn handle_select(stmt: &mut SQLite3PreparedStmt) -> Result<c_in
266251
}
267252
}
268253

269-
pub async unsafe fn handle_execute(db: *mut SQLite3, sql: &str) -> Result<c_int, Box<dyn Error>> {
254+
pub async fn handle_execute(db: *mut SQLite3, sql: &str) -> Result<c_int, Box<dyn Error>> {
270255
if db.is_null() {
271256
return Err("Database pointer is null".into());
272257
}
@@ -279,7 +264,7 @@ pub async unsafe fn handle_execute(db: *mut SQLite3, sql: &str) -> Result<c_int,
279264
}
280265
}
281266

282-
pub async fn begin_tnx_on_db(db: *mut SQLite3) -> Result<c_int, Box<dyn Error>> {
267+
pub async fn begin_tnx_on_db(db: *mut SQLite3, sql: &str) -> Result<c_int, Box<dyn Error>> {
283268
if db.is_null() {
284269
return Err("Database pointer is null".into());
285270
}
@@ -297,14 +282,14 @@ pub async fn begin_tnx_on_db(db: *mut SQLite3) -> Result<c_int, Box<dyn Error>>
297282
return Err("Database is busy".into());
298283
}
299284

300-
let baton_value = get_transaction_baton(&db.client, &db.turso_config).await?;
285+
let baton_value = get_transaction_baton(&db.client, &db.turso_config, &sql).await?;
301286
db.transaction_baton.lock().unwrap().replace(baton_value);
302287
*db.transaction_has_began.lock().unwrap() = true;
303288

304289
Ok(SQLITE_OK)
305290
}
306291

307-
pub async fn commit_tnx_on_db(db: *mut SQLite3) -> Result<c_int, Box<dyn Error>> {
292+
pub async fn commit_tnx_on_db(db: *mut SQLite3, sql: &str) -> Result<c_int, Box<dyn Error>> {
308293
if db.is_null() {
309294
return Err("Database pointer is null".into());
310295
}
@@ -324,7 +309,7 @@ pub async fn commit_tnx_on_db(db: *mut SQLite3) -> Result<c_int, Box<dyn Error>>
324309

325310
let baton = db.transaction_baton.lock().unwrap().clone();
326311

327-
execute_sql_and_params(db, "COMMIT", vec![], baton.as_ref()).await?;
312+
execute_sql_and_params(db, &sql, vec![], baton.as_ref()).await?;
328313

329314
db.transaction_baton.lock().unwrap().take();
330315

@@ -333,50 +318,46 @@ pub async fn commit_tnx_on_db(db: *mut SQLite3) -> Result<c_int, Box<dyn Error>>
333318
Ok(SQLITE_OK)
334319
}
335320

336-
async unsafe fn execute_stmt(
337-
stmt: &mut SQLite3PreparedStmt,
338-
) -> Result<QueryResult, Box<dyn Error>> {
339-
let db: &SQLite3 = &*stmt.db;
321+
pub async fn execute_stmt(stmt: &mut SQLite3PreparedStmt) -> Result<c_int, Box<dyn Error>> {
322+
let db: &SQLite3 = unsafe { &*stmt.db };
340323
let baton_str = {
341324
let baton = db.transaction_baton.lock().unwrap();
342325
baton.as_ref().map(|s| s.as_str()).map(|s| s.to_owned())
343326
};
344327

345328
let params = convert_params_to_json(&stmt.params);
346329
let response = execute_sql_and_params(db, &stmt.sql, params, baton_str.as_ref()).await?;
330+
let response = get_execution_result(db, &response)?;
347331

348-
let result = get_execution_result(db, &response)?;
332+
stmt.column_names = response.cols.iter().map(|col| col.name.clone()).collect();
349333

350-
Ok(result.clone())
351-
}
352-
353-
async unsafe fn execute_stmt_and_populate_result_rows(
354-
stmt: &mut SQLite3PreparedStmt,
355-
) -> Result<c_int, Box<dyn Error>> {
356-
let response = execute_stmt(stmt).await?;
357334
let mut result_rows = stmt.result_rows.lock().unwrap();
358-
359-
let rows = response.rows;
360-
let columns = response.cols;
361-
stmt.column_names = columns.iter().map(|col| col.name.clone()).collect();
362-
363-
*result_rows = rows
335+
*result_rows = response
336+
.rows
364337
.iter()
365338
.map(|row| {
366339
let result = row
367340
.iter()
368-
.map(|row| match row.r#type.as_str() {
369-
"integer" => match &row.value {
370-
serde_json::Value::String(s) => {
371-
Value::Integer(s.parse::<i64>().unwrap_or(0))
372-
}
373-
serde_json::Value::Number(n) => Value::Integer(n.as_i64().unwrap_or(0)),
374-
_ => Value::Integer(0),
375-
},
376-
"float" => Value::Real(row.value.as_f64().unwrap_or(0.0)),
377-
"text" => Value::Text(row.value.as_str().unwrap_or("").to_string()),
378-
"null" => Value::Null,
379-
_ => Value::Null,
341+
.map(|row| {
342+
if row.value.is_none() {
343+
return Value::Null;
344+
}
345+
346+
let value = row.value.as_ref().unwrap();
347+
348+
match row.r#type.as_str() {
349+
"integer" => match &value {
350+
serde_json::Value::String(s) => {
351+
Value::Integer(s.parse::<i64>().unwrap_or(0))
352+
}
353+
serde_json::Value::Number(n) => Value::Integer(n.as_i64().unwrap_or(0)),
354+
_ => Value::Integer(0),
355+
},
356+
"float" => Value::Real(value.as_f64().unwrap_or(0.0)),
357+
"text" => Value::Text(value.as_str().unwrap_or("").to_string()),
358+
"null" => Value::Null,
359+
_ => Value::Null,
360+
}
380361
})
381362
.collect();
382363

src/utils.rs

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -35,22 +35,6 @@ pub fn count_parameters(sql: &str) -> c_int {
3535
re.find_iter(&sql).count() as c_int
3636
}
3737

38-
pub fn extract_column_names(sql: &str) -> Vec<String> {
39-
let select_start = sql.to_uppercase().find("SELECT");
40-
let from_start = sql.to_uppercase().find("FROM");
41-
42-
if let (Some(start), Some(end)) = (select_start, from_start) {
43-
let columns_part = &sql[start + 6..end].trim();
44-
columns_part
45-
.split(',')
46-
.map(|col| col.split("AS").last().unwrap_or(col).trim().to_string())
47-
.collect()
48-
} else {
49-
// Default to unnamed columns if parsing fails
50-
vec![]
51-
}
52-
}
53-
5438
pub fn execute_async_task<F, R>(db: *mut SQLite3, task: F) -> c_int
5539
where
5640
F: std::future::Future<Output = Result<R, Box<dyn std::error::Error>>>,

0 commit comments

Comments
 (0)