Skip to content

Commit b846bd2

Browse files
authored
fix: modify step and add reset for statement (#511)
Signed-off-by: bkioshn <[email protected]>
1 parent 3f91b2e commit b846bd2

File tree

3 files changed

+139
-102
lines changed
  • hermes/bin/src/runtime_extensions/hermes/sqlite/statement
  • wasm/wasi/wit/deps/hermes-sqlite

3 files changed

+139
-102
lines changed

hermes/bin/src/runtime_extensions/hermes/sqlite/statement/core.rs

Lines changed: 95 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@ use libsqlite3_sys::{
55
sqlite3_bind_blob, sqlite3_bind_double, sqlite3_bind_int, sqlite3_bind_int64,
66
sqlite3_bind_null, sqlite3_bind_text, sqlite3_column_blob, sqlite3_column_bytes,
77
sqlite3_column_double, sqlite3_column_int64, sqlite3_column_text, sqlite3_column_type,
8-
sqlite3_finalize, sqlite3_step, sqlite3_stmt, SQLITE_BLOB, SQLITE_DONE, SQLITE_FLOAT,
9-
SQLITE_INTEGER, SQLITE_NULL, SQLITE_OK, SQLITE_ROW, SQLITE_TEXT, SQLITE_TRANSIENT,
8+
sqlite3_finalize, sqlite3_reset, sqlite3_step, sqlite3_stmt, SQLITE_BLOB, SQLITE_DONE,
9+
SQLITE_FLOAT, SQLITE_INTEGER, SQLITE_NULL, SQLITE_OK, SQLITE_ROW, SQLITE_TEXT,
10+
SQLITE_TRANSIENT,
1011
};
1112

12-
use crate::runtime_extensions::bindings::hermes::sqlite::api::{Errno, Value};
13+
use crate::runtime_extensions::bindings::hermes::sqlite::api::{Errno, StepResult, Value};
1314

1415
/// Stores application data into parameters of the original SQL.
1516
pub(crate) fn bind(
@@ -60,13 +61,13 @@ pub(crate) fn bind(
6061
}
6162

6263
/// Advances a statement to the next result row or to completion.
63-
pub(crate) fn step(stmt_ptr: *mut sqlite3_stmt) -> Result<(), Errno> {
64+
pub(crate) fn step(stmt_ptr: *mut sqlite3_stmt) -> Result<StepResult, Errno> {
6465
let rc = unsafe { sqlite3_step(stmt_ptr) };
6566

66-
if rc != SQLITE_DONE && rc != SQLITE_ROW {
67-
Err(Errno::Sqlite(rc))
68-
} else {
69-
Ok(())
67+
match rc {
68+
SQLITE_DONE => Ok(StepResult::Done),
69+
SQLITE_ROW => Ok(StepResult::Row),
70+
_ => Err(Errno::Sqlite(rc)),
7071
}
7172
}
7273

@@ -114,6 +115,17 @@ pub(crate) fn column(
114115
Ok(value)
115116
}
116117

118+
/// Reset the prepared statement.
119+
pub(crate) fn reset(stmt_ptr: *mut sqlite3_stmt) -> Result<(), Errno> {
120+
let rc = unsafe { sqlite3_reset(stmt_ptr) };
121+
122+
if rc == SQLITE_OK {
123+
Ok(())
124+
} else {
125+
Err(Errno::Sqlite(rc))
126+
}
127+
}
128+
117129
/// Destroys a prepared statement object. If the most recent evaluation of the
118130
/// statement encountered no errors or if the statement is never been evaluated,
119131
/// then the function results without errors. If the most recent evaluation of
@@ -144,8 +156,7 @@ mod tests {
144156
const TMP_DIR: &str = "tmp-dir";
145157

146158
fn init() -> Result<*mut sqlite3, Errno> {
147-
let app_name = ApplicationName(String::from(TMP_DIR));
148-
159+
let app_name = ApplicationName(TMP_DIR.to_string());
149160
open(false, true, app_name)
150161
}
151162

@@ -155,138 +166,124 @@ mod tests {
155166
value: Value,
156167
) -> Result<(), Errno> {
157168
let sql = format!("CREATE TABLE Dummy(Id INTEGER PRIMARY KEY, Value {db_value_type});");
158-
159-
execute(db_ptr, sql.as_str())?;
169+
execute(db_ptr, &sql)?;
160170

161171
let sql = "INSERT INTO Dummy(Value) VALUES(?);";
162-
163172
let stmt_ptr = prepare(db_ptr, sql)?;
164-
165173
bind(stmt_ptr, 1, value)?;
166174
step(stmt_ptr)?;
167175
finalize(stmt_ptr)?;
168-
169176
Ok(())
170177
}
171178

172179
fn get_value(db_ptr: *mut sqlite3) -> Result<Value, Errno> {
173180
let sql = "SELECT Value FROM Dummy WHERE Id = 1;";
174181

175182
let stmt_ptr = prepare(db_ptr, sql)?;
176-
step(stmt_ptr)?;
177-
let col_result = column(stmt_ptr, 0);
178-
finalize(stmt_ptr)?;
179-
180-
col_result
183+
match step(stmt_ptr)? {
184+
StepResult::Row => {
185+
let val = column(stmt_ptr, 0)?;
186+
finalize(stmt_ptr)?;
187+
Ok(val)
188+
},
189+
StepResult::Done => {
190+
finalize(stmt_ptr)?;
191+
Err(Errno::Sqlite(SQLITE_DONE))
192+
},
193+
}
181194
}
182195

183-
#[test]
184-
fn test_value_double() -> Result<(), Errno> {
196+
fn test_single_value(
197+
db_type: &str,
198+
value: &Value,
199+
) -> Result<(), Errno> {
185200
let db_ptr = init()?;
186-
187-
let value = Value::Double(std::f64::consts::PI);
188-
init_value(db_ptr, "REAL", value.clone())?;
189-
let value_result = get_value(db_ptr);
190-
191-
assert!(
192-
matches!((value, value_result), (Value::Double(x), Ok(Value::Double(y))) if x.eq(&y))
193-
);
194-
201+
init_value(db_ptr, db_type, value.clone())?;
202+
let value_result = get_value(db_ptr)?;
203+
assert_eq!(format!("{value:?}",), format!("{value_result:?}",));
195204
close(db_ptr)
196205
}
197206

198207
#[test]
199-
fn test_value_bool() -> Result<(), Errno> {
200-
let db_ptr = init()?;
201-
202-
let value = Value::Int32(1);
203-
init_value(db_ptr, "BOOLEAN", value.clone())?;
204-
let value_result = get_value(db_ptr);
205-
206-
assert!(matches!((value, value_result), (Value::Int32(x), Ok(Value::Int32(y))) if x == y));
207-
208-
close(db_ptr)
208+
fn test_double() -> Result<(), Errno> {
209+
test_single_value("REAL", &Value::Double(std::f64::consts::PI))
209210
}
210-
211211
#[test]
212-
fn test_value_int32() -> Result<(), Errno> {
213-
let db_ptr = init()?;
214-
215-
let value = Value::Int32(i32::MAX);
216-
init_value(db_ptr, "MEDIUMINT", value.clone())?;
217-
let value_result = get_value(db_ptr);
218-
219-
assert!(matches!((value, value_result), (Value::Int32(x), Ok(Value::Int32(y))) if x == y));
220-
221-
close(db_ptr)
212+
fn test_bool() -> Result<(), Errno> {
213+
test_single_value("BOOLEAN", &Value::Int32(1))
222214
}
223-
224215
#[test]
225-
fn test_value_int32_nullable() -> Result<(), Errno> {
226-
let db_ptr = init()?;
227-
228-
let value = Value::Null;
229-
init_value(db_ptr, "MEDIUMINT", value.clone())?;
230-
let value_result = get_value(db_ptr);
231-
232-
assert!(matches!(
233-
(value, value_result),
234-
(Value::Null, Ok(Value::Null))
235-
));
236-
237-
close(db_ptr)
216+
fn test_int32() -> Result<(), Errno> {
217+
test_single_value("MEDIUMINT", &Value::Int32(i32::MAX))
238218
}
239-
240219
#[test]
241-
fn test_value_int64() -> Result<(), Errno> {
242-
let db_ptr = init()?;
243-
244-
let value = Value::Int64(i64::MAX);
245-
init_value(db_ptr, "BIGINT", value.clone())?;
246-
let value_result = get_value(db_ptr);
247-
248-
assert!(matches!((value, value_result), (Value::Int64(x), Ok(Value::Int64(y))) if x == y));
249-
250-
close(db_ptr)
220+
fn test_int32_nullable() -> Result<(), Errno> {
221+
test_single_value("MEDIUMINT", &Value::Null)
251222
}
252-
253223
#[test]
254-
fn test_value_text() -> Result<(), Errno> {
255-
let db_ptr = init()?;
256-
257-
let value = Value::Text(String::from("Hello, World!"));
258-
init_value(db_ptr, "TEXT", value.clone())?;
259-
let value_result = get_value(db_ptr);
260-
261-
assert!(matches!((value, value_result), (Value::Text(x), Ok(Value::Text(y))) if x == y));
262-
263-
close(db_ptr)
224+
fn test_int64() -> Result<(), Errno> {
225+
test_single_value("BIGINT", &Value::Int64(i64::MAX))
226+
}
227+
#[test]
228+
fn test_text() -> Result<(), Errno> {
229+
test_single_value("TEXT", &Value::Text("Hello, World!".to_string()))
230+
}
231+
#[test]
232+
fn test_blob() -> Result<(), Errno> {
233+
test_single_value("BLOB", &Value::Blob(vec![1, 2, 3, 4, 5]))
264234
}
265235

266236
#[test]
267-
fn test_value_blob() -> Result<(), Errno> {
237+
fn test_finalize_simple() -> Result<(), Errno> {
268238
let db_ptr = init()?;
269239

270-
let value = Value::Blob(vec![1, 2, 3, 4, 5]);
271-
init_value(db_ptr, "BLOB", value.clone())?;
272-
let value_result = get_value(db_ptr);
240+
let sql = "SELECT 1;";
241+
242+
let stmt_ptr = core::prepare(db_ptr, sql)?;
243+
244+
let result = finalize(stmt_ptr);
273245

274-
assert!(matches!((value, value_result), (Value::Blob(x), Ok(Value::Blob(y))) if x == y));
246+
assert!(result.is_ok());
275247

276248
close(db_ptr)
277249
}
278250

279251
#[test]
280-
fn test_finalize_simple() -> Result<(), Errno> {
252+
fn test_loop_over_rows() -> Result<(), Errno> {
281253
let db_ptr = init()?;
282254

283-
let sql = "SELECT 1;";
255+
let sql = "CREATE TABLE IF NOT EXISTS Dummy(Id INTEGER PRIMARY KEY, Value INTEGER);";
256+
execute(db_ptr, sql)?;
284257

285-
let stmt_ptr = core::prepare(db_ptr, sql)?;
258+
// Insert multiple rows
259+
let sql = "INSERT INTO Dummy(Value) VALUES(?);";
260+
let stmt_ptr = prepare(db_ptr, sql)?;
261+
for i in 1..=5 {
262+
bind(stmt_ptr, 1, Value::Int32(i))?;
263+
step(stmt_ptr)?;
264+
reset(stmt_ptr)?;
265+
}
266+
finalize(stmt_ptr)?;
286267

287-
let result = finalize(stmt_ptr);
268+
let sql = "SELECT Value FROM Dummy ORDER BY Id ASC;";
269+
let stmt_ptr = prepare(db_ptr, sql)?;
270+
let mut collected = Vec::new();
271+
272+
loop {
273+
let result = step(stmt_ptr)?;
274+
match result {
275+
StepResult::Row => collected.push(column(stmt_ptr, 0)?),
276+
StepResult::Done => break,
277+
}
278+
}
279+
finalize(stmt_ptr)?;
288280

289-
assert!(result.is_ok());
281+
for (i, val) in collected.iter().enumerate() {
282+
match val {
283+
Value::Int32(v) => assert_eq!(*v, i32::try_from(i).unwrap() + 1),
284+
_ => panic!("Unexpected value type: {val:?}"),
285+
}
286+
}
290287

291288
close(db_ptr)
292289
}

hermes/bin/src/runtime_extensions/hermes/sqlite/statement/host.rs

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
use super::{super::state::get_statement_state, core};
44
use crate::{
55
runtime_context::HermesRuntimeContext,
6-
runtime_extensions::bindings::hermes::sqlite::api::{Errno, HostStatement, Statement, Value},
6+
runtime_extensions::bindings::hermes::sqlite::api::{
7+
Errno, HostStatement, Statement, StepResult, Value,
8+
},
79
};
810

911
impl HostStatement for HermesRuntimeContext {
@@ -29,10 +31,14 @@ impl HostStatement for HermesRuntimeContext {
2931
///
3032
/// After a prepared statement has been prepared, this function must be called one or
3133
/// more times to evaluate the statement.
34+
///
35+
/// ## Returns
36+
///
37+
/// A `step-result` indicating the status of the step.
3238
fn step(
3339
&mut self,
3440
resource: wasmtime::component::Resource<Statement>,
35-
) -> wasmtime::Result<Result<(), Errno>> {
41+
) -> wasmtime::Result<Result<StepResult, Errno>> {
3642
let mut app_state = get_statement_state().get_app_state(self.app_name())?;
3743
let stmt_ptr = app_state.get_object(&resource)?;
3844
Ok(core::step(*stmt_ptr as *mut _))
@@ -62,6 +68,21 @@ impl HostStatement for HermesRuntimeContext {
6268
Ok(core::column(*stmt_ptr as *mut _, index))
6369
}
6470

71+
/// Reset a prepared statement object back to its initial state, ready to be
72+
/// re-executed.
73+
///
74+
/// This function clears all previous bindings, resets the statement to the beginning,
75+
/// and prepares it for another execution. This must be called before reusing a
76+
/// statement with new parameter bindings.
77+
fn reset(
78+
&mut self,
79+
resource: wasmtime::component::Resource<Statement>,
80+
) -> wasmtime::Result<Result<(), Errno>> {
81+
let mut app_state = get_statement_state().get_app_state(self.app_name())?;
82+
let stmt_ptr = app_state.get_object(&resource)?;
83+
Ok(core::reset(*stmt_ptr as *mut _))
84+
}
85+
6586
/// Destroys a prepared statement object. If the most recent evaluation of the
6687
/// statement encountered no errors or if the statement is never been evaluated,
6788
/// then the function results without errors. If the most recent evaluation of

wasm/wasi/wit/deps/hermes-sqlite/api.wit

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,14 @@ interface api {
6161
text(string)
6262
}
6363

64+
/// The result of advancing a prepared SQLite statement by one step.
65+
enum step-result {
66+
/// Indicates that the statement has finished executing.
67+
done,
68+
/// Indicates that there is a new row of result.
69+
row,
70+
}
71+
6472
/// The database connection object.
6573
resource sqlite {
6674
/// Closes a database connection, destructor for `sqlite3`.
@@ -115,7 +123,11 @@ interface api {
115123
/// Advances a statement to the next result row or to completion.
116124
///
117125
/// After a prepared statement has been prepared, this function must be called one or more times to evaluate the statement.
118-
step: func() -> result<_, errno>;
126+
///
127+
/// ## Returns
128+
///
129+
/// A `step-result` indicating the status of the step.
130+
step: func() -> result<step-result, errno>;
119131

120132
/// Returns information about a single column of the current result row of a query.
121133
///
@@ -130,6 +142,13 @@ interface api {
130142
/// The value of a result column in a specific data format.
131143
column: func(index: u32) -> result<value, errno>;
132144

145+
/// Reset a prepared statement object back to its initial state, ready to be re-executed.
146+
///
147+
/// This function clears all previous bindings, resets the statement to the beginning,
148+
/// and prepares it for another execution. This must be called before reusing a statement
149+
/// with new parameter bindings.
150+
reset: func() -> result<_, errno>;
151+
133152
/// Destroys a prepared statement object. If the most recent evaluation of the statement encountered no errors or if the statement is never been evaluated,
134153
/// then the function results without errors. If the most recent evaluation of statement failed, then the function results the appropriate error code.
135154
///
@@ -155,4 +174,4 @@ interface api {
155174
/// World just for the Hermes 'sqlite' API.
156175
world sqlite-api {
157176
import api;
158-
}
177+
}

0 commit comments

Comments
 (0)