Skip to content

Commit 4950787

Browse files
test: add sqllogictest runner
This change adds a sqllogictest runner + CLI which runs .slt files in the tests/sqllogictest directory. Right now, the runner uses a hardcoded distributed config (ex. `with_network_shuffle_tasks(2)` etc.) but can be extended in the future. Using sqllogictest will make it much easier to write tests (especially for `explain` and `explain (analyze)`). `explain (analyze)` tests will be added in a future commit as it is still being implemented. Also, this change deletes `tests/distributed_aggregation.rs` and moves the test cases to `.slt` files. Documentation: - https://sqlite.org/sqllogictest/doc/trunk/about.wiki - https://github.com/risinglightdb/sqllogictest-rs
1 parent d3b46ab commit 4950787

File tree

8 files changed

+574
-179
lines changed

8 files changed

+574
-179
lines changed

Cargo.lock

Lines changed: 266 additions & 24 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ parquet = { version = "55.2.0", optional = true }
3939
arrow = { version = "55.2.0", optional = true }
4040
tokio-stream = { version = "0.1.17", optional = true }
4141
hyper-util = { version = "0.1.16", optional = true }
42+
sqllogictest = { version = "0.20", optional = true }
43+
regex = { version = "1.0", optional = true }
44+
clap = { version = "4.0", features = ["derive"], optional = true }
45+
env_logger = { version = "0.10", optional = true }
4246
pin-project = "1.1.10"
4347

4448
[features]
@@ -50,6 +54,10 @@ integration = [
5054
"arrow",
5155
"tokio-stream",
5256
"hyper-util",
57+
"sqllogictest",
58+
"regex",
59+
"clap",
60+
"env_logger",
5361
]
5462

5563
tpch = ["integration"]

src/bin/logictest.rs

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
use clap::Parser;
2+
use datafusion_distributed::test_utils::sqllogictest::DatafusionDistributedDB;
3+
use sqllogictest::Runner;
4+
use std::path::PathBuf;
5+
6+
#[derive(Parser)]
7+
#[command(name = "logictest")]
8+
#[command(
9+
about = "A SQLLogicTest runner for DataFusion Distributed. Docs: https://sqlite.org/sqllogictest/doc/trunk/about.wiki"
10+
)]
11+
struct Args {
12+
/// Files or directories to run
13+
#[arg(required = true)]
14+
files: Vec<PathBuf>,
15+
16+
/// Update test files with actual output rather than verifying the existing output
17+
#[arg(long = "override")]
18+
override_mode: bool,
19+
20+
/// Number of workers
21+
#[arg(long, default_value = "3")]
22+
num_workers: usize,
23+
}
24+
25+
async fn run<D, M>(paths: Vec<PathBuf>, runner: &mut Runner<D, M>, override_mode: bool)
26+
where
27+
D: sqllogictest::AsyncDB,
28+
M: sqllogictest::MakeConnection<Conn = D>,
29+
{
30+
let mut queue = paths;
31+
let mut idx = 0;
32+
while idx < queue.len() {
33+
let file_path = &queue[idx];
34+
idx += 1;
35+
if !file_path.is_file() {
36+
queue.extend(
37+
expand_directory(file_path).await.unwrap_or_else(|_| {
38+
panic!("Failed to expand directory: {}", file_path.display())
39+
}),
40+
);
41+
continue;
42+
}
43+
let file_path_str = file_path.to_str().expect("Invalid file path");
44+
45+
let result = match override_mode {
46+
true => {
47+
runner
48+
.update_test_file(
49+
file_path_str,
50+
" ",
51+
sqllogictest::default_validator,
52+
sqllogictest::default_column_validator,
53+
)
54+
.await
55+
}
56+
57+
false => runner.run_file_async(file_path).await.map_err(|e| e.into()),
58+
};
59+
match result {
60+
Ok(_) => println!("🟢 Success: {}", file_path.display()),
61+
Err(e) => eprintln!("🔴 Failure: {}:\n{e}", file_path.display()),
62+
}
63+
}
64+
}
65+
66+
#[tokio::main]
67+
async fn main() -> Result<(), Box<dyn std::error::Error>> {
68+
let args = Args::parse();
69+
70+
let mut runner =
71+
Runner::new(
72+
move || async move { Ok(DatafusionDistributedDB::new(args.num_workers).await) },
73+
);
74+
75+
run(args.files, &mut runner, args.override_mode).await;
76+
77+
Ok(())
78+
}
79+
80+
async fn expand_directory(dir_path: &PathBuf) -> Result<Vec<PathBuf>, Box<dyn std::error::Error>> {
81+
let mut entries: Vec<_> = std::fs::read_dir(dir_path)?
82+
.filter_map(|entry| entry.ok())
83+
.filter(|entry| {
84+
entry
85+
.path()
86+
.extension()
87+
.and_then(|ext| ext.to_str())
88+
.map(|ext| ext == "slt")
89+
.unwrap_or(false)
90+
})
91+
.map(|entry| entry.path())
92+
.collect();
93+
94+
// Sort entries for consistent order
95+
entries.sort();
96+
97+
Ok(entries)
98+
}

src/test_utils/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@ pub mod metrics;
55
pub mod mock_exec;
66
pub mod parquet;
77
pub mod session_context;
8+
pub mod sqllogictest;
89
pub mod tpch;

