Skip to content

Commit de40c71

Browse files
authored
feat: [sqlite] Batch insert using prepared statements (#453)
* feat: [sqlite] Batch insert using prepared statements * Remove Decimal * Remove Decimal * Fixes * Update core/tests/sqlite/mod.rs Signed-off-by: Luke Kim <[email protected]> * Fixes * Fix build issue * Fix build * Remove Decimal support * Remove unused conversion * Fix Clippy issues --------- Signed-off-by: Luke Kim <[email protected]>
1 parent 10c974b commit de40c71

File tree

9 files changed

+1440
-22
lines changed

9 files changed

+1440
-22
lines changed

benches/sqlite_insert_benchmark.rs

Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
use std::sync::Arc;
2+
use std::time::Instant;
3+
4+
use arrow::{
5+
array::{Float64Array, Int64Array, RecordBatch, StringArray},
6+
datatypes::{DataType, Field, Schema},
7+
};
8+
use datafusion::{
9+
catalog::TableProviderFactory,
10+
common::{Constraints, TableReference, ToDFSchema},
11+
execution::context::SessionContext,
12+
logical_expr::{dml::InsertOp, CreateExternalTable},
13+
physical_plan::collect,
14+
};
15+
use datafusion_table_providers::{sqlite::SqliteTableProviderFactory, util::test::MockExec};
16+
17+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18+
enum InsertMethod {
19+
Prepared,
20+
Inline,
21+
}
22+
23+
impl InsertMethod {
24+
fn name(&self) -> &str {
25+
match self {
26+
InsertMethod::Prepared => "insert_batch_prepared (NEW)",
27+
InsertMethod::Inline => "insert_batch (OLD)",
28+
}
29+
}
30+
}
31+
32+
/// Benchmark for SQLite insert performance comparing prepared statements vs inline SQL
33+
///
34+
/// This benchmark measures the performance of inserting data into SQLite
35+
/// using both the new prepared statement approach and the old inline SQL approach.
36+
///
37+
/// Set the environment variable SQLITE_INSERT_METHOD to control which method to test:
38+
/// - "prepared" (default): Use prepared statements
39+
/// - "inline": Use inline SQL generation
40+
/// - "both": Test both methods and compare
41+
#[tokio::main]
42+
async fn main() {
43+
println!("\n=== SQLite Insert Performance Benchmark ===\n");
44+
45+
// Determine which method(s) to test
46+
let test_mode = std::env::var("SQLITE_INSERT_METHOD")
47+
.unwrap_or_else(|_| "both".to_string())
48+
.to_lowercase();
49+
50+
let methods_to_test = match test_mode.as_str() {
51+
"inline" => vec![InsertMethod::Inline],
52+
"prepared" => vec![InsertMethod::Prepared],
53+
"both" => vec![InsertMethod::Inline, InsertMethod::Prepared],
54+
_ => vec![InsertMethod::Inline, InsertMethod::Prepared],
55+
};
56+
57+
// Test configurations: (num_batches, rows_per_batch)
58+
let test_configs = vec![
59+
(10, 1),
60+
(10, 10),
61+
(10, 100),
62+
(1, 1000),
63+
(10, 1000),
64+
(100, 1000),
65+
(10, 10000),
66+
(5, 50000),
67+
(5, 100000),
68+
(5, 1000000),
69+
];
70+
71+
// Store results for comparison
72+
type BenchmarkResults = Vec<(InsertMethod, Vec<(usize, f64, f64)>)>;
73+
let mut results: BenchmarkResults = Vec::new();
74+
75+
for method in &methods_to_test {
76+
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
77+
println!("Testing Method: {}", method.name());
78+
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n");
79+
80+
let mut method_results = Vec::new();
81+
82+
for (num_batches, rows_per_batch) in &test_configs {
83+
let total_rows = num_batches * rows_per_batch;
84+
println!(
85+
" Config: {} batches × {} rows = {} total rows",
86+
num_batches, rows_per_batch, total_rows
87+
);
88+
89+
let duration = run_benchmark(*num_batches, *rows_per_batch, *method).await;
90+
let rows_per_sec = total_rows as f64 / duration.as_secs_f64();
91+
let time_per_row = duration.as_micros() as f64 / total_rows as f64;
92+
93+
println!(" ⏱️ Time taken: {:.3}s", duration.as_secs_f64());
94+
println!(" 🚀 Throughput: {:.0} rows/sec", rows_per_sec);
95+
println!(" 📊 Per-row time: {:.2}µs\n", time_per_row);
96+
97+
method_results.push((total_rows, rows_per_sec, time_per_row));
98+
}
99+
100+
results.push((*method, method_results));
101+
println!();
102+
}
103+
104+
// Print comparison summary if both methods were tested
105+
if methods_to_test.len() > 1 {
106+
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
107+
println!("Performance Comparison");
108+
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n");
109+
110+
println!(
111+
"{:<15} {:<20} {:<20} {:<15}",
112+
"Total Rows", "OLD (rows/sec)", "NEW (rows/sec)", "Speedup"
113+
);
114+
println!("{}", "─".repeat(75));
115+
116+
for i in 0..test_configs.len() {
117+
let (total_rows, old_throughput, _) = results[0].1[i];
118+
let (_, new_throughput, _) = results[1].1[i];
119+
let speedup = new_throughput / old_throughput;
120+
121+
println!(
122+
"{:<15} {:<20.0} {:<20.0} {:.2}x",
123+
total_rows, old_throughput, new_throughput, speedup
124+
);
125+
}
126+
127+
println!("\n{}", "─".repeat(75));
128+
129+
// Calculate average speedup
130+
let avg_speedup: f64 = (0..test_configs.len())
131+
.map(|i| results[1].1[i].1 / results[0].1[i].1)
132+
.sum::<f64>()
133+
/ test_configs.len() as f64;
134+
135+
println!(
136+
"\n📊 Average speedup: {:.2}x faster with prepared statements",
137+
avg_speedup
138+
);
139+
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n");
140+
}
141+
}
142+
143+
async fn run_benchmark(
144+
num_batches: usize,
145+
rows_per_batch: usize,
146+
method: InsertMethod,
147+
) -> std::time::Duration {
148+
// Create schema with multiple column types
149+
let schema = Arc::new(Schema::new(vec![
150+
Field::new("id", DataType::Int64, false),
151+
Field::new("name", DataType::Utf8, false),
152+
Field::new("value", DataType::Float64, false),
153+
Field::new("category", DataType::Utf8, true),
154+
Field::new("count", DataType::Int64, true),
155+
]));
156+
157+
let df_schema = ToDFSchema::to_dfschema_ref(Arc::clone(&schema)).expect("df schema");
158+
159+
// Create a unique table name to avoid conflicts
160+
let table_name = format!(
161+
"bench_table_{}",
162+
std::time::SystemTime::now()
163+
.duration_since(std::time::UNIX_EPOCH)
164+
.unwrap()
165+
.as_millis()
166+
);
167+
168+
let external_table = CreateExternalTable {
169+
schema: df_schema,
170+
name: TableReference::bare(table_name),
171+
location: String::new(),
172+
file_type: String::new(),
173+
table_partition_cols: vec![],
174+
if_not_exists: true,
175+
definition: None,
176+
order_exprs: vec![],
177+
unbounded: false,
178+
options: std::collections::HashMap::new(),
179+
constraints: Constraints::new_unverified(vec![]),
180+
column_defaults: std::collections::HashMap::default(),
181+
temporary: false,
182+
};
183+
184+
let ctx = SessionContext::new();
185+
186+
// Configure the factory based on which method we're testing
187+
let use_prepared = match method {
188+
InsertMethod::Prepared => true,
189+
InsertMethod::Inline => false,
190+
};
191+
192+
let table = SqliteTableProviderFactory::default()
193+
.with_batch_insert_use_prepared_statements(use_prepared)
194+
.create(&ctx.state(), &external_table)
195+
.await
196+
.expect("table should be created");
197+
198+
// Generate batches
199+
let batches: Vec<Result<RecordBatch, datafusion::error::DataFusionError>> = (0..num_batches)
200+
.map(|batch_idx| {
201+
let start_id = batch_idx * rows_per_batch;
202+
203+
let ids: Vec<i64> = (start_id..(start_id + rows_per_batch))
204+
.map(|i| i as i64)
205+
.collect();
206+
207+
let names: Vec<String> = (start_id..(start_id + rows_per_batch))
208+
.map(|i| format!("name_{}", i))
209+
.collect();
210+
211+
let values: Vec<f64> = (start_id..(start_id + rows_per_batch))
212+
.map(|i| (i as f64) * 1.5)
213+
.collect();
214+
215+
let categories: Vec<Option<String>> = (start_id..(start_id + rows_per_batch))
216+
.map(|i| Some(format!("category_{}", i % 10)))
217+
.collect();
218+
219+
let counts: Vec<Option<i64>> = (start_id..(start_id + rows_per_batch))
220+
.map(|i| {
221+
if i % 3 == 0 {
222+
Some((i % 100) as i64)
223+
} else {
224+
None
225+
}
226+
})
227+
.collect();
228+
229+
let id_array = Int64Array::from(ids);
230+
let name_array = StringArray::from(names);
231+
let value_array = Float64Array::from(values);
232+
let category_array = StringArray::from(categories);
233+
let count_array = Int64Array::from(counts);
234+
235+
Ok(RecordBatch::try_new(
236+
Arc::clone(&schema),
237+
vec![
238+
Arc::new(id_array),
239+
Arc::new(name_array),
240+
Arc::new(value_array),
241+
Arc::new(category_array),
242+
Arc::new(count_array),
243+
],
244+
)
245+
.expect("batch should be created"))
246+
})
247+
.collect();
248+
249+
let exec = MockExec::new(batches, schema);
250+
251+
// Start timing
252+
let start = Instant::now();
253+
254+
let insertion = table
255+
.insert_into(&ctx.state(), Arc::new(exec), InsertOp::Append)
256+
.await
257+
.expect("insertion should be successful");
258+
259+
collect(insertion, ctx.task_ctx())
260+
.await
261+
.expect("insert successful");
262+
263+
// End timing
264+
start.elapsed()
265+
}

core/Cargo.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,3 +182,8 @@ required-features = ["sqlite"]
182182
name = "clickhouse"
183183
path = "examples/clickhouse.rs"
184184
required-features = ["clickhouse"]
185+
186+
[[bin]]
187+
name = "sqlite_insert_benchmark"
188+
path = "../benches/sqlite_insert_benchmark.rs"
189+
required-features = ["sqlite"]

core/src/mysql.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -355,11 +355,11 @@ impl MySQL {
355355

356356
async fn table_exists(&self, mysql_connection: &MySQLConnection) -> bool {
357357
let sql = format!(
358-
r#"SELECT EXISTS (
358+
"SELECT EXISTS (
359359
SELECT 1
360360
FROM information_schema.tables
361361
WHERE table_name = '{name}'
362-
)"#,
362+
)",
363363
name = self.table_name
364364
);
365365
tracing::trace!("{sql}");

core/src/postgres.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -394,12 +394,12 @@ impl Postgres {
394394
async fn table_exists(&self, postgres_conn: &PostgresConnection) -> bool {
395395
let sql = match self.table.schema() {
396396
Some(schema) => format!(
397-
r#"SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = '{name}' AND table_schema = '{schema}')"#,
397+
"SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = '{name}' AND table_schema = '{schema}')",
398398
name = self.table.table(),
399399
schema = schema
400400
),
401401
None => format!(
402-
r#"SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = '{name}')"#,
402+
"SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = '{name}')",
403403
name = self.table.table()
404404
),
405405
};
@@ -439,7 +439,7 @@ impl Postgres {
439439
async fn delete_all_table_data(&self, transaction: &Transaction<'_>) -> Result<()> {
440440
transaction
441441
.execute(
442-
format!(r#"DELETE FROM {}"#, self.table.to_quoted_string()).as_str(),
442+
format!("DELETE FROM {}", self.table.to_quoted_string()).as_str(),
443443
&[],
444444
)
445445
.await

core/src/sql/arrow_sql_gen/statement.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1341,7 +1341,7 @@ pub(crate) fn map_data_type_to_column_type(data_type: &DataType) -> ColumnType {
13411341
// This caused the error: "Row size too large. The maximum row size for the used table type, not counting BLOBs, is 65535.
13421342
// This includes storage overhead, check the manual. You have to change some columns to TEXT or BLOBs."
13431343
// Changing to Blob fixes this issue. This change does not affect Postgres, and for Sqlite, the mapping type changes from varbinary_blob to blob.
1344-
DataType::Binary | DataType::LargeBinary => ColumnType::Blob,
1344+
DataType::Binary | DataType::LargeBinary | DataType::BinaryView => ColumnType::Blob,
13451345
DataType::FixedSizeBinary(num_bytes) => ColumnType::Binary(num_bytes.to_owned() as u32),
13461346
DataType::Interval(_) => ColumnType::Interval(None, None),
13471347
// Add more mappings here as needed

core/src/sql/sql_provider_datafusion/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -481,7 +481,7 @@ mod tests {
481481
async fn test_sql_to_string() -> Result<(), Box<dyn Error + Send + Sync>> {
482482
let sql_table = new_sql_table("users", Some(Arc::new(SqliteDialect {})))?;
483483
let result = sql_table.scan_to_sql(Some(&vec![0]), &[], None)?;
484-
assert_eq!(result, r#"SELECT `users`.`name` FROM `users`"#);
484+
assert_eq!(result, "SELECT `users`.`name` FROM `users`");
485485
Ok(())
486486
}
487487

@@ -493,7 +493,7 @@ mod tests {
493493
let result = sql_table.scan_to_sql(Some(&vec![0, 1]), &filters, Some(3))?;
494494
assert_eq!(
495495
result,
496-
r#"SELECT `users`.`name`, `users`.`age` FROM `users` WHERE ((`users`.`age` >= 30) AND (`users`.`name` = 'x')) LIMIT 3"#
496+
"SELECT `users`.`name`, `users`.`age` FROM `users` WHERE ((`users`.`age` >= 30) AND (`users`.`name` = 'x')) LIMIT 3"
497497
);
498498
Ok(())
499499
}

0 commit comments

Comments
 (0)