Skip to content

Commit 8645704

Browse files
authored
Support multi-statement queries (#598)
Inspired by the [Go client](https://github.com/mlafeldt/go-duckdb/blob/2227e78d82a5e48b320f412cae5e572a0620d98e/connection.go#L203-L235), this enhances `prepare` to handle both single and multi-statement queries. Among other things, this enables the use of `PIVOT` (which internally expands to multiple statements) and also enables multiple statements per line in the example REPL. 🎉 Fixes #425 Fixes #310
2 parents c061eee + 465e881 commit 8645704

File tree

4 files changed

+159
-27
lines changed

4 files changed

+159
-27
lines changed

crates/duckdb/examples/repl.rs

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -132,28 +132,14 @@ impl SqlRepl {
132132
}
133133

134134
fn execute_sql(&self, sql: &str) -> DuckResult<()> {
135-
// Check if it's a statement that returns results
136-
let sql_upper = sql.trim().to_uppercase();
137-
let is_query = sql_upper.starts_with("SELECT")
138-
|| sql_upper.starts_with("FROM")
139-
|| sql_upper.starts_with("SHOW")
140-
|| sql_upper.starts_with("DESCRIBE")
141-
|| sql_upper.starts_with("EXPLAIN")
142-
|| sql_upper.starts_with("PRAGMA")
143-
|| sql_upper.starts_with("WITH");
144-
145-
if is_query {
146-
let mut stmt = self.conn.prepare(sql)?;
147-
let rbs: Vec<RecordBatch> = stmt.query_arrow([])?.collect();
148-
149-
if rbs.is_empty() || rbs[0].num_rows() == 0 {
150-
println!("No results returned.");
151-
} else {
152-
self.print_records(&rbs);
153-
}
154-
} else {
155-
// Execute non-query statements
156-
self.conn.execute_batch(sql)?;
135+
let mut stmt = self.conn.prepare(sql)?;
136+
let rbs: Vec<RecordBatch> = stmt.query_arrow([])?.collect();
137+
138+
// NOTE: When executing multi-statement queries (e.g., "SELECT 1; SELECT 2;"),
139+
// only the result from the final statement will be displayed. This differs from
140+
// the DuckDB CLI which shows results from all statements.
141+
if !rbs.is_empty() && rbs[0].num_rows() > 0 {
142+
self.print_records(&rbs);
157143
}
158144

159145
Ok(())

crates/duckdb/src/error.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,3 +286,29 @@ pub fn result_from_duckdb_arrow(code: ffi::duckdb_state, mut out: ffi::duckdb_ar
286286
error_from_duckdb_code(code, message)
287287
}
288288
}
289+
290+
#[cold]
291+
#[inline]
292+
pub fn result_from_duckdb_extract(
293+
num_statements: ffi::idx_t,
294+
mut extracted: ffi::duckdb_extracted_statements,
295+
) -> Result<()> {
296+
if num_statements > 0 {
297+
return Ok(());
298+
}
299+
unsafe {
300+
let message = if extracted.is_null() {
301+
Some("extracted statements are null".to_string())
302+
} else {
303+
let c_err = ffi::duckdb_extract_statements_error(extracted);
304+
let message = if c_err.is_null() {
305+
None
306+
} else {
307+
Some(CStr::from_ptr(c_err).to_string_lossy().to_string())
308+
};
309+
ffi::duckdb_destroy_extracted(&mut extracted);
310+
message
311+
};
312+
error_from_duckdb_code(ffi::DuckDBError, message)
313+
}
314+
}

crates/duckdb/src/inner_connection.rs

Lines changed: 73 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@ use std::{
88

99
use super::{ffi, Appender, Config, Connection, Result};
1010
use crate::{
11-
error::{result_from_duckdb_appender, result_from_duckdb_arrow, result_from_duckdb_prepare, Error},
11+
error::{
12+
result_from_duckdb_appender, result_from_duckdb_arrow, result_from_duckdb_extract, result_from_duckdb_prepare,
13+
Error,
14+
},
1215
raw_statement::RawStatement,
1316
statement::Statement,
1417
};
@@ -93,11 +96,68 @@ impl InnerConnection {
9396
}
9497

9598
pub fn prepare<'a>(&mut self, conn: &'a Connection, sql: &str) -> Result<Statement<'a>> {
96-
let mut c_stmt: ffi::duckdb_prepared_statement = ptr::null_mut();
9799
let c_str = CString::new(sql).unwrap();
98-
let r = unsafe { ffi::duckdb_prepare(self.con, c_str.as_ptr() as *const c_char, &mut c_stmt) };
99-
result_from_duckdb_prepare(r, c_stmt)?;
100-
Ok(Statement::new(conn, unsafe { RawStatement::new(c_stmt) }))
100+
101+
// Extract statements (handles both single and multi-statement queries)
102+
let mut extracted = ptr::null_mut();
103+
let num_stmts =
104+
unsafe { ffi::duckdb_extract_statements(self.con, c_str.as_ptr() as *const c_char, &mut extracted) };
105+
result_from_duckdb_extract(num_stmts, extracted)?;
106+
107+
// Auto-cleanup on drop
108+
let _guard = ExtractedStatementsGuard(extracted);
109+
110+
// Execute all intermediate statements
111+
for i in 0..num_stmts - 1 {
112+
self.execute_extracted_statement(extracted, i)?;
113+
}
114+
115+
// Prepare and return final statement
116+
let final_stmt = self.prepare_extracted_statement(extracted, num_stmts - 1)?;
117+
Ok(Statement::new(conn, unsafe { RawStatement::new(final_stmt) }))
118+
}
119+
120+
fn prepare_extracted_statement(
121+
&self,
122+
extracted: ffi::duckdb_extracted_statements,
123+
index: ffi::idx_t,
124+
) -> Result<ffi::duckdb_prepared_statement> {
125+
let mut stmt = ptr::null_mut();
126+
let res = unsafe { ffi::duckdb_prepare_extracted_statement(self.con, extracted, index, &mut stmt) };
127+
result_from_duckdb_prepare(res, stmt)?;
128+
Ok(stmt)
129+
}
130+
131+
fn execute_extracted_statement(
132+
&self,
133+
extracted: ffi::duckdb_extracted_statements,
134+
index: ffi::idx_t,
135+
) -> Result<()> {
136+
let mut stmt = self.prepare_extracted_statement(extracted, index)?;
137+
138+
let mut result = unsafe { mem::zeroed() };
139+
let rc = unsafe { ffi::duckdb_execute_prepared(stmt, &mut result) };
140+
141+
let error = if rc != ffi::DuckDBSuccess {
142+
unsafe {
143+
let c_err = ffi::duckdb_result_error(&mut result as *mut _);
144+
let msg = if c_err.is_null() {
145+
None
146+
} else {
147+
Some(CStr::from_ptr(c_err).to_string_lossy().to_string())
148+
};
149+
Some(Error::DuckDBFailure(ffi::Error::new(rc), msg))
150+
}
151+
} else {
152+
None
153+
};
154+
155+
unsafe {
156+
ffi::duckdb_destroy_prepare(&mut stmt);
157+
ffi::duckdb_destroy_result(&mut result);
158+
}
159+
160+
error.map_or(Ok(()), Err)
101161
}
102162

103163
pub fn appender<'a>(&mut self, conn: &'a Connection, table: &str, schema: &str) -> Result<Appender<'a>> {
@@ -126,6 +186,14 @@ impl InnerConnection {
126186
}
127187
}
128188

