Skip to content

Commit f9cf0f1

Browse files
wip
1 parent d3b46ab commit f9cf0f1

File tree

8 files changed

+591
-179
lines changed

8 files changed

+591
-179
lines changed

Cargo.lock

Lines changed: 266 additions & 24 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ parquet = { version = "55.2.0", optional = true }
3939
arrow = { version = "55.2.0", optional = true }
4040
tokio-stream = { version = "0.1.17", optional = true }
4141
hyper-util = { version = "0.1.16", optional = true }
42+
sqllogictest = { version = "0.20", optional = true }
43+
regex = { version = "1.0", optional = true }
44+
clap = { version = "4.0", features = ["derive"], optional = true }
45+
env_logger = { version = "0.10", optional = true }
4246
pin-project = "1.1.10"
4347

4448
[features]
@@ -50,6 +54,10 @@ integration = [
5054
"arrow",
5155
"tokio-stream",
5256
"hyper-util",
57+
"sqllogictest",
58+
"regex",
59+
"clap",
60+
"env_logger",
5361
]
5462

5563
tpch = ["integration"]

src/bin/logictest.rs

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
use clap::Parser;
2+
use datafusion_distributed::test_utils::sqllogictest::DatafusionDistributedDB;
3+
use sqllogictest::Runner;
4+
use std::path::PathBuf;
5+
6+
#[derive(Parser)]
7+
#[command(name = "logictest")]
8+
#[command(about = "A SQLLogicTest runner for DataFusion Distributed")]
9+
struct Args {
10+
/// Test files or directories to run
11+
#[arg(required = true)]
12+
files: Vec<PathBuf>,
13+
14+
/// Override mode: update test files with actual output
15+
#[arg(long = "override")]
16+
override_mode: bool,
17+
18+
/// Number of distributed nodes to start
19+
#[arg(long, default_value = "3")]
20+
nodes: usize,
21+
}
22+
23+
#[tokio::main]
24+
async fn main() -> Result<(), Box<dyn std::error::Error>> {
25+
let args = Args::parse();
26+
27+
// Create a closure that creates new database connections
28+
let nodes = args.nodes;
29+
let mut runner =
30+
Runner::new(move || async move { Ok(DatafusionDistributedDB::new(nodes).await) });
31+
32+
// Configure runner based on override mode
33+
if args.override_mode {
34+
// Override mode: use sqllogictest's built-in override functionality
35+
for file_path in &args.files {
36+
if file_path.is_file() {
37+
let file_path_str = file_path.to_str().expect("Invalid file path");
38+
39+
// Use the built-in update_test_file with default comparison functions
40+
match runner
41+
.update_test_file(
42+
file_path_str,
43+
file_path_str,
44+
sqllogictest::default_validator,
45+
sqllogictest::default_column_validator,
46+
)
47+
.await
48+
{
49+
Ok(_) => println!("✅ {}: Generated", file_path.display()),
50+
Err(e) => {
51+
eprintln!("❌ {}: Failed to generate", file_path.display());
52+
eprintln!(" Error: {}", e);
53+
}
54+
}
55+
} else {
56+
eprintln!("Override mode only works with individual files, not directories");
57+
}
58+
}
59+
} else {
60+
// Verify mode: compare results against expected output
61+
for file_path in &args.files {
62+
if file_path.is_file() {
63+
match runner.run_file_async(file_path).await {
64+
Ok(_) => println!("✅ {}: PASSED", file_path.display()),
65+
Err(e) => {
66+
eprintln!("❌ {}: FAILED", file_path.display());
67+
eprintln!(" Error: {}", e);
68+
}
69+
}
70+
} else if file_path.is_dir() {
71+
println!("Running tests in directory: {}", file_path.display());
72+
run_directory(&mut runner, file_path).await?;
73+
} else {
74+
eprintln!(
75+
"Warning: {} is neither a file nor directory",
76+
file_path.display()
77+
);
78+
}
79+
}
80+
}
81+
82+
Ok(())
83+
}
84+
85+
#[cfg(feature = "integration")]
86+
async fn run_directory<D, M>(
87+
runner: &mut Runner<D, M>,
88+
dir_path: &PathBuf,
89+
) -> Result<(), Box<dyn std::error::Error>>
90+
where
91+
D: sqllogictest::AsyncDB,
92+
M: sqllogictest::MakeConnection<Conn = D>,
93+
{
94+
let mut entries: Vec<_> = std::fs::read_dir(dir_path)?
95+
.filter_map(|entry| entry.ok())
96+
.filter(|entry| {
97+
entry
98+
.path()
99+
.extension()
100+
.and_then(|ext| ext.to_str())
101+
.map(|ext| ext == "slt")
102+
.unwrap_or(false)
103+
})
104+
.collect();
105+
106+
// Sort entries for consistent order
107+
entries.sort_by_key(|entry| entry.path());
108+
109+
for entry in entries {
110+
let file_path = entry.path();
111+
println!("Running test file: {}", file_path.display());
112+
match runner.run_file_async(&file_path).await {
113+
Ok(_) => println!("✅ {}: PASSED", file_path.display()),
114+
Err(e) => {
115+
eprintln!("❌ {}: FAILED", file_path.display());
116+
eprintln!(" Error: {}", e);
117+
}
118+
}
119+
}
120+
121+
Ok(())
122+
}

src/test_utils/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@ pub mod metrics;
55
pub mod mock_exec;
66
pub mod parquet;
77
pub mod session_context;
8+
pub mod sqllogictest;
89
pub mod tpch;