src/test_utils/sqllogictest.rs

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
use crate::DefaultSessionBuilder;
2+
use crate::DistributedPhysicalOptimizerRule;
3+
use crate::test_utils::localhost::start_localhost_context;
4+
use crate::test_utils::parquet::register_parquet_tables;
5+
use async_trait::async_trait;
6+
use datafusion::arrow::array::RecordBatch;
7+
use datafusion::arrow::array::{ArrayRef, StringArray};
8+
use datafusion::arrow::datatypes::{DataType, Field, Schema};
9+
use datafusion::arrow::util::display::array_value_to_string;
10+
use datafusion::common::runtime::JoinSet;
11+
use datafusion::error::DataFusionError;
12+
use datafusion::execution::context::SessionContext;
13+
use datafusion::physical_optimizer::PhysicalOptimizerRule;
14+
use datafusion::physical_plan::ExecutionPlan;
15+
use datafusion::physical_plan::displayable;
16+
use datafusion::physical_plan::execution_plan::collect;
17+
use sqllogictest::{AsyncDB, DBOutput, DefaultColumnType};
18+
use std::sync::Arc;
19+
20+
pub struct DatafusionDistributedDB {
21+
ctx: SessionContext,
22+
_guard: JoinSet<()>,
23+
}
24+
25+
impl DatafusionDistributedDB {
26+
pub async fn new(num_nodes: usize) -> Self {
27+
let (ctx, _guard) = start_localhost_context(num_nodes, DefaultSessionBuilder).await;
28+
register_parquet_tables(&ctx).await.unwrap();
29+
Self { ctx, _guard }
30+
}
31+
32+
pub fn optimize_distributed(
33+
plan: Arc<dyn ExecutionPlan>,
34+
) -> Result<Arc<dyn ExecutionPlan>, DataFusionError> {
35+
DistributedPhysicalOptimizerRule::default()
36+
.with_network_shuffle_tasks(2)
37+
.with_network_coalesce_tasks(2)
38+
.optimize(plan, &Default::default())
39+
}
40+
41+
fn convert_batches_to_output(
42+
&self,
43+
batches: Vec<RecordBatch>,
44+
) -> Result<DBOutput<DefaultColumnType>, datafusion::error::DataFusionError> {
45+
if batches.is_empty() {
46+
return Ok(DBOutput::Rows {
47+
types: vec![],
48+
rows: vec![],
49+
});
50+
}
51+
52+
let num_columns = batches[0].num_columns();
53+
let column_types = vec![DefaultColumnType::Text; num_columns]; // Everything as text
54+
55+
let mut rows = Vec::new();
56+
for batch in batches {
57+
for row_idx in 0..batch.num_rows() {
58+
let mut row = Vec::new();
59+
for col_idx in 0..batch.num_columns() {
60+
let column = batch.column(col_idx);
61+
let value = array_value_to_string(column, row_idx)
62+
.map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?;
63+
row.push(value);
64+
}
65+
rows.push(row);
66+
}
67+
}
68+
69+
Ok(DBOutput::Rows {
70+
types: column_types,
71+
rows,
72+
})
73+
}
74+
75+
async fn handle_explain_analyze(
76+
&mut self,
77+
_sql: &str,
78+
) -> Result<DBOutput<DefaultColumnType>, datafusion::error::DataFusionError> {
79+
unimplemented!();
80+
}
81+
82+
async fn handle_explain(
83+
&mut self,
84+
sql: &str,
85+
) -> Result<DBOutput<DefaultColumnType>, datafusion::error::DataFusionError> {
86+
let query = sql.trim_start_matches("EXPLAIN").trim();
87+
let df = self.ctx.sql(query).await?;
88+
let physical_plan = df.create_physical_plan().await?;
89+
90+
let physical_distributed = Self::optimize_distributed(physical_plan)?;
91+
92+
let physical_distributed_str = displayable(physical_distributed.as_ref())
93+
.indent(true)
94+
.to_string();
95+
96+
let lines: Vec<String> = physical_distributed_str
97+
.lines()
98+
.map(|s| s.to_string())
99+
.collect();
100+
let schema = Arc::new(Schema::new(vec![Field::new("plan", DataType::Utf8, false)]));
101+
let batch =
102+
RecordBatch::try_new(schema, vec![Arc::new(StringArray::from(lines)) as ArrayRef])?;
103+
104+
self.convert_batches_to_output(vec![batch])
105+
}
106+
}
107+
108+
#[async_trait]
109+
impl AsyncDB for DatafusionDistributedDB {
110+
type Error = datafusion::error::DataFusionError;
111+
type ColumnType = DefaultColumnType;
112+
113+
async fn run(&mut self, sql: &str) -> Result<DBOutput<Self::ColumnType>, Self::Error> {
114+
let sql = sql.trim();
115+
116+
// Ignore DDL/DML
117+
if sql.to_uppercase().starts_with("CREATE")
118+
|| sql.to_uppercase().starts_with("INSERT")
119+
|| sql.to_uppercase().starts_with("DROP")
120+
{
121+
return Ok(DBOutput::StatementComplete(0));
122+
}
123+
124+
if sql.to_uppercase().starts_with("EXPLAIN ANALYZE") {
125+
return self.handle_explain_analyze(sql).await;
126+
}
127+
128+
if sql.to_uppercase().starts_with("EXPLAIN") {
129+
return self.handle_explain(sql).await;
130+
}
131+
132+
// Default: Execute SELECT statement
133+
let df = self.ctx.sql(sql).await?;
134+
let task_ctx = Arc::new(df.task_ctx());
135+
let plan = df.create_physical_plan().await?;
136+
let distributed_plan = Self::optimize_distributed(plan)?;
137+
let batches = collect(distributed_plan, task_ctx).await?;
138+
139+
self.convert_batches_to_output(batches)
140+
}
141+
142+
fn engine_name(&self) -> &str {
143+
"datafusion-distributed"
144+
}
145+
}

0 commit comments

Comments
 (0)