189+
struct ExtractedStatementsGuard(ffi::duckdb_extracted_statements);
190+
191+
impl Drop for ExtractedStatementsGuard {
192+
fn drop(&mut self) {
193+
unsafe { ffi::duckdb_destroy_extracted(&mut self.0) }
194+
}
195+
}
196+
129197
impl Drop for InnerConnection {
130198
#[allow(unused_must_use)]
131199
#[inline]

crates/duckdb/src/lib.rs

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1455,4 +1455,56 @@ mod test {
14551455

14561456
Ok(())
14571457
}
1458+
1459+
#[test]
1460+
fn test_prepare_multi_statement() -> Result<()> {
1461+
let db = checked_memory_handle();
1462+
1463+
{
1464+
let mut stmt =
1465+
db.prepare("CREATE TABLE test(x INTEGER); INSERT INTO test VALUES (42); SELECT x FROM test;")?;
1466+
let result: i32 = stmt.query_row([], |row| row.get(0))?;
1467+
assert_eq!(result, 42);
1468+
}
1469+
1470+
{
1471+
let mut stmt = db.prepare(
1472+
"CREATE TEMP TABLE temp_data(id INTEGER, value TEXT);
1473+
INSERT INTO temp_data VALUES (1, 'first'), (2, 'second');
1474+
SELECT COUNT(*) FROM temp_data;",
1475+
)?;
1476+
let count: i32 = stmt.query_row([], |row| row.get(0))?;
1477+
assert_eq!(count, 2);
1478+
}
1479+
1480+
Ok(())
1481+
}
1482+
1483+
#[test]
1484+
fn test_pivot_query() -> Result<()> {
1485+
let db = checked_memory_handle();
1486+
1487+
db.execute_batch(
1488+
"CREATE TABLE cities(city VARCHAR, year INTEGER, population INTEGER);
1489+
INSERT INTO cities VALUES
1490+
('Amsterdam', 2000, 1005),
1491+
('Amsterdam', 2010, 1065),
1492+
('Amsterdam', 2020, 1158),
1493+
('Berlin', 2000, 3382),
1494+
('Berlin', 2010, 3460),
1495+
('Berlin', 2020, 3576);",
1496+
)?;
1497+
1498+
// PIVOT queries internally expand to multiple statements
1499+
let mut stmt = db.prepare("PIVOT cities ON year USING sum(population);")?;
1500+
let mut rows = stmt.query([])?;
1501+
1502+
let mut row_count = 0;
1503+
while let Some(_row) = rows.next()? {
1504+
row_count += 1;
1505+
}
1506+
assert_eq!(row_count, 2);
1507+
1508+
Ok(())
1509+
}
14581510
}

0 commit comments

Comments
 (0)