diff --git a/Cargo.lock b/Cargo.lock index 6bc85e2a7d1f..b8845b6135ac 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2132,9 +2132,12 @@ dependencies = [ "bytes", "dashmap", "datafusion", + "datafusion-common", + "datafusion-expr", "datafusion-ffi", "datafusion-physical-expr-adapter", "datafusion-proto", + "datafusion-sql", "env_logger", "futures", "log", @@ -2143,6 +2146,7 @@ dependencies = [ "object_store", "prost 0.13.5", "rand 0.9.2", + "rand_distr", "serde_json", "tempfile", "test-utils", diff --git a/datafusion-examples/Cargo.toml b/datafusion-examples/Cargo.toml index 68bb5376a1ac..a0cb993eade9 100644 --- a/datafusion-examples/Cargo.toml +++ b/datafusion-examples/Cargo.toml @@ -56,6 +56,18 @@ path = "examples/external_dependency/query-aws-s3.rs" name = "custom_file_casts" path = "examples/custom_file_casts.rs" +[[example]] +name = "relation_planner_table_sample" +path = "examples/relation_planner/table_sample.rs" + +[[example]] +name = "relation_planner_match_recognize" +path = "examples/relation_planner/match_recognize.rs" + +[[example]] +name = "relation_planner_pivot_unpivot" +path = "examples/relation_planner/pivot_unpivot.rs" + [dev-dependencies] arrow = { workspace = true } # arrow_schema is required for record_batch! macro :sad: @@ -67,9 +79,12 @@ dashmap = { workspace = true } # note only use main datafusion crate for examples base64 = "0.22.1" datafusion = { workspace = true, default-features = true, features = ["parquet_encryption"] } +datafusion-common = { workspace = true } +datafusion-expr = { workspace = true } datafusion-ffi = { workspace = true } datafusion-physical-expr-adapter = { workspace = true } datafusion-proto = { workspace = true } +datafusion-sql = { workspace = true } env_logger = { workspace = true } futures = { workspace = true } log = { workspace = true } @@ -77,6 +92,7 @@ mimalloc = { version = "0.1", default-features = false } object_store = { workspace = true, features = ["aws", "http"] } prost = { workspace = true } rand = { workspace = true } +rand_distr = "0.5" serde_json = { workspace = true } tempfile = { workspace = true } test-utils = { path = "../test-utils" } diff --git a/datafusion-examples/examples/relation_planner/match_recognize.rs b/datafusion-examples/examples/relation_planner/match_recognize.rs new file mode 100644 index 000000000000..50dda35583bb --- /dev/null +++ b/datafusion-examples/examples/relation_planner/match_recognize.rs @@ -0,0 +1,344 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This example demonstrates using custom relation planners to implement +//! MATCH_RECOGNIZE-style pattern matching on event streams. +//! +//! MATCH_RECOGNIZE is a SQL extension for pattern matching on ordered data, +//! similar to regular expressions but for relational data. This example shows +//! how to use custom planners to implement new SQL syntax. + +use std::{any::Any, cmp::Ordering, hash::Hasher, sync::Arc}; + +use arrow::array::{ArrayRef, Float64Array, Int32Array, StringArray}; +use arrow::record_batch::RecordBatch; +use datafusion::prelude::*; +use datafusion_common::{DFSchemaRef, DataFusionError, Result}; +use datafusion_expr::{ + logical_plan::{Extension, InvariantLevel, LogicalPlan}, + planner::{ + PlannedRelation, RelationPlanner, RelationPlannerContext, RelationPlanning, + }, + Expr, UserDefinedLogicalNode, +}; +use datafusion_sql::sqlparser::ast::TableFactor; + +#[tokio::main] +async fn main() -> Result<()> { + let ctx = SessionContext::new(); + + // Register sample data tables + register_sample_data(&ctx)?; + + // Register custom planner + ctx.register_relation_planner(Arc::new(MatchRecognizePlanner))?; + + println!("Custom Relation Planner: MATCH_RECOGNIZE Pattern Matching"); + println!("==========================================================\n"); + + // Example 1: Basic MATCH_RECOGNIZE with MEASURES and DEFINE clauses + // Shows: How to use MATCH_RECOGNIZE to find patterns in event data with aggregations + // Expected: Logical plan showing MiniMatchRecognize node with SUM and AVG measures + // Note: This demonstrates the logical planning phase - actual execution would require physical implementation + // Actual (Logical Plan): + // Projection: t.price + // SubqueryAlias: t + // MiniMatchRecognize measures=[total_price := sum(price), avg_price := avg(price)] define=[a := price > Int64(10)] + // EmptyRelation: rows=0 + run_example( + &ctx, + "Example 1: MATCH_RECOGNIZE with measures and definitions", + r#"SELECT * FROM events + MATCH_RECOGNIZE ( + PARTITION BY 1 + MEASURES SUM(price) AS total_price, AVG(price) AS avg_price + PATTERN (A) + DEFINE A AS price > 10 + ) AS t"#, + ) + .await?; + + // Example 2: Stock price pattern detection using MATCH_RECOGNIZE + // Shows: How to detect patterns in financial data (e.g., stocks above threshold) + // Expected: Logical plan showing MiniMatchRecognize with MIN, MAX, AVG measures on stock prices + // Note: Uses real stock data (DDOG prices: 150, 155, 152, 158) to find patterns above 151.0 + // Actual (Logical Plan): + // Projection: trends.column1, trends.column2 + // SubqueryAlias: trends + // MiniMatchRecognize measures=[min_price := min(column2), max_price := max(column2), avg_price := avg(column2)] define=[high := column2 > Float64(151)] + // Values: (Utf8("DDOG"), Float64(150)), (Utf8("DDOG"), Float64(155)), (Utf8("DDOG"), Float64(152)), (Utf8("DDOG"), Float64(158)) + run_example( + &ctx, + "Example 2: Detect stocks above threshold using MATCH_RECOGNIZE", + r#"SELECT * FROM stock_prices + MATCH_RECOGNIZE ( + MEASURES MIN(column2) AS min_price, + MAX(column2) AS max_price, + AVG(column2) AS avg_price + PATTERN (HIGH) + DEFINE HIGH AS column2 > 151.0 + ) AS trends"#, + ) + .await?; + + Ok(()) +} + +/// Register sample data tables for the examples +fn register_sample_data(ctx: &SessionContext) -> Result<()> { + // Create events table with price column + let price: ArrayRef = Arc::new(Int32Array::from(vec![5, 12, 8, 15, 20])); + let batch = RecordBatch::try_from_iter(vec![("price", price)])?; + ctx.register_batch("events", batch)?; + + // Create stock_prices table with symbol and price columns + let symbol: ArrayRef = + Arc::new(StringArray::from(vec!["DDOG", "DDOG", "DDOG", "DDOG"])); + let price: ArrayRef = Arc::new(Float64Array::from(vec![150.0, 155.0, 152.0, 158.0])); + let batch = + RecordBatch::try_from_iter(vec![("column1", symbol), ("column2", price)])?; + ctx.register_batch("stock_prices", batch)?; + + Ok(()) +} + +async fn run_example(ctx: &SessionContext, title: &str, sql: &str) -> Result<()> { + println!("{title}:\n{sql}\n"); + let plan = ctx.sql(sql).await?.into_unoptimized_plan(); + println!("Logical Plan:"); + println!("{}\n", plan.display_indent()); + Ok(()) +} + +/// A custom logical plan node representing MATCH_RECOGNIZE operations +#[derive(Debug)] +struct MiniMatchRecognizeNode { + input: Arc, + schema: DFSchemaRef, + measures: Vec<(String, Expr)>, + definitions: Vec<(String, Expr)>, +} + +impl UserDefinedLogicalNode for MiniMatchRecognizeNode { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "MiniMatchRecognize" + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + vec![self.input.as_ref()] + } + + fn schema(&self) -> &DFSchemaRef { + &self.schema + } + + fn check_invariants(&self, _check: InvariantLevel) -> Result<()> { + Ok(()) + } + + fn expressions(&self) -> Vec { + self.measures + .iter() + .chain(&self.definitions) + .map(|(_, expr)| expr.clone()) + .collect() + } + + fn fmt_for_explain(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "MiniMatchRecognize")?; + + if !self.measures.is_empty() { + write!(f, " measures=[")?; + for (idx, (alias, expr)) in self.measures.iter().enumerate() { + if idx > 0 { + write!(f, ", ")?; + } + write!(f, "{alias} := {expr}")?; + } + write!(f, "]")?; + } + + if !self.definitions.is_empty() { + write!(f, " define=[")?; + for (idx, (symbol, expr)) in self.definitions.iter().enumerate() { + if idx > 0 { + write!(f, ", ")?; + } + write!(f, "{symbol} := {expr}")?; + } + write!(f, "]")?; + } + + Ok(()) + } + + fn with_exprs_and_inputs( + &self, + exprs: Vec, + inputs: Vec, + ) -> Result> { + if exprs.len() != self.measures.len() + self.definitions.len() { + return Err(DataFusionError::Internal( + "MiniMatchRecognize received an unexpected expression count".into(), + )); + } + + let (measure_exprs, definition_exprs) = exprs.split_at(self.measures.len()); + + let Some(first_input) = inputs.into_iter().next() else { + return Err(DataFusionError::Internal( + "MiniMatchRecognize requires a single input".into(), + )); + }; + + let measures = self + .measures + .iter() + .zip(measure_exprs.iter()) + .map(|((alias, _), expr)| (alias.clone(), expr.clone())) + .collect(); + + let definitions = self + .definitions + .iter() + .zip(definition_exprs.iter()) + .map(|((symbol, _), expr)| (symbol.clone(), expr.clone())) + .collect(); + + Ok(Arc::new(Self { + input: Arc::new(first_input), + schema: Arc::clone(&self.schema), + measures, + definitions, + })) + } + + fn dyn_hash(&self, state: &mut dyn Hasher) { + state.write_usize(Arc::as_ptr(&self.input) as usize); + state.write_usize(self.measures.len()); + state.write_usize(self.definitions.len()); + } + + fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool { + other + .as_any() + .downcast_ref::() + .map(|o| { + Arc::ptr_eq(&self.input, &o.input) + && self.measures == o.measures + && self.definitions == o.definitions + }) + .unwrap_or(false) + } + + fn dyn_ord(&self, other: &dyn UserDefinedLogicalNode) -> Option { + if self.dyn_eq(other) { + Some(Ordering::Equal) + } else { + None + } + } +} + +/// Custom planner that handles MATCH_RECOGNIZE table factor syntax +#[derive(Debug)] +struct MatchRecognizePlanner; + +impl RelationPlanner for MatchRecognizePlanner { + fn plan_relation( + &self, + relation: TableFactor, + context: &mut dyn RelationPlannerContext, + ) -> Result { + if let TableFactor::MatchRecognize { + table, + measures, + symbols, + alias, + .. + } = relation + { + println!("[MatchRecognizePlanner] Processing MATCH_RECOGNIZE clause"); + + // DEMONSTRATE context.plan(): Recursively plan the input table + println!("[MatchRecognizePlanner] Using context.plan() to plan input table"); + let input = context.plan(*table)?; + let input_schema = input.schema().clone(); + println!( + "[MatchRecognizePlanner] Input schema has {} fields", + input_schema.fields().len() + ); + + // DEMONSTRATE normalize_ident() and sql_to_expr(): Process MEASURES + let planned_measures = measures + .iter() + .map(|measure| { + // Normalize the measure alias + let alias = context.normalize_ident(measure.alias.clone()); + println!("[MatchRecognizePlanner] Normalized measure alias: {alias}"); + + // Convert SQL expression to DataFusion expression + let expr = context + .sql_to_expr(measure.expr.clone(), input_schema.as_ref())?; + println!( + "[MatchRecognizePlanner] Planned measure expression: {expr:?}" + ); + + Ok((alias, expr)) + }) + .collect::>>()?; + + // DEMONSTRATE normalize_ident() and sql_to_expr(): Process DEFINE + let planned_definitions = symbols + .iter() + .map(|symbol| { + // Normalize the symbol name + let name = context.normalize_ident(symbol.symbol.clone()); + println!("[MatchRecognizePlanner] Normalized symbol: {name}"); + + // Convert SQL expression to DataFusion expression + let expr = context + .sql_to_expr(symbol.definition.clone(), input_schema.as_ref())?; + println!("[MatchRecognizePlanner] Planned definition: {expr:?}"); + + Ok((name, expr)) + }) + .collect::>>()?; + + // Create the custom MATCH_RECOGNIZE node + let node = MiniMatchRecognizeNode { + schema: Arc::clone(&input_schema), + input: Arc::new(input), + measures: planned_measures, + definitions: planned_definitions, + }; + + let plan = LogicalPlan::Extension(Extension { + node: Arc::new(node), + }); + + println!("[MatchRecognizePlanner] Successfully created MATCH_RECOGNIZE plan"); + return Ok(RelationPlanning::Planned(PlannedRelation::new(plan, alias))); + } + + Ok(RelationPlanning::Original(relation)) + } +} diff --git a/datafusion-examples/examples/relation_planner/pivot_unpivot.rs b/datafusion-examples/examples/relation_planner/pivot_unpivot.rs new file mode 100644 index 000000000000..491986619042 --- /dev/null +++ b/datafusion-examples/examples/relation_planner/pivot_unpivot.rs @@ -0,0 +1,542 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This example demonstrates using custom relation planners to implement +//! PIVOT and UNPIVOT operations for reshaping data. +//! +//! PIVOT transforms rows into columns (wide format), while UNPIVOT does the +//! reverse, transforming columns into rows (long format). This example shows +//! how to use custom planners to implement these SQL clauses by rewriting them +//! into equivalent standard SQL operations: +//! +//! - PIVOT is rewritten to GROUP BY with CASE expressions +//! - UNPIVOT is rewritten to UNION ALL of projections + +use std::sync::Arc; + +use arrow::array::{ArrayRef, Int64Array, StringArray}; +use arrow::record_batch::RecordBatch; +use datafusion::prelude::*; +use datafusion_common::{DataFusionError, Result, ScalarValue}; +use datafusion_expr::{ + case, lit, + logical_plan::builder::LogicalPlanBuilder, + planner::{ + PlannedRelation, RelationPlanner, RelationPlannerContext, RelationPlanning, + }, + Expr, +}; +use datafusion_sql::sqlparser::ast::TableFactor; + +#[tokio::main] +async fn main() -> Result<()> { + let ctx = SessionContext::new(); + + // Register sample data tables + register_sample_data(&ctx)?; + + // Register custom planner + ctx.register_relation_planner(Arc::new(PivotUnpivotPlanner))?; + + println!("Custom Relation Planner: PIVOT and UNPIVOT Operations"); + println!("======================================================\n"); + + // Example 1: Basic PIVOT to transform monthly sales data from rows to columns + // Shows: How to pivot sales data so each quarter becomes a column + // The PIVOT is rewritten to: SELECT region, SUM(CASE WHEN quarter = 'Q1' THEN amount END) as Q1, + // SUM(CASE WHEN quarter = 'Q2' THEN amount END) as Q2 + // FROM quarterly_sales GROUP BY region + // Expected Output: + // +--------+------+------+ + // | region | Q1 | Q2 | + // +--------+------+------+ + // | North | 1000 | 1500 | + // | South | 1200 | 1300 | + // +--------+------+------+ + run_example( + &ctx, + "Example 1: Basic PIVOT - Transform quarters from rows to columns", + r#"SELECT * FROM quarterly_sales + PIVOT ( + SUM(amount) + FOR quarter IN ('Q1', 'Q2') + ) AS pivoted"#, + ) + .await?; + + // Example 2: PIVOT with multiple aggregate functions + // Shows: How to apply multiple aggregations (SUM and AVG) during pivot + // Expected: Logical plan showing MiniPivot with both SUM and AVG aggregates + // Actual (Logical Plan): + // Projection: pivoted.region, pivoted.Q1, pivoted.Q2 + // SubqueryAlias: pivoted + // MiniPivot aggregate=[SUM(amount), AVG(amount)] value_column=[quarter] values=["Q1", "Q2"] + // Values: (Utf8("North"), Utf8("Q1"), Int64(1000)), (Utf8("North"), Utf8("Q2"), Int64(1500)), ... + run_example( + &ctx, + "Example 2: PIVOT with multiple aggregates (SUM and AVG)", + r#"SELECT * FROM quarterly_sales + PIVOT ( + SUM(amount), AVG(amount) + FOR quarter IN ('Q1', 'Q2') + ) AS pivoted"#, + ) + .await?; + + // Example 3: PIVOT with additional grouping columns + // Shows: How pivot works when there are multiple non-pivot columns + // The region and product both appear in GROUP BY + // Expected Output: + // +--------+-----------+------+------+ + // | region | product | Q1 | Q2 | + // +--------+-----------+------+------+ + // | North | ProductA | 500 | | + // | North | ProductB | 500 | | + // | South | ProductA | | 650 | + // +--------+-----------+------+------+ + run_example( + &ctx, + "Example 3: PIVOT with multiple grouping columns", + r#"SELECT * FROM product_sales + PIVOT ( + SUM(amount) + FOR quarter IN ('Q1', 'Q2') + ) AS pivoted"#, + ) + .await?; + + // Example 4: Basic UNPIVOT to transform columns back into rows + // Shows: How to unpivot wide-format data into long format + // The UNPIVOT is rewritten to: + // SELECT region, 'q1_label' as quarter, q1 as sales FROM wide_sales + // UNION ALL + // SELECT region, 'q2_label' as quarter, q2 as sales FROM wide_sales + // Expected Output: + // +--------+----------+-------+ + // | region | quarter | sales | + // +--------+----------+-------+ + // | North | q1_label | 1000 | + // | South | q1_label | 1200 | + // | North | q2_label | 1500 | + // | South | q2_label | 1300 | + // +--------+----------+-------+ + run_example( + &ctx, + "Example 4: Basic UNPIVOT - Transform columns to rows", + r#"SELECT * FROM wide_sales + UNPIVOT ( + sales FOR quarter IN (q1 AS 'q1_label', q2 AS 'q2_label') + ) AS unpivoted"#, + ) + .await?; + + // Example 5: UNPIVOT with INCLUDE NULLS + // Shows: How null handling works in UNPIVOT operations + // With INCLUDE NULLS, the filter `sales IS NOT NULL` is NOT added + // Expected: Same output as Example 4 (no nulls in this dataset anyway) + run_example( + &ctx, + "Example 5: UNPIVOT with INCLUDE NULLS", + r#"SELECT * FROM wide_sales + UNPIVOT INCLUDE NULLS ( + sales FOR quarter IN (q1 AS 'q1_label', q2 AS 'q2_label') + ) AS unpivoted"#, + ) + .await?; + + // Example 6: Simple PIVOT with projection + // Shows: PIVOT works seamlessly with other SQL operations like projection + // We can select specific columns after pivoting + run_example( + &ctx, + "Example 6: PIVOT with projection", + r#"SELECT region FROM quarterly_sales + PIVOT (SUM(amount) FOR quarter IN ('Q1', 'Q2')) AS pivoted"#, + ) + .await?; + + Ok(()) +} + +/// Register sample data tables for the examples +fn register_sample_data(ctx: &SessionContext) -> Result<()> { + // Create quarterly_sales table: region, quarter, amount + let region: ArrayRef = + Arc::new(StringArray::from(vec!["North", "North", "South", "South"])); + let quarter: ArrayRef = Arc::new(StringArray::from(vec!["Q1", "Q2", "Q1", "Q2"])); + let amount: ArrayRef = Arc::new(Int64Array::from(vec![1000, 1500, 1200, 1300])); + let batch = RecordBatch::try_from_iter(vec![ + ("region", region), + ("quarter", quarter), + ("amount", amount), + ])?; + ctx.register_batch("quarterly_sales", batch)?; + + // Create product_sales table: region, quarter, product, amount + let region: ArrayRef = Arc::new(StringArray::from(vec!["North", "North", "South"])); + let quarter: ArrayRef = Arc::new(StringArray::from(vec!["Q1", "Q1", "Q2"])); + let product: ArrayRef = + Arc::new(StringArray::from(vec!["ProductA", "ProductB", "ProductA"])); + let amount: ArrayRef = Arc::new(Int64Array::from(vec![500, 500, 650])); + let batch = RecordBatch::try_from_iter(vec![ + ("region", region), + ("quarter", quarter), + ("product", product), + ("amount", amount), + ])?; + ctx.register_batch("product_sales", batch)?; + + // Create wide_sales table: region, q1, q2 + let region: ArrayRef = Arc::new(StringArray::from(vec!["North", "South"])); + let q1: ArrayRef = Arc::new(Int64Array::from(vec![1000, 1200])); + let q2: ArrayRef = Arc::new(Int64Array::from(vec![1500, 1300])); + let batch = + RecordBatch::try_from_iter(vec![("region", region), ("q1", q1), ("q2", q2)])?; + ctx.register_batch("wide_sales", batch)?; + + Ok(()) +} + +async fn run_example(ctx: &SessionContext, title: &str, sql: &str) -> Result<()> { + println!("{title}:\n{sql}\n"); + let df = ctx.sql(sql).await?; + + // Show the logical plan to demonstrate the rewrite + println!("Rewritten Logical Plan:"); + println!("{}\n", df.logical_plan().display_indent()); + + // Execute and show results + println!("Results:"); + df.show().await?; + println!(); + Ok(()) +} + +/// Helper function to extract column name from an expression +fn get_column_name(expr: &Expr) -> Option { + match expr { + Expr::Column(col) => Some(col.name.clone()), + _ => None, + } +} + +/// Custom planner that handles PIVOT and UNPIVOT table factor syntax +#[derive(Debug)] +struct PivotUnpivotPlanner; + +impl RelationPlanner for PivotUnpivotPlanner { + fn plan_relation( + &self, + relation: TableFactor, + context: &mut dyn RelationPlannerContext, + ) -> Result { + match relation { + // Handle PIVOT operations + TableFactor::Pivot { + table, + aggregate_functions, + value_column, + value_source, + alias, + .. + } => { + println!("[PivotUnpivotPlanner] Processing PIVOT clause"); + + // Plan the input table + let input = context.plan(*table)?; + let input_schema = input.schema().clone(); + println!( + "[PivotUnpivotPlanner] Input schema has {} fields", + input_schema.fields().len() + ); + + // Process aggregate functions + let aggregates = aggregate_functions + .iter() + .map(|agg| { + let expr = context + .sql_to_expr(agg.expr.clone(), input_schema.as_ref())?; + Ok(expr) + }) + .collect::>>()?; + + // Get the pivot column (should be a single column for simple case) + if value_column.len() != 1 { + return Err(DataFusionError::Plan( + "Only single-column pivot supported in this example".into(), + )); + } + let pivot_col = context + .sql_to_expr(value_column[0].clone(), input_schema.as_ref())?; + let pivot_col_name = get_column_name(&pivot_col).ok_or_else(|| { + DataFusionError::Plan( + "Pivot column must be a column reference".into(), + ) + })?; + + // Process pivot values + use datafusion_sql::sqlparser::ast::PivotValueSource; + let pivot_values = match value_source { + PivotValueSource::List(list) => list + .iter() + .map(|item| { + let alias = item + .alias + .as_ref() + .map(|id| context.normalize_ident(id.clone())); + let expr = context + .sql_to_expr(item.expr.clone(), input_schema.as_ref())?; + Ok((alias, expr)) + }) + .collect::>>()?, + _ => { + return Err(DataFusionError::Plan( + "Dynamic pivot (ANY/Subquery) not supported in this example" + .into(), + )); + } + }; + + // Build the rewritten plan: GROUP BY with CASE expressions + // For each aggregate and each pivot value, create: aggregate(CASE WHEN pivot_col = value THEN agg_input END) + + let mut pivot_exprs = Vec::new(); + + // Determine grouping columns (all non-pivot columns, excluding aggregate inputs) + // Collect aggregate input column names + let agg_input_cols: Vec = aggregates + .iter() + .filter_map(|agg| { + if let Expr::AggregateFunction(agg_fn) = agg { + agg_fn.params.args.first().and_then(get_column_name) + } else { + None + } + }) + .collect(); + + let mut group_by_cols = Vec::new(); + for field in input_schema.fields() { + let field_name = field.name(); + // Include in GROUP BY if it's not the pivot column and not an aggregate input + if field_name != &pivot_col_name + && !agg_input_cols.contains(field_name) + { + group_by_cols.push(col(field_name)); + } + } + + // Create CASE expressions for each (aggregate, pivot_value) pair + for agg_func in &aggregates { + for (alias_opt, pivot_value) in &pivot_values { + // Get the input to the aggregate function + let agg_input = if let Expr::AggregateFunction(agg_fn) = agg_func + { + if agg_fn.params.args.len() == 1 { + agg_fn.params.args[0].clone() + } else { + return Err(DataFusionError::Plan( + "Only single-argument aggregates supported in this example".into(), + )); + } + } else { + return Err(DataFusionError::Plan( + "Expected aggregate function".into(), + )); + }; + + // Create: CASE WHEN pivot_col = pivot_value THEN agg_input END + let case_expr = case(col(&pivot_col_name)) + .when(pivot_value.clone(), agg_input) + .end()?; + + // Wrap in aggregate function + let pivoted_expr = + if let Expr::AggregateFunction(agg_fn) = agg_func { + agg_fn.func.call(vec![case_expr]) + } else { + return Err(DataFusionError::Plan( + "Expected aggregate function".into(), + )); + }; + + // Determine the column alias + let value_part = alias_opt.clone().unwrap_or_else(|| { + if let Expr::Literal(ScalarValue::Utf8(Some(s)), _) = + pivot_value + { + s.clone() + } else if let Expr::Literal(lit, _) = pivot_value { + format!("{lit}") + } else { + format!("{pivot_value}") + } + }); + + // If there are multiple aggregates, prefix with function name + let col_alias = if aggregates.len() > 1 { + let agg_name = + if let Expr::AggregateFunction(agg_fn) = agg_func { + agg_fn.func.name() + } else { + "agg" + }; + format!("{agg_name}_{value_part}") + } else { + value_part + }; + + pivot_exprs.push(pivoted_expr.alias(col_alias)); + } + } + + // Build the final plan: GROUP BY with aggregations + let plan = LogicalPlanBuilder::from(input) + .aggregate(group_by_cols, pivot_exprs)? + .build()?; + + println!("[PivotUnpivotPlanner] Successfully rewrote PIVOT to GROUP BY with CASE"); + Ok(RelationPlanning::Planned(PlannedRelation::new(plan, alias))) + } + + // Handle UNPIVOT operations + TableFactor::Unpivot { + table, + value, + name, + columns, + null_inclusion, + alias, + } => { + println!("[PivotUnpivotPlanner] Processing UNPIVOT clause"); + + // Plan the input table + let input = context.plan(*table)?; + let input_schema = input.schema().clone(); + println!( + "[PivotUnpivotPlanner] Input schema has {} fields", + input_schema.fields().len() + ); + + // Get output column names + let value_col_name = format!("{value}"); + let name_col_name = context.normalize_ident(name.clone()); + println!( + "[PivotUnpivotPlanner] Value column name (output): {value_col_name}" + ); + println!( + "[PivotUnpivotPlanner] Name column name (output): {name_col_name}" + ); + + // Process columns to unpivot + let unpivot_cols = columns + .iter() + .map(|col| { + let label = col + .alias + .as_ref() + .map(|id| context.normalize_ident(id.clone())) + .unwrap_or_else(|| format!("{}", col.expr)); + + let expr = context + .sql_to_expr(col.expr.clone(), input_schema.as_ref())?; + let col_name = get_column_name(&expr) + .ok_or_else(|| DataFusionError::Plan("Unpivot column must be a column reference".into()))?; + + println!( + "[PivotUnpivotPlanner] Will unpivot column '{col_name}' with label '{label}'" + ); + + Ok((col_name, label)) + }) + .collect::>>()?; + + // Determine which columns to keep (not being unpivoted) + let keep_cols: Vec = input_schema + .fields() + .iter() + .filter_map(|field| { + let field_name = field.name(); + if !unpivot_cols.iter().any(|(col, _)| col == field_name) { + Some(field_name.clone()) + } else { + None + } + }) + .collect(); + + // Build UNION ALL of projections + // For each unpivot column, create: SELECT keep_cols..., 'label' as name_col, col as value_col FROM input + let mut union_inputs = Vec::new(); + for (col_name, label) in &unpivot_cols { + let mut projection = Vec::new(); + + // Add all columns we're keeping + for keep_col in &keep_cols { + projection.push(col(keep_col)); + } + + // Add the name column (constant label) + projection.push(lit(label.clone()).alias(&name_col_name)); + + // Add the value column (the column being unpivoted) + projection.push(col(col_name).alias(&value_col_name)); + + let projected = LogicalPlanBuilder::from(input.clone()) + .project(projection)? + .build()?; + + union_inputs.push(projected); + } + + // Build UNION ALL + if union_inputs.is_empty() { + return Err(DataFusionError::Plan( + "UNPIVOT requires at least one column".into(), + )); + } + + let mut union_iter = union_inputs.into_iter(); + let mut plan = union_iter.next().unwrap(); + for union_input in union_iter { + plan = LogicalPlanBuilder::from(plan) + .union(LogicalPlanBuilder::from(union_input).build()?)? + .build()?; + } + + // Handle EXCLUDE NULLS (default) by filtering out nulls + if null_inclusion.is_none() + || matches!( + null_inclusion, + Some(datafusion_sql::sqlparser::ast::NullInclusion::ExcludeNulls) + ) + { + plan = LogicalPlanBuilder::from(plan) + .filter(col(&value_col_name).is_not_null())? + .build()?; + } + + println!( + "[PivotUnpivotPlanner] Successfully rewrote UNPIVOT to UNION ALL" + ); + Ok(RelationPlanning::Planned(PlannedRelation::new(plan, alias))) + } + + other => Ok(RelationPlanning::Original(other)), + } + } +} diff --git a/datafusion-examples/examples/relation_planner/table_sample.rs b/datafusion-examples/examples/relation_planner/table_sample.rs new file mode 100644 index 000000000000..dcc9222f627a --- /dev/null +++ b/datafusion-examples/examples/relation_planner/table_sample.rs @@ -0,0 +1,963 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This example demonstrates using custom relation planners to implement +//! SQL TABLESAMPLE clause support. +//! +//! TABLESAMPLE allows sampling a fraction or number of rows from a table: +//! - `SELECT * FROM table TABLESAMPLE BERNOULLI(10)` - 10% sample +//! - `SELECT * FROM table TABLESAMPLE (100 ROWS)` - 100 rows +//! - `SELECT * FROM table TABLESAMPLE (10 PERCENT) REPEATABLE(42)` - Reproducible + +use std::{ + any::Any, + fmt::{self, Debug, Formatter}, + hash::{Hash, Hasher}, + ops::{Add, Div, Mul, Sub}, + pin::Pin, + str::FromStr, + sync::Arc, + task::{Context, Poll}, +}; + +use arrow::{ + array::{ArrayRef, Int32Array, RecordBatch, StringArray, UInt32Array}, + compute, +}; +use arrow_schema::SchemaRef; +use futures::{ + ready, + stream::{Stream, StreamExt}, +}; +use rand::{rngs::StdRng, Rng, SeedableRng}; +use rand_distr::{Distribution, Poisson}; +use tonic::async_trait; + +use datafusion::{ + execution::{ + context::QueryPlanner, RecordBatchStream, SendableRecordBatchStream, + SessionState, SessionStateBuilder, TaskContext, + }, + physical_expr::EquivalenceProperties, + physical_plan::{ + metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, RecordOutput}, + DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, + }, + physical_planner::{DefaultPhysicalPlanner, ExtensionPlanner, PhysicalPlanner}, + prelude::*, +}; +use datafusion_common::{ + internal_err, not_impl_err, plan_datafusion_err, plan_err, DFSchemaRef, + DataFusionError, Result, Statistics, +}; +use datafusion_expr::{ + logical_plan::{Extension, LogicalPlan, LogicalPlanBuilder}, + planner::{ + PlannedRelation, RelationPlanner, RelationPlannerContext, RelationPlanning, + }, + UserDefinedLogicalNode, UserDefinedLogicalNodeCore, +}; +use datafusion_sql::sqlparser::ast::{ + self, TableFactor, TableSampleMethod, TableSampleUnit, +}; + +#[tokio::main] +async fn main() -> Result<()> { + let state = SessionStateBuilder::new() + .with_default_features() + .with_query_planner(Arc::new(TableSampleQueryPlanner {})) + .build(); + + let ctx = SessionContext::new_with_state(state.clone()); + + // Register sample data table + register_sample_data(&ctx)?; + + // Register custom planner + ctx.register_relation_planner(Arc::new(TableSamplePlanner))?; + + println!("Custom Relation Planner: TABLESAMPLE Support"); + println!("============================================\n"); + println!("Note: This shows logical planning for TABLESAMPLE."); + println!("Physical execution requires additional implementation.\n"); + + // Example 1: Full table without any sampling (baseline) + // Shows: Complete dataset with all 10 rows (1-10 with row_1 to row_10) + // Expected: 10 rows showing the full sample_data table + // Actual: + // +---------+---------+ + // | column1 | column2 | + // +---------+---------+ + // | 1 | row_1 | + // | 2 | row_2 | + // | 3 | row_3 | + // | 4 | row_4 | + // | 5 | row_5 | + // | 6 | row_6 | + // | 7 | row_7 | + // | 8 | row_8 | + // | 9 | row_9 | + // | 10 | row_10 | + // +---------+---------+ + run_example( + &ctx, + "Example 1: Full table (no sampling)", + "SELECT * FROM sample_data", + ) + .await?; + + // Example 2: TABLESAMPLE with BERNOULLI sampling at 30% probability + // Shows: Random sampling where each row has 30% chance of being selected + // Expected: ~3 rows (varies due to randomness) from the 10-row dataset + // Actual: + // +---------+---------+ + // | column1 | column2 | + // +---------+---------+ + // | 4 | row_4 | + // | 6 | row_6 | + // | 9 | row_9 | + // +---------+---------+ + run_example( + &ctx, + "Example 2: TABLESAMPLE with percentage", + "SELECT * FROM sample_data TABLESAMPLE BERNOULLI(30 PERCENT)", + ) + .await?; + + // Example 3: TABLESAMPLE with fractional sampling (50% of data) + // Shows: Random sampling using decimal fraction instead of percentage + // Expected: ~5 rows (varies due to randomness) from the 10-row dataset + // Actual: + // +---------+---------+ + // | column1 | column2 | + // +---------+---------+ + // | 3 | row_3 | + // | 4 | row_4 | + // | 5 | row_5 | + // +---------+---------+ + run_example( + &ctx, + "Example 3: TABLESAMPLE with fraction", + "SELECT * FROM sample_data TABLESAMPLE (0.5)", + ) + .await?; + + // Example 4: TABLESAMPLE with REPEATABLE seed for reproducible results + // Shows: Deterministic sampling using a fixed seed for consistent results + // Expected: Same rows selected each time due to fixed seed (42) + // Actual: + // +---------+---------+ + // | column1 | column2 | + // +---------+---------+ + // | 5 | row_5 | + // | 9 | row_9 | + // | 10 | row_10 | + // +---------+---------+ + run_example( + &ctx, + "Example 4: TABLESAMPLE with REPEATABLE seed", + "SELECT * FROM sample_data TABLESAMPLE (0.3) REPEATABLE(42)", + ) + .await?; + + // Example 5: TABLESAMPLE with exact row count limit + // Shows: Sampling by limiting to a specific number of rows (not probabilistic) + // Expected: Exactly 3 rows (first 3 rows from the dataset) + // Actual: + // +---------+---------+ + // | column1 | column2 | + // +---------+---------+ + // | 1 | row_1 | + // | 2 | row_2 | + // | 3 | row_3 | + // +---------+---------+ + run_example( + &ctx, + "Example 5: TABLESAMPLE with row count", + "SELECT * FROM sample_data TABLESAMPLE (3 ROWS)", + ) + .await?; + + // Example 6: TABLESAMPLE combined with WHERE clause filtering + // Shows: How sampling works with other query operations like filtering + // Expected: 3 rows where column1 > 2 (from the 5-row sample) + // Actual: + // +---------+---------+ + // | column1 | column2 | + // +---------+---------+ + // | 3 | row_3 | + // | 4 | row_4 | + // | 5 | row_5 | + // +---------+---------+ + run_example( + &ctx, + "Example 6: TABLESAMPLE with WHERE clause", + r#"SELECT * FROM sample_data + TABLESAMPLE (5 ROWS) + WHERE column1 > 2"#, + ) + .await?; + + // Example 7: JOIN between two independently sampled tables + // Shows: How sampling works in complex queries with multiple table references + // Expected: Rows where both sampled tables have matching column1 values + // Actual: + // +---------+---------+---------+---------+ + // | column1 | column1 | column2 | column2 | + // +---------+---------+---------+---------+ + // | 2 | 2 | row_2 | row_2 | + // | 8 | 8 | row_8 | row_8 | + // | 10 | 10 | row_10 | row_10 | + // +---------+---------+---------+---------+ + run_example( + &ctx, + "Example 7: JOIN between two different TABLESAMPLE tables", + r#"SELECT t1.column1, t2.column1, t1.column2, t2.column2 + FROM sample_data t1 TABLESAMPLE (0.7) + JOIN sample_data t2 TABLESAMPLE (0.7) + ON t1.column1 = t2.column1"#, + ) + .await?; + + Ok(()) +} + +/// Register sample data table for the examples +fn register_sample_data(ctx: &SessionContext) -> Result<()> { + // Create sample_data table with 10 rows: column1 (1-10), column2 (row_1 to row_10) + let column1: ArrayRef = Arc::new(Int32Array::from((1..=10).collect::>())); + let column2: ArrayRef = Arc::new(StringArray::from( + (1..=10) + .map(|i| format!("row_{i}")) + .collect::>(), + )); + let batch = + RecordBatch::try_from_iter(vec![("column1", column1), ("column2", column2)])?; + ctx.register_batch("sample_data", batch)?; + + Ok(()) +} + +async fn run_example(ctx: &SessionContext, title: &str, sql: &str) -> Result<()> { + println!("{title}:\n{sql}\n"); + let df = ctx.sql(sql).await?; + println!("Logical Plan:\n{}\n", df.logical_plan().display_indent()); + df.show().await?; + Ok(()) +} + +/// Hashable and comparable f64 for sampling bounds +#[derive(Debug, Clone, Copy, PartialOrd)] +struct Bound(f64); + +impl PartialEq for Bound { + fn eq(&self, other: &Self) -> bool { + (self.0 - other.0).abs() < f64::EPSILON + } +} + +impl Eq for Bound {} + +impl Hash for Bound { + fn hash(&self, state: &mut H) { + // Hash the bits of the f64 + self.0.to_bits().hash(state); + } +} + +impl From for Bound { + fn from(value: f64) -> Self { + Self(value) + } +} +impl From for f64 { + fn from(value: Bound) -> Self { + value.0 + } +} + +impl AsRef for Bound { + fn as_ref(&self) -> &f64 { + &self.0 + } +} + +#[derive(Debug, Clone, Hash, Eq, PartialEq, PartialOrd)] +struct TableSamplePlanNode { + inner_plan: LogicalPlan, + + lower_bound: Bound, + upper_bound: Bound, + with_replacement: bool, + seed: u64, +} + +impl TableSamplePlanNode { + pub fn new( + input: LogicalPlan, + fraction: f64, + with_replacement: Option, + seed: Option, + ) -> Self { + TableSamplePlanNode { + inner_plan: input, + lower_bound: Bound::from(0.0), + upper_bound: Bound::from(fraction), + with_replacement: with_replacement.unwrap_or(false), + seed: seed.unwrap_or_else(rand::random), + } + } + + pub fn into_plan(self) -> LogicalPlan { + LogicalPlan::Extension(Extension { + node: Arc::new(self), + }) + } +} + +impl UserDefinedLogicalNodeCore for TableSamplePlanNode { + fn name(&self) -> &str { + "TableSample" + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + vec![&self.inner_plan] + } + + fn schema(&self) -> &DFSchemaRef { + self.inner_plan.schema() + } + + fn expressions(&self) -> Vec { + vec![] + } + + fn fmt_for_explain(&self, f: &mut Formatter) -> fmt::Result { + f.write_fmt(format_args!( + "Sample: {:?} {:?} {:?}", + self.lower_bound, self.upper_bound, self.seed + )) + } + + fn with_exprs_and_inputs( + &self, + _exprs: Vec, + inputs: Vec, + ) -> Result { + let input = inputs + .first() + .ok_or(DataFusionError::Plan("Should have input".into()))?; + Ok(Self { + inner_plan: input.clone(), + lower_bound: self.lower_bound, + upper_bound: self.upper_bound, + with_replacement: self.with_replacement, + seed: self.seed, + }) + } +} + +/// Execution planner with `SampleExec` for `TableSamplePlanNode` +struct TableSampleExtensionPlanner {} + +impl TableSampleExtensionPlanner { + fn build_execution_plan( + &self, + specific_node: &TableSamplePlanNode, + physical_input: Arc, + ) -> Result> { + Ok(Arc::new(SampleExec { + input: physical_input.clone(), + lower_bound: 0.0, + upper_bound: specific_node.upper_bound.into(), + with_replacement: specific_node.with_replacement, + seed: specific_node.seed, + metrics: Default::default(), + cache: SampleExec::compute_properties(&physical_input), + })) + } +} + +#[async_trait] +impl ExtensionPlanner for TableSampleExtensionPlanner { + /// Create a physical plan for an extension node + async fn plan_extension( + &self, + _planner: &dyn PhysicalPlanner, + node: &dyn UserDefinedLogicalNode, + logical_inputs: &[&LogicalPlan], + physical_inputs: &[Arc], + _session_state: &SessionState, + ) -> Result>> { + if let Some(specific_node) = node.as_any().downcast_ref::() { + println!("Extension planner plan_extension: {:?}", &logical_inputs); + assert_eq!(logical_inputs.len(), 1, "Inconsistent number of inputs"); + assert_eq!(physical_inputs.len(), 1, "Inconsistent number of inputs"); + + let exec_plan = + self.build_execution_plan(specific_node, physical_inputs[0].clone())?; + Ok(Some(exec_plan)) + } else { + Ok(None) + } + } +} + +/// Query planner supporting a `TableSampleExtensionPlanner` +#[derive(Debug)] +struct TableSampleQueryPlanner {} + +#[async_trait] +impl QueryPlanner for TableSampleQueryPlanner { + /// Given a `LogicalPlan` created from above, create an + /// `ExecutionPlan` suitable for execution + async fn create_physical_plan( + &self, + logical_plan: &LogicalPlan, + session_state: &SessionState, + ) -> Result> { + // Additional extension for table sample node + let physical_planner = + DefaultPhysicalPlanner::with_extension_planners(vec![Arc::new( + TableSampleExtensionPlanner {}, + )]); + // Delegate most work of physical planning to the default physical planner + physical_planner + .create_physical_plan(logical_plan, session_state) + .await + } +} + +/// Physical plan implementation +trait Sampler: Send + Sync { + fn sample(&mut self, batch: &RecordBatch) -> Result; +} + +struct BernoulliSampler { + lower_bound: f64, + upper_bound: f64, + rng: StdRng, +} + +impl BernoulliSampler { + fn new(lower_bound: f64, upper_bound: f64, seed: u64) -> Self { + Self { + lower_bound, + upper_bound, + rng: StdRng::seed_from_u64(seed), + } + } +} + +impl Sampler for BernoulliSampler { + fn sample(&mut self, batch: &RecordBatch) -> Result { + if self.upper_bound <= self.lower_bound { + return Ok(RecordBatch::new_empty(batch.schema())); + } + + let mut indices = Vec::new(); + + for i in 0..batch.num_rows() { + let rnd: f64 = self.rng.random(); + + if rnd >= self.lower_bound && rnd < self.upper_bound { + indices.push(i as u32); + } + } + + if indices.is_empty() { + return Ok(RecordBatch::new_empty(batch.schema())); + } + let indices = UInt32Array::from(indices); + compute::take_record_batch(batch, &indices).map_err(|e| e.into()) + } +} + +struct PoissonSampler { + ratio: f64, + poisson: Poisson, + rng: StdRng, +} + +impl PoissonSampler { + fn try_new(ratio: f64, seed: u64) -> Result { + let poisson = Poisson::new(ratio).map_err(|e| plan_datafusion_err!("{}", e))?; + Ok(Self { + ratio, + poisson, + rng: StdRng::seed_from_u64(seed), + }) + } +} + +impl Sampler for PoissonSampler { + fn sample(&mut self, batch: &RecordBatch) -> Result { + if self.ratio <= 0.0 { + return Ok(RecordBatch::new_empty(batch.schema())); + } + + let mut indices = Vec::new(); + + for i in 0..batch.num_rows() { + let k = self.poisson.sample(&mut self.rng) as i32; + for _ in 0..k { + indices.push(i as u32); + } + } + + if indices.is_empty() { + return Ok(RecordBatch::new_empty(batch.schema())); + } + + let indices = UInt32Array::from(indices); + compute::take_record_batch(batch, &indices).map_err(|e| e.into()) + } +} + +/// SampleExec samples rows from its input based on a sampling method. +/// This is used to implement SQL `SAMPLE` clause. +#[derive(Debug, Clone)] +pub struct SampleExec { + /// The input plan + input: Arc, + /// The lower bound of the sampling ratio + lower_bound: f64, + /// The upper bound of the sampling ratio + upper_bound: f64, + /// Whether to sample with replacement + with_replacement: bool, + /// Random seed for reproducible sampling + seed: u64, + /// Execution metrics + metrics: ExecutionPlanMetricsSet, + /// Properties equivalence properties, partitioning, etc. + cache: PlanProperties, +} + +impl SampleExec { + /// Create a new SampleExec with a custom sampling method + pub fn try_new( + input: Arc, + lower_bound: f64, + upper_bound: f64, + with_replacement: bool, + seed: u64, + ) -> Result { + if lower_bound < 0.0 || upper_bound > 1.0 || lower_bound > upper_bound { + return internal_err!( + "Sampling bounds must be between 0.0 and 1.0, and lower_bound <= upper_bound, got [{}, {}]", + lower_bound, upper_bound + ); + } + + let cache = Self::compute_properties(&input); + + Ok(Self { + input, + lower_bound, + upper_bound, + with_replacement, + seed, + metrics: ExecutionPlanMetricsSet::new(), + cache, + }) + } + + fn create_sampler(&self, partition: usize) -> Result> { + if self.with_replacement { + Ok(Box::new(PoissonSampler::try_new( + self.upper_bound - self.lower_bound, + self.seed + partition as u64, + )?)) + } else { + Ok(Box::new(BernoulliSampler::new( + self.lower_bound, + self.upper_bound, + self.seed + partition as u64, + ))) + } + } + + /// Whether to sample with replacement + pub fn with_replacement(&self) -> bool { + self.with_replacement + } + + /// The lower bound of the sampling ratio + pub fn lower_bound(&self) -> f64 { + self.lower_bound + } + + /// The upper bound of the sampling ratio + pub fn upper_bound(&self) -> f64 { + self.upper_bound + } + + /// The random seed + pub fn seed(&self) -> u64 { + self.seed + } + + /// The input plan + pub fn input(&self) -> &Arc { + &self.input + } + + /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. + fn compute_properties(input: &Arc) -> PlanProperties { + input + .properties() + .clone() + .with_eq_properties(EquivalenceProperties::new(input.schema())) + } +} + +impl DisplayAs for SampleExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!( + f, + "SampleExec: lower_bound={}, upper_bound={}, with_replacement={}, seed={}", + self.lower_bound, self.upper_bound, self.with_replacement, self.seed + ) + } + DisplayFormatType::TreeRender => { + write!( + f, + "SampleExec: lower_bound={}, upper_bound={}, with_replacement={}, seed={}", + self.lower_bound, self.upper_bound, self.with_replacement, self.seed + ) + } + } + } +} + +impl ExecutionPlan for SampleExec { + fn name(&self) -> &'static str { + "SampleExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.cache + } + + fn maintains_input_order(&self) -> Vec { + vec![false] // Sampling does not maintain input order + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(SampleExec::try_new( + Arc::clone(&children[0]), + self.lower_bound, + self.upper_bound, + self.with_replacement, + self.seed, + )?)) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + let input_stream = self.input.execute(partition, context)?; + let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); + + Ok(Box::pin(SampleExecStream { + input: input_stream, + sampler: self.create_sampler(partition)?, + baseline_metrics, + })) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn partition_statistics(&self, partition: Option) -> Result { + let input_stats = self.input.partition_statistics(partition)?; + + // Apply sampling ratio to statistics + let mut stats = input_stats; + let ratio = self.upper_bound - self.lower_bound; + + stats.num_rows = stats + .num_rows + .map(|nr| (nr as f64 * ratio) as usize) + .to_inexact(); + stats.total_byte_size = stats + .total_byte_size + .map(|tb| (tb as f64 * ratio) as usize) + .to_inexact(); + + Ok(stats) + } +} + +/// Stream for the SampleExec operator +struct SampleExecStream { + /// The input stream + input: SendableRecordBatchStream, + /// The sampling method + sampler: Box, + /// Runtime metrics recording + baseline_metrics: BaselineMetrics, +} + +impl Stream for SampleExecStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + match ready!(self.input.poll_next_unpin(cx)) { + Some(Ok(batch)) => { + let start = self.baseline_metrics.elapsed_compute().clone(); + let result = self.sampler.sample(&batch); + let result = result.record_output(&self.baseline_metrics); + let _timer = start.timer(); + Poll::Ready(Some(result)) + } + Some(Err(e)) => Poll::Ready(Some(Err(e))), + None => Poll::Ready(None), + } + } +} + +impl RecordBatchStream for SampleExecStream { + fn schema(&self) -> SchemaRef { + self.input.schema() + } +} + +/// Helper to evaluate numeric SQL expressions +fn evaluate_number< + T: FromStr + Add + Sub + Mul + Div, +>( + expr: &ast::Expr, +) -> Option { + match expr { + ast::Expr::BinaryOp { left, op, right } => { + let left = evaluate_number::(left); + let right = evaluate_number::(right); + match (left, right) { + (Some(left), Some(right)) => match op { + ast::BinaryOperator::Plus => Some(left + right), + ast::BinaryOperator::Minus => Some(left - right), + ast::BinaryOperator::Multiply => Some(left * right), + ast::BinaryOperator::Divide => Some(left / right), + _ => None, + }, + _ => None, + } + } + ast::Expr::Value(value) => match &value.value { + ast::Value::Number(value, _) => { + let value = value.to_string(); + let Ok(value) = value.parse::() else { + return None; + }; + Some(value) + } + _ => None, + }, + _ => None, + } +} + +/// Custom relation planner that handles TABLESAMPLE clauses +#[derive(Debug)] +struct TableSamplePlanner; + +impl RelationPlanner for TableSamplePlanner { + fn plan_relation( + &self, + relation: TableFactor, + context: &mut dyn RelationPlannerContext, + ) -> Result { + match relation { + TableFactor::Table { + sample: Some(sample), + alias, + name, + args, + with_hints, + version, + with_ordinality, + partitions, + json_path, + index_hints, + } => { + println!("[TableSamplePlanner] Processing TABLESAMPLE clause"); + + let sample = match sample { + ast::TableSampleKind::BeforeTableAlias(sample) => sample, + ast::TableSampleKind::AfterTableAlias(sample) => sample, + }; + if let Some(name) = &sample.name { + if *name != TableSampleMethod::Bernoulli + && *name != TableSampleMethod::Row + { + // Postgres-style sample. Not supported because DataFusion does not have a concept of pages like PostgreSQL. + return not_impl_err!("{} is not supported yet", name); + } + } + if sample.offset.is_some() { + // Clickhouse-style sample. Not supported because it requires knowing the total data size. + return not_impl_err!("Offset sample is not supported yet"); + } + + let seed = sample + .seed + .map(|seed| { + let Ok(seed) = seed.value.to_string().parse::() else { + return plan_err!("seed must be a number: {}", seed.value); + }; + Ok(seed) + }) + .transpose()?; + + let sampleless_relation = TableFactor::Table { + sample: None, + alias: alias.clone(), + name: name.clone(), + args: args.clone(), + with_hints: with_hints.clone(), + version: version.clone(), + with_ordinality, + partitions: partitions.clone(), + json_path: json_path.clone(), + index_hints: index_hints.clone(), + }; + let input = context.plan(sampleless_relation)?; + + if let Some(bucket) = sample.bucket { + if bucket.on.is_some() { + // Hive-style sample, only used when the Hive table is defined with CLUSTERED BY + return not_impl_err!( + "Bucket sample with ON is not supported yet" + ); + } + + let Ok(bucket_num) = bucket.bucket.to_string().parse::() else { + return plan_err!("bucket must be a number"); + }; + + let Ok(total_num) = bucket.total.to_string().parse::() else { + return plan_err!("total must be a number"); + }; + let value = bucket_num as f64 / total_num as f64; + let plan = + TableSamplePlanNode::new(input, value, None, seed).into_plan(); + return Ok(RelationPlanning::Planned(PlannedRelation::new( + plan, alias, + ))); + } + if let Some(quantity) = sample.quantity { + return match quantity.unit { + Some(TableSampleUnit::Rows) => { + let value = evaluate_number::(&quantity.value); + if value.is_none() { + return plan_err!( + "quantity must be a number: {:?}", + quantity.value + ); + } + let value = value.unwrap(); + if value < 0 { + return plan_err!( + "quantity must be a non-negative number: {:?}", + quantity.value + ); + } + Ok(RelationPlanning::Planned(PlannedRelation::new( + LogicalPlanBuilder::from(input) + .limit(0, Some(value as usize))? + .build()?, + alias, + ))) + } + Some(TableSampleUnit::Percent) => { + let value = evaluate_number::(&quantity.value); + if value.is_none() { + return plan_err!( + "quantity must be a number: {:?}", + quantity.value + ); + } + let value = value.unwrap() / 100.0; + let plan = TableSamplePlanNode::new(input, value, None, seed) + .into_plan(); + Ok(RelationPlanning::Planned(PlannedRelation::new( + plan, alias, + ))) + } + None => { + // Clickhouse-style sample + let value = evaluate_number::(&quantity.value); + if value.is_none() { + return plan_err!( + "quantity must be a valid number: {:?}", + quantity.value + ); + } + let value = value.unwrap(); + if value < 0.0 { + return plan_err!( + "quantity must be a non-negative number: {:?}", + quantity.value + ); + } + if value >= 1.0 { + // If value is larger than 1, it is a row limit + Ok(RelationPlanning::Planned(PlannedRelation::new( + LogicalPlanBuilder::from(input) + .limit(0, Some(value as usize))? + .build()?, + alias, + ))) + } else { + // If value is between 0.0 and 1.0, it is a fraction + let plan = + TableSamplePlanNode::new(input, value, None, seed) + .into_plan(); + Ok(RelationPlanning::Planned(PlannedRelation::new( + plan, alias, + ))) + } + } + }; + } + plan_err!("Cannot plan sample SQL") + } + other => Ok(RelationPlanning::Original(other)), + } + } +} diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index a8148b80495e..819b4afb8ee5 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -75,6 +75,8 @@ pub use datafusion_execution::config::SessionConfig; use datafusion_execution::registry::SerializerRegistry; pub use datafusion_execution::TaskContext; pub use datafusion_expr::execution_props::ExecutionProps; +#[cfg(feature = "sql")] +use datafusion_expr::planner::RelationPlanner; use datafusion_expr::{ expr_rewriter::FunctionRewrite, logical_plan::{DdlStatement, Statement}, @@ -1325,6 +1327,18 @@ impl SessionContext { self.state.write().register_udwf(Arc::new(f)).ok(); } + #[cfg(feature = "sql")] + /// Registers a [`RelationPlanner`] to customize SQL table-factor planning. + /// + /// Planners are invoked in reverse registration order, allowing newer + /// planners to take precedence over existing ones. + pub fn register_relation_planner( + &self, + planner: Arc, + ) -> Result<()> { + self.state.write().register_relation_planner(planner) + } + /// Deregisters a UDF within this context. pub fn deregister_udf(&self, name: &str) { self.state.write().deregister_udf(name).ok(); diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index b04004dd495c..51b895e5727a 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -52,7 +52,7 @@ use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr_rewriter::FunctionRewrite; use datafusion_expr::planner::ExprPlanner; #[cfg(feature = "sql")] -use datafusion_expr::planner::TypePlanner; +use datafusion_expr::planner::{RelationPlanner, TypePlanner}; use datafusion_expr::registry::{FunctionRegistry, SerializerRegistry}; use datafusion_expr::simplify::SimplifyInfo; #[cfg(feature = "sql")] @@ -138,6 +138,8 @@ pub struct SessionState { analyzer: Analyzer, /// Provides support for customizing the SQL planner, e.g. to add support for custom operators like `->>` or `?` expr_planners: Vec>, + #[cfg(feature = "sql")] + relation_planners: Vec>, /// Provides support for customizing the SQL type planning #[cfg(feature = "sql")] type_planner: Option>, @@ -207,6 +209,9 @@ impl Debug for SessionState { .field("function_factory", &self.function_factory) .field("expr_planners", &self.expr_planners); + #[cfg(feature = "sql")] + let ret = ret.field("relation_planners", &self.relation_planners); + #[cfg(feature = "sql")] let ret = ret.field("type_planner", &self.type_planner); @@ -570,6 +575,24 @@ impl SessionState { &self.expr_planners } + #[cfg(feature = "sql")] + /// Returns the registered relation planners in priority order. + pub fn relation_planners(&self) -> &[Arc] { + &self.relation_planners + } + + #[cfg(feature = "sql")] + /// Registers a [`RelationPlanner`] to customize SQL relation planning. + /// + /// Newly registered planners are given higher priority than existing ones. + pub fn register_relation_planner( + &mut self, + planner: Arc, + ) -> datafusion_common::Result<()> { + self.relation_planners.insert(0, planner); + Ok(()) + } + /// Returns the [`QueryPlanner`] for this session pub fn query_planner(&self) -> &Arc { &self.query_planner @@ -913,6 +936,8 @@ pub struct SessionStateBuilder { analyzer: Option, expr_planners: Option>>, #[cfg(feature = "sql")] + relation_planners: Option>>, + #[cfg(feature = "sql")] type_planner: Option>, optimizer: Option, physical_optimizers: Option, @@ -950,6 +975,8 @@ impl SessionStateBuilder { analyzer: None, expr_planners: None, #[cfg(feature = "sql")] + relation_planners: None, + #[cfg(feature = "sql")] type_planner: None, optimizer: None, physical_optimizers: None, @@ -1000,6 +1027,8 @@ impl SessionStateBuilder { analyzer: Some(existing.analyzer), expr_planners: Some(existing.expr_planners), #[cfg(feature = "sql")] + relation_planners: Some(existing.relation_planners), + #[cfg(feature = "sql")] type_planner: existing.type_planner, optimizer: Some(existing.optimizer), physical_optimizers: Some(existing.physical_optimizers), @@ -1140,6 +1169,16 @@ impl SessionStateBuilder { self } + #[cfg(feature = "sql")] + /// Sets the [`RelationPlanner`]s used to customize SQL relation planning. + pub fn with_relation_planners( + mut self, + relation_planners: Vec>, + ) -> Self { + self.relation_planners = Some(relation_planners); + self + } + /// Set the [`TypePlanner`] used to customize the behavior of the SQL planner. #[cfg(feature = "sql")] pub fn with_type_planner(mut self, type_planner: Arc) -> Self { @@ -1354,6 +1393,8 @@ impl SessionStateBuilder { analyzer, expr_planners, #[cfg(feature = "sql")] + relation_planners, + #[cfg(feature = "sql")] type_planner, optimizer, physical_optimizers, @@ -1384,6 +1425,8 @@ impl SessionStateBuilder { analyzer: analyzer.unwrap_or_default(), expr_planners: expr_planners.unwrap_or_default(), #[cfg(feature = "sql")] + relation_planners: relation_planners.unwrap_or_default(), + #[cfg(feature = "sql")] type_planner, optimizer: optimizer.unwrap_or_default(), physical_optimizers: physical_optimizers.unwrap_or_default(), @@ -1501,6 +1544,12 @@ impl SessionStateBuilder { &mut self.expr_planners } + #[cfg(feature = "sql")] + /// Returns a mutable reference to the current [`RelationPlanner`] list. + pub fn relation_planners(&mut self) -> &mut Option>> { + &mut self.relation_planners + } + /// Returns the current type_planner value #[cfg(feature = "sql")] pub fn type_planner(&mut self) -> &mut Option> { @@ -1675,6 +1724,10 @@ impl ContextProvider for SessionContextProvider<'_> { self.state.expr_planners() } + fn get_relation_planners(&self) -> &[Arc] { + self.state.relation_planners() + } + fn get_type_planner(&self) -> Option> { if let Some(type_planner) = &self.state.type_planner { Some(Arc::clone(type_planner)) diff --git a/datafusion/core/tests/user_defined/mod.rs b/datafusion/core/tests/user_defined/mod.rs index 5d84cdb69283..eb70e79dac44 100644 --- a/datafusion/core/tests/user_defined/mod.rs +++ b/datafusion/core/tests/user_defined/mod.rs @@ -33,5 +33,8 @@ mod user_defined_table_functions; /// Tests for Expression Planner mod expr_planner; +/// Tests for Relation Planner extensions +mod relation_planner; + /// Tests for insert operations mod insert_operation; diff --git a/datafusion/core/tests/user_defined/relation_planner.rs b/datafusion/core/tests/user_defined/relation_planner.rs new file mode 100644 index 000000000000..e5e46e8a6852 --- /dev/null +++ b/datafusion/core/tests/user_defined/relation_planner.rs @@ -0,0 +1,353 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::{Int64Array, RecordBatch, StringArray}; +use arrow::datatypes::{DataType, Field, Schema}; +use datafusion::catalog::memory::MemTable; +use datafusion::common::test_util::batches_to_string; +use datafusion::prelude::*; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::logical_plan::builder::LogicalPlanBuilder; +use datafusion_expr::planner::{ + PlannedRelation, RelationPlanner, RelationPlannerContext, RelationPlanning, +}; +use datafusion_expr::Expr; +use datafusion_sql::sqlparser::ast::TableFactor; + +/// A planner that creates an in-memory table with custom values +#[derive(Debug)] +struct CustomValuesPlanner; + +impl RelationPlanner for CustomValuesPlanner { + fn plan_relation( + &self, + relation: TableFactor, + _context: &mut dyn RelationPlannerContext, + ) -> Result { + match relation { + TableFactor::Table { name, alias, .. } + if name.to_string().eq_ignore_ascii_case("custom_values") => + { + let plan = LogicalPlanBuilder::values(vec![ + vec![Expr::Literal(ScalarValue::Int64(Some(1)), None)], + vec![Expr::Literal(ScalarValue::Int64(Some(2)), None)], + vec![Expr::Literal(ScalarValue::Int64(Some(3)), None)], + ])? + .build()?; + Ok(RelationPlanning::Planned(PlannedRelation::new(plan, alias))) + } + other => Ok(RelationPlanning::Original(other)), + } + } +} + +/// A planner that handles string-based tables +#[derive(Debug)] +struct StringTablePlanner; + +impl RelationPlanner for StringTablePlanner { + fn plan_relation( + &self, + relation: TableFactor, + _context: &mut dyn RelationPlannerContext, + ) -> Result { + match relation { + TableFactor::Table { name, alias, .. } + if name.to_string().eq_ignore_ascii_case("colors") => + { + let plan = LogicalPlanBuilder::values(vec![ + vec![Expr::Literal(ScalarValue::Utf8(Some("red".into())), None)], + vec![Expr::Literal(ScalarValue::Utf8(Some("green".into())), None)], + vec![Expr::Literal(ScalarValue::Utf8(Some("blue".into())), None)], + ])? + .build()?; + Ok(RelationPlanning::Planned(PlannedRelation::new(plan, alias))) + } + other => Ok(RelationPlanning::Original(other)), + } + } +} + +/// A planner that intercepts nested joins and plans them recursively +#[derive(Debug)] +struct RecursiveJoinPlanner; + +impl RelationPlanner for RecursiveJoinPlanner { + fn plan_relation( + &self, + relation: TableFactor, + context: &mut dyn RelationPlannerContext, + ) -> Result { + match relation { + TableFactor::NestedJoin { + table_with_joins, + alias, + .. + } if table_with_joins.joins.len() == 1 => { + // Recursively plan both sides using context.plan() + let left = context.plan(table_with_joins.relation.clone())?; + let right = context.plan(table_with_joins.joins[0].relation.clone())?; + + // Create a cross join + let plan = LogicalPlanBuilder::from(left).cross_join(right)?.build()?; + Ok(RelationPlanning::Planned(PlannedRelation::new(plan, alias))) + } + other => Ok(RelationPlanning::Original(other)), + } + } +} + +/// A planner that always returns None to test delegation +#[derive(Debug)] +struct PassThroughPlanner; + +impl RelationPlanner for PassThroughPlanner { + fn plan_relation( + &self, + relation: TableFactor, + _context: &mut dyn RelationPlannerContext, + ) -> Result { + // Always return Original - delegates to next planner or default + Ok(RelationPlanning::Original(relation)) + } +} + +async fn collect_sql(ctx: &SessionContext, sql: &str) -> Vec { + ctx.sql(sql).await.unwrap().collect().await.unwrap() +} + +#[tokio::test] +async fn test_custom_planner_handles_relation() { + let ctx = SessionContext::new(); + ctx.register_relation_planner(Arc::new(CustomValuesPlanner)) + .unwrap(); + + let results = collect_sql(&ctx, "SELECT * FROM custom_values").await; + + let expected = "\ ++---------+ +| column1 | ++---------+ +| 1 | +| 2 | +| 3 | ++---------+"; + assert_eq!(batches_to_string(&results), expected); +} + +#[tokio::test] +async fn test_multiple_planners_first_wins() { + let ctx = SessionContext::new(); + + // Register multiple planners - first one wins + ctx.register_relation_planner(Arc::new(CustomValuesPlanner)) + .unwrap(); + ctx.register_relation_planner(Arc::new(StringTablePlanner)) + .unwrap(); + + // CustomValuesPlanner handles this + let results = collect_sql(&ctx, "SELECT * FROM custom_values").await; + let expected = "\ ++---------+ +| column1 | ++---------+ +| 1 | +| 2 | +| 3 | ++---------+"; + assert_eq!(batches_to_string(&results), expected); + + // StringTablePlanner handles this + let results = collect_sql(&ctx, "SELECT * FROM colors").await; + let expected = "\ ++---------+ +| column1 | ++---------+ +| red | +| green | +| blue | ++---------+"; + assert_eq!(batches_to_string(&results), expected); +} + +#[tokio::test] +async fn test_planner_delegates_to_default() { + let ctx = SessionContext::new(); + + // Register a planner that always returns None + ctx.register_relation_planner(Arc::new(PassThroughPlanner)) + .unwrap(); + + // Also register a real table + let schema = Arc::new(Schema::new(vec![Field::new( + "value", + DataType::Int64, + true, + )])); + let batch = + RecordBatch::try_new(schema.clone(), vec![Arc::new(Int64Array::from(vec![42]))]) + .unwrap(); + let table = MemTable::try_new(schema, vec![vec![batch]]).unwrap(); + ctx.register_table("real_table", Arc::new(table)).unwrap(); + + // PassThroughPlanner returns None, so it delegates to the default planner + let results = collect_sql(&ctx, "SELECT * FROM real_table").await; + let expected = "\ ++-------+ +| value | ++-------+ +| 42 | ++-------+"; + assert_eq!(batches_to_string(&results), expected); +} + +#[tokio::test] +async fn test_planner_delegates_with_multiple_planners() { + let ctx = SessionContext::new(); + + // Register planners in order + ctx.register_relation_planner(Arc::new(PassThroughPlanner)) + .unwrap(); + ctx.register_relation_planner(Arc::new(CustomValuesPlanner)) + .unwrap(); + ctx.register_relation_planner(Arc::new(StringTablePlanner)) + .unwrap(); + + // PassThroughPlanner returns None, CustomValuesPlanner handles it + let results = collect_sql(&ctx, "SELECT * FROM custom_values").await; + let expected = "\ ++---------+ +| column1 | ++---------+ +| 1 | +| 2 | +| 3 | ++---------+"; + assert_eq!(batches_to_string(&results), expected); + + // PassThroughPlanner and CustomValuesPlanner both return None, + // StringTablePlanner handles it + let results = collect_sql(&ctx, "SELECT * FROM colors").await; + let expected = "\ ++---------+ +| column1 | ++---------+ +| red | +| green | +| blue | ++---------+"; + assert_eq!(batches_to_string(&results), expected); +} + +#[tokio::test] +async fn test_recursive_planning_with_context_plan() { + let ctx = SessionContext::new(); + + // Register planners + ctx.register_relation_planner(Arc::new(CustomValuesPlanner)) + .unwrap(); + ctx.register_relation_planner(Arc::new(StringTablePlanner)) + .unwrap(); + ctx.register_relation_planner(Arc::new(RecursiveJoinPlanner)) + .unwrap(); + + // RecursiveJoinPlanner calls context.plan() on both sides, + // which recursively invokes the planner pipeline + let results = collect_sql( + &ctx, + "SELECT * FROM custom_values AS nums JOIN colors AS c ON true", + ) + .await; + + // Should produce a cross join: 3 numbers × 3 colors = 9 rows + let expected = "\ ++---------+---------+ +| column1 | column1 | ++---------+---------+ +| 1 | red | +| 1 | green | +| 1 | blue | +| 2 | red | +| 2 | green | +| 2 | blue | +| 3 | red | +| 3 | green | +| 3 | blue | ++---------+---------+"; + assert_eq!(batches_to_string(&results), expected); +} + +#[tokio::test] +async fn test_planner_with_filters_and_projections() { + let ctx = SessionContext::new(); + ctx.register_relation_planner(Arc::new(CustomValuesPlanner)) + .unwrap(); + + // Test that filters and projections work on custom-planned tables + let results = collect_sql( + &ctx, + "SELECT column1 * 10 AS scaled FROM custom_values WHERE column1 > 1", + ) + .await; + + let expected = "\ ++--------+ +| scaled | ++--------+ +| 20 | +| 30 | ++--------+"; + assert_eq!(batches_to_string(&results), expected); +} + +#[tokio::test] +async fn test_planner_falls_back_to_default_for_unknown_table() { + let ctx = SessionContext::new(); + + ctx.register_relation_planner(Arc::new(CustomValuesPlanner)) + .unwrap(); + + // Register a regular table + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("name", DataType::Utf8, false), + ])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int64Array::from(vec![1, 2])), + Arc::new(StringArray::from(vec!["Alice", "Bob"])), + ], + ) + .unwrap(); + let table = MemTable::try_new(schema, vec![vec![batch]]).unwrap(); + ctx.register_table("users", Arc::new(table)).unwrap(); + + // CustomValuesPlanner doesn't handle "users", falls back to default + let results = collect_sql(&ctx, "SELECT * FROM users ORDER BY id").await; + + let expected = "\ ++----+-------+ +| id | name | ++----+-------+ +| 1 | Alice | +| 2 | Bob | ++----+-------+"; + assert_eq!(batches_to_string(&results), expected); +} diff --git a/datafusion/expr/src/planner.rs b/datafusion/expr/src/planner.rs index 25a0f83947ee..91dbcd71c5fd 100644 --- a/datafusion/expr/src/planner.rs +++ b/datafusion/expr/src/planner.rs @@ -21,6 +21,8 @@ use std::fmt::Debug; use std::sync::Arc; use crate::expr::NullTreatment; +#[cfg(feature = "sql")] +use crate::logical_plan::LogicalPlan; use crate::{ AggregateUDF, Expr, GetFieldAccess, ScalarUDF, SortExpr, TableSource, WindowFrame, WindowFunctionDefinition, WindowUDF, @@ -30,6 +32,8 @@ use datafusion_common::{ config::ConfigOptions, file_options::file_type::FileType, not_impl_err, DFSchema, Result, TableReference, }; +#[cfg(feature = "sql")] +use sqlparser::ast::{Expr as SQLExpr, Ident, ObjectName, TableAlias, TableFactor}; /// Provides the `SQL` query planner meta-data about tables and /// functions referenced in SQL statements, without a direct dependency on the @@ -83,6 +87,12 @@ pub trait ContextProvider { &[] } + /// Return [`RelationPlanner`] extensions for planning table factors + #[cfg(feature = "sql")] + fn get_relation_planners(&self) -> &[Arc] { + &[] + } + /// Return [`TypePlanner`] extensions for planning data types #[cfg(feature = "sql")] fn get_type_planner(&self) -> Option> { @@ -324,6 +334,85 @@ pub enum PlannerResult { Original(T), } +/// Result of planning a relation with [`RelationPlanner`] +#[cfg(feature = "sql")] +#[derive(Debug, Clone)] +pub struct PlannedRelation { + /// The logical plan for the relation + pub plan: LogicalPlan, + /// Optional table alias for the relation + pub alias: Option, +} + +#[cfg(feature = "sql")] +impl PlannedRelation { + /// Create a new `PlannedRelation` with the given plan and alias + pub fn new(plan: LogicalPlan, alias: Option) -> Self { + Self { plan, alias } + } +} + +/// Result of attempting to plan a relation with extension planners +#[cfg(feature = "sql")] +#[derive(Debug)] +pub enum RelationPlanning { + /// The relation was successfully planned by an extension planner + Planned(PlannedRelation), + /// No extension planner handled the relation, return it for default processing + Original(TableFactor), +} + +/// Customize planning SQL table factors to [`LogicalPlan`]s. +#[cfg(feature = "sql")] +pub trait RelationPlanner: Debug + Send + Sync { + /// Plan a table factor into a [`LogicalPlan`]. + /// + /// Returning `Ok(RelationPlanning::Transformed(planned_relation))` short-circuits further planning and uses the + /// provided plan. Returning `Ok(RelationPlanning::Original(relation))` allows the next registered planner, + /// or DataFusion's default logic, to handle the relation. + fn plan_relation( + &self, + relation: TableFactor, + context: &mut dyn RelationPlannerContext, + ) -> Result; +} + +/// Provides utilities for relation planners to interact with DataFusion's SQL +/// planner. +/// +/// This trait provides SQL planning utilities specific to relation planning, +/// such as converting SQL expressions to logical expressions and normalizing +/// identifiers. It uses composition to provide access to session context via +/// [`ContextProvider`]. +#[cfg(feature = "sql")] +pub trait RelationPlannerContext { + /// Provides access to the underlying context provider for reading session + /// configuration, accessing tables, functions, and other metadata. + fn context_provider(&self) -> &dyn ContextProvider; + + /// Plans the specified relation through the full planner pipeline, starting + /// from the first registered relation planner. + fn plan(&mut self, relation: TableFactor) -> Result; + + /// Converts a SQL expression into a logical expression using the current + /// planner context. + fn sql_to_expr(&mut self, expr: SQLExpr, schema: &DFSchema) -> Result; + + /// Converts a SQL expression into a logical expression without DataFusion + /// rewrites. + fn sql_expr_to_logical_expr( + &mut self, + expr: SQLExpr, + schema: &DFSchema, + ) -> Result; + + /// Normalizes an identifier according to session settings. + fn normalize_ident(&self, ident: Ident) -> String; + + /// Normalizes a SQL object name into a [`TableReference`]. + fn object_name_to_table_reference(&self, name: ObjectName) -> Result; +} + /// Customize planning SQL types to DataFusion (Arrow) types. #[cfg(feature = "sql")] pub trait TypePlanner: Debug + Send + Sync { diff --git a/datafusion/sql/src/relation/mod.rs b/datafusion/sql/src/relation/mod.rs index 9dfa078701d3..9ebfb52a8bb3 100644 --- a/datafusion/sql/src/relation/mod.rs +++ b/datafusion/sql/src/relation/mod.rs @@ -24,12 +24,61 @@ use datafusion_common::{ not_impl_err, plan_err, DFSchema, Diagnostic, Result, Span, Spans, TableReference, }; use datafusion_expr::builder::subquery_alias; +use datafusion_expr::planner::{ + PlannedRelation, RelationPlannerContext, RelationPlanning, +}; use datafusion_expr::{expr::Unnest, Expr, LogicalPlan, LogicalPlanBuilder}; use datafusion_expr::{Subquery, SubqueryAlias}; use sqlparser::ast::{FunctionArg, FunctionArgExpr, Spanned, TableFactor}; mod join; +struct SqlToRelRelationContext<'a, 'b, S: ContextProvider> { + planner: &'a SqlToRel<'b, S>, + planner_context: &'a mut PlannerContext, +} + +// Implement RelationPlannerContext +impl<'a, 'b, S: ContextProvider> RelationPlannerContext + for SqlToRelRelationContext<'a, 'b, S> +{ + fn context_provider(&self) -> &dyn ContextProvider { + self.planner.context_provider + } + + fn plan(&mut self, relation: TableFactor) -> Result { + self.planner.create_relation(relation, self.planner_context) + } + + fn sql_to_expr( + &mut self, + expr: sqlparser::ast::Expr, + schema: &DFSchema, + ) -> Result { + self.planner.sql_to_expr(expr, schema, self.planner_context) + } + + fn sql_expr_to_logical_expr( + &mut self, + expr: sqlparser::ast::Expr, + schema: &DFSchema, + ) -> Result { + self.planner + .sql_expr_to_logical_expr(expr, schema, self.planner_context) + } + + fn normalize_ident(&self, ident: sqlparser::ast::Ident) -> String { + self.planner.ident_normalizer.normalize(ident) + } + + fn object_name_to_table_reference( + &self, + name: sqlparser::ast::ObjectName, + ) -> Result { + self.planner.object_name_to_table_reference(name) + } +} + impl SqlToRel<'_, S> { /// Create a `LogicalPlan` that scans the named relation fn create_relation( @@ -37,6 +86,57 @@ impl SqlToRel<'_, S> { relation: TableFactor, planner_context: &mut PlannerContext, ) -> Result { + let planned_relation = + match self.create_extension_relation(relation, planner_context)? { + RelationPlanning::Planned(planned) => planned, + RelationPlanning::Original(original) => { + self.create_default_relation(original, planner_context)? + } + }; + + let optimized_plan = optimize_subquery_sort(planned_relation.plan)?.data; + if let Some(alias) = planned_relation.alias { + self.apply_table_alias(optimized_plan, alias) + } else { + Ok(optimized_plan) + } + } + + fn create_extension_relation( + &self, + relation: TableFactor, + planner_context: &mut PlannerContext, + ) -> Result { + let planners = self.context_provider.get_relation_planners(); + if planners.is_empty() { + return Ok(RelationPlanning::Original(relation)); + } + + let mut current_relation = relation; + for planner in planners.iter() { + let mut context = SqlToRelRelationContext { + planner: self, + planner_context, + }; + + match planner.plan_relation(current_relation, &mut context)? { + RelationPlanning::Planned(planned) => { + return Ok(RelationPlanning::Planned(planned)); + } + RelationPlanning::Original(original) => { + current_relation = original; + } + } + } + + Ok(RelationPlanning::Original(current_relation)) + } + + fn create_default_relation( + &self, + relation: TableFactor, + planner_context: &mut PlannerContext, + ) -> Result { let relation_span = relation.span(); let (plan, alias) = match relation { TableFactor::Table { @@ -190,13 +290,7 @@ impl SqlToRel<'_, S> { ); } }; - - let optimized_plan = optimize_subquery_sort(plan)?.data; - if let Some(alias) = alias { - self.apply_table_alias(optimized_plan, alias) - } else { - Ok(optimized_plan) - } + Ok(PlannedRelation::new(plan, alias)) } pub(crate) fn create_relation_subquery(