src/test_utils/sqllogictest.rs

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
use crate::DefaultSessionBuilder;
2+
use crate::DistributedPhysicalOptimizerRule;
3+
use crate::test_utils::localhost::start_localhost_context;
4+
use crate::test_utils::parquet::register_parquet_tables;
5+
use async_trait::async_trait;
6+
use datafusion::arrow::array::RecordBatch;
7+
use datafusion::arrow::array::{ArrayRef, StringArray};
8+
use datafusion::arrow::datatypes::{DataType, Field, Schema};
9+
use datafusion::arrow::util::display::array_value_to_string;
10+
use datafusion::common::runtime::JoinSet;
11+
use datafusion::error::DataFusionError;
12+
use datafusion::execution::context::SessionContext;
13+
use datafusion::physical_optimizer::PhysicalOptimizerRule;
14+
use datafusion::physical_plan::displayable;
15+
use sqllogictest::{AsyncDB, DBOutput, DefaultColumnType};
16+
use std::sync::Arc;
17+
18+
pub struct DatafusionDistributedDB {
19+
ctx: SessionContext,
20+
_guard: JoinSet<()>,
21+
}
22+
23+
impl DatafusionDistributedDB {
24+
pub async fn new(num_nodes: usize) -> Self {
25+
let (ctx, _guard) = start_localhost_context(num_nodes, DefaultSessionBuilder).await;
26+
register_parquet_tables(&ctx).await.unwrap();
27+
Self { ctx, _guard }
28+
}
29+
30+
fn convert_batches_to_output(
31+
&self,
32+
batches: Vec<RecordBatch>,
33+
) -> Result<DBOutput<DefaultColumnType>, datafusion::error::DataFusionError> {
34+
if batches.is_empty() {
35+
return Ok(DBOutput::Rows {
36+
types: vec![],
37+
rows: vec![],
38+
});
39+
}
40+
41+
let num_columns = batches[0].num_columns();
42+
let column_types = vec![DefaultColumnType::Text; num_columns]; // Everything as text
43+
44+
let mut rows = Vec::new();
45+
for batch in batches {
46+
for row_idx in 0..batch.num_rows() {
47+
let mut row = Vec::new();
48+
for col_idx in 0..batch.num_columns() {
49+
let column = batch.column(col_idx);
50+
let value = array_value_to_string(column, row_idx)
51+
.map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?;
52+
row.push(value);
53+
}
54+
rows.push(row);
55+
}
56+
}
57+
58+
Ok(DBOutput::Rows {
59+
types: column_types,
60+
rows,
61+
})
62+
}
63+
64+
async fn handle_explain_analyze(
65+
&mut self,
66+
_sql: &str,
67+
) -> Result<DBOutput<DefaultColumnType>, datafusion::error::DataFusionError> {
68+
unimplemented!();
69+
}
70+
71+
async fn handle_explain(
72+
&mut self,
73+
sql: &str,
74+
) -> Result<DBOutput<DefaultColumnType>, datafusion::error::DataFusionError> {
75+
let query = sql.trim_start_matches("EXPLAIN").trim();
76+
let df = self.ctx.sql(query).await?;
77+
let physical_plan = df.create_physical_plan().await?;
78+
79+
let physical_distributed = DistributedPhysicalOptimizerRule::default()
80+
.with_network_shuffle_tasks(2)
81+
.with_network_coalesce_tasks(2)
82+
.optimize(physical_plan, &Default::default())?;
83+
84+
let physical_distributed_str = displayable(physical_distributed.as_ref())
85+
.indent(true)
86+
.to_string();
87+
88+
let lines: Vec<String> = physical_distributed_str
89+
.lines()
90+
.map(|s| s.to_string())
91+
.collect();
92+
let schema = Arc::new(Schema::new(vec![Field::new("plan", DataType::Utf8, false)]));
93+
let batch =
94+
RecordBatch::try_new(schema, vec![Arc::new(StringArray::from(lines)) as ArrayRef])?;
95+
96+
self.convert_batches_to_output(vec![batch])
97+
}
98+
}
99+
100+
#[async_trait]
101+
impl AsyncDB for DatafusionDistributedDB {
102+
type Error = datafusion::error::DataFusionError;
103+
type ColumnType = DefaultColumnType;
104+
105+
async fn run(&mut self, sql: &str) -> Result<DBOutput<Self::ColumnType>, Self::Error> {
106+
let sql = sql.trim();
107+
108+
// Handle different types of SQL statements
109+
if sql.to_uppercase().starts_with("CREATE")
110+
|| sql.to_uppercase().starts_with("INSERT")
111+
|| sql.to_uppercase().starts_with("DROP")
112+
{
113+
// For DDL/DML statements, just return an empty result
114+
return Ok(DBOutput::StatementComplete(0));
115+
}
116+
117+
// Handle EXPLAIN ANALYZE
118+
if sql.to_uppercase().starts_with("EXPLAIN ANALYZE") {
119+
return self.handle_explain_analyze(sql).await;
120+
}
121+
122+
// Handle regular EXPLAIN - use distributed optimizer
123+
if sql.to_uppercase().starts_with("EXPLAIN") {
124+
return self.handle_explain(sql).await;
125+
}
126+
127+
let df = self.ctx.sql(sql).await?;
128+
let batches = df.collect().await?;
129+
130+
self.convert_batches_to_output(batches)
131+
}
132+
133+
fn engine_name(&self) -> &str {
134+
"datafusion-distributed"
135+
}
136+
}

0 commit comments

Comments
 (0)