diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index f4e33fe2c19c..02352508dc4d 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -31,7 +31,7 @@ rust-version = { workspace = true } all-features = true [features] -default = [] +default = ["backtrace"] backtrace = ["datafusion/backtrace"] [dependencies] diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index 6618c6aeec28..b8a330a10e84 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -783,7 +783,7 @@ config_namespace! { pub skip_failed_rules: bool, default = false /// Number of times that the optimizer will attempt to optimize the plan - pub max_passes: usize, default = 3 + pub max_passes: usize, default = 1 /// When set to true, the physical plan optimizer will run a top down /// process to reorder the join keys diff --git a/datafusion/common/src/functional_dependencies.rs b/datafusion/common/src/functional_dependencies.rs index 63962998ad18..a3628a0cda55 100644 --- a/datafusion/common/src/functional_dependencies.rs +++ b/datafusion/common/src/functional_dependencies.rs @@ -360,7 +360,7 @@ impl FunctionalDependencies { left_func_dependencies.extend(right_func_dependencies); left_func_dependencies } - JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => { + JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark | JoinType::LeftSingle => { // These joins preserve functional dependencies of the left side: left_func_dependencies } diff --git a/datafusion/common/src/join_type.rs b/datafusion/common/src/join_type.rs index d9a1478f0238..da295e85937a 100644 --- a/datafusion/common/src/join_type.rs +++ b/datafusion/common/src/join_type.rs @@ -72,6 +72,8 @@ pub enum JoinType { /// Same logic as the LeftMark Join above, however it returns a record for each record from the /// right input. RightMark, + + LeftSingle, } impl JoinType { @@ -94,6 +96,7 @@ impl JoinType { JoinType::RightAnti => JoinType::LeftAnti, JoinType::LeftMark => JoinType::RightMark, JoinType::RightMark => JoinType::LeftMark, + JoinType::LeftSingle => unreachable!(), // TODO: add right single support } } @@ -126,6 +129,7 @@ impl Display for JoinType { JoinType::RightAnti => "RightAnti", JoinType::LeftMark => "LeftMark", JoinType::RightMark => "RightMark", + JoinType::LeftSingle => "LeftSingle", }; write!(f, "{join_type}") } @@ -147,6 +151,7 @@ impl FromStr for JoinType { "RIGHTANTI" => Ok(JoinType::RightAnti), "LEFTMARK" => Ok(JoinType::LeftMark), "RIGHTMARK" => Ok(JoinType::RightMark), + "LEFtSINGLE" => Ok(JoinType::LeftSingle), _ => _not_impl_err!("The join type {s} does not exist or is not implemented"), } } diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 293f2cfc9670..c3cea7c1454b 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -78,8 +78,8 @@ use datafusion_expr::expr::{ use datafusion_expr::expr_rewriter::unnormalize_cols; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; use datafusion_expr::{ - Analyze, DescribeTable, DmlStatement, Explain, ExplainFormat, Extension, FetchType, - Filter, JoinType, RecursiveQuery, SkipType, StringifiedPlan, WindowFrame, + Analyze, DelimGet, DescribeTable, DmlStatement, Explain, ExplainFormat, Extension, + FetchType, Filter, JoinType, RecursiveQuery, SkipType, StringifiedPlan, WindowFrame, WindowFrameBound, WriteOp, }; use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; @@ -1311,6 +1311,63 @@ impl DefaultPhysicalPlanner { "Unsupported logical plan: Analyze must be root of the plan" ) } + LogicalPlan::DependentJoin(_) => { + return internal_err!( + "Optimizors have not completely remove dependent join" + ) + } + LogicalPlan::DelimGet(DelimGet { + table_name, + projected_schema, + .. + }) => { + let resolved = session_state.resolve_table_ref(table_name.clone()); + if let Ok(schema) = session_state.schema_for_ref(resolved.clone()) { + if let Some(table) = schema.table(&resolved.table).await? { + let mut proj = vec![]; + for (i, field) in table.schema().fields().iter().enumerate() { + for iter in projected_schema.as_ref().iter() { + if iter.1 == field { + proj.push(i); + } + } + } + + // First create the scan execution plan. + let scan_plan = + table.scan(session_state, Some(&proj), &[], None).await?; + + // Now add aggregation to eliminate duplicated rows. + // Create a PhysicalGroupBy with empty expressions, which means we're grouping by all columns + let schema = &scan_plan.schema(); + let group_exprs: Vec<(Arc, String)> = (0 + ..schema.fields().len()) + .map(|i| { + let name = schema.field(i).name().to_string(); + let expr = Arc::new(Column::new(&name, i)) + as Arc; + (expr, name) + }) + .collect(); + + let group_by = PhysicalGroupBy::new_single(group_exprs); + + // Create the AggregateExec with no aggregate expressions to deduplicate the rows + Arc::new(AggregateExec::try_new( + AggregateMode::Final, + group_by, + vec![], // No aggregate expressions, just grouping to deduplicate + vec![], // No filters + scan_plan.clone(), + scan_plan.schema(), + )?) + } else { + return internal_err!("no table provider"); + } + } else { + return internal_err!("empty schema"); + } + } }; Ok(exec_node) } diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 0749ff0e98b7..567875761297 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -3169,7 +3169,18 @@ pub const UNNEST_COLUMN_PREFIX: &str = "UNNEST"; impl Display for Expr { fn fmt(&self, f: &mut Formatter) -> fmt::Result { match self { - Expr::Alias(Alias { expr, name, .. }) => write!(f, "{expr} AS {name}"), + Expr::Alias(Alias { + expr, + relation, + name, + .. + }) => { + if let Some(relation) = relation { + write!(f, "{expr} AS {relation}.{name}") + } else { + write!(f, "{expr} AS {name}") + } + } Expr::Column(c) => write!(f, "{c}"), Expr::OuterReferenceColumn(_, c) => { write!(f, "{OUTER_REFERENCE_COLUMN_PREFIX}({c})") diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 1ab5ffa75842..09066dc1e1b6 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -31,10 +31,10 @@ use crate::expr_rewriter::{ rewrite_sort_cols_by_aggs, }; use crate::logical_plan::{ - Aggregate, Analyze, Distinct, DistinctOn, EmptyRelation, Explain, Filter, Join, - JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare, - Projection, Repartition, Sort, SubqueryAlias, TableScan, Union, Unnest, Values, - Window, + Aggregate, Analyze, DelimGet, DependentJoin, Distinct, DistinctOn, EmptyRelation, + Explain, Filter, Join, JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, + PlanType, Prepare, Projection, Repartition, Sort, SubqueryAlias, TableScan, Union, + Unnest, Values, Window, }; use crate::select_expr::SelectExpr; use crate::utils::{ @@ -48,6 +48,8 @@ use crate::{ }; use super::dml::InsertOp; +use super::plan::JoinKind; +use super::CorrelatedColumnInfo; use arrow::compute::can_cast_types; use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef}; use datafusion_common::display::ToStringifiedPlan; @@ -397,6 +399,12 @@ impl LogicalPlanBuilder { Self::scan_with_filters(table_name, table_source, projection, vec![]) } + pub fn delim_get(correlated_columns: &Vec) -> Result { + Ok(Self::new(LogicalPlan::DelimGet(DelimGet::try_new( + correlated_columns, + )?))) + } + /// Create a [CopyTo] for copying the contents of this builder to the specified file(s) pub fn copy_to( input: LogicalPlan, @@ -882,6 +890,47 @@ impl LogicalPlanBuilder { )))) } + /// Build a dependent join provided a subquery plan + /// this function should only be used by the optimizor + /// a dependent join node will provides all columns belonging to the LHS + /// and one additional column as the result of evaluating the subquery on the RHS + /// under the name "subquery_name.output" + pub fn dependent_join( + self, + right: LogicalPlan, + correlated_columns: Vec, + subquery_expr: Option, + subquery_depth: usize, + subquery_name: String, + lateral_join_condition: Option<(JoinType, Expr)>, + ) -> Result { + let left = self.build()?; + let schema = left.schema(); + // TODO: for lateral join, output schema is similar to a normal join + let qualified_fields = schema + .iter() + .map(|(q, f)| (q.cloned(), Arc::clone(f))) + .chain( + subquery_expr + .iter() + .map(|expr| subquery_output_field(&subquery_name, expr)), + ) + .collect(); + let metadata = schema.metadata().clone(); + let dfschema = DFSchema::new_with_metadata(qualified_fields, metadata)?; + + Ok(Self::new(LogicalPlan::DependentJoin(DependentJoin { + schema: DFSchemaRef::new(dfschema), + left: Arc::new(left), + right: Arc::new(right), + correlated_columns, + subquery_expr, + subquery_name, + subquery_depth, + lateral_join_condition, + }))) + } + /// Apply a join to `right` using explicitly specified columns and an /// optional filter expression. /// @@ -911,6 +960,23 @@ impl LogicalPlanBuilder { ) } + pub fn delim_join( + self, + right: LogicalPlan, + join_type: JoinType, + join_keys: (Vec>, Vec>), + filter: Option, + ) -> Result { + self.join_detailed_with_join_kind( + right, + join_type, + join_keys, + filter, + NullEquality::NullEqualsNothing, + JoinKind::DelimJoin, + ) + } + /// Apply a join using the specified expressions. /// /// Note that DataFusion automatically optimizes joins, including @@ -996,6 +1062,25 @@ impl LogicalPlanBuilder { join_keys: (Vec>, Vec>), filter: Option, null_equality: NullEquality, + ) -> Result { + self.join_detailed_with_join_kind( + right, + join_type, + join_keys, + filter, + null_equality, + JoinKind::ComparisonJoin, + ) + } + + pub fn join_detailed_with_join_kind( + self, + right: LogicalPlan, + join_type: JoinType, + join_keys: (Vec>, Vec>), + filter: Option, + null_equality: NullEquality, + join_kind: JoinKind, ) -> Result { if join_keys.0.len() != join_keys.1.len() { return plan_err!("left_keys and right_keys were not the same length"); @@ -1113,6 +1198,7 @@ impl LogicalPlanBuilder { join_constraint: JoinConstraint::On, schema: DFSchemaRef::new(join_schema), null_equality, + join_kind, }))) } @@ -1570,6 +1656,28 @@ fn mark_field(schema: &DFSchema) -> (Option, Arc) { ) } +fn subquery_output_field( + subquery_alias: &str, + subquery_expr: &Expr, +) -> (Option, Arc) { + // TODO: check nullability + let field = match subquery_expr { + Expr::InSubquery(_) => { + Arc::new(Field::new(subquery_alias, DataType::Boolean, false)) + } + Expr::Exists(_) => Arc::new(Field::new(subquery_alias, DataType::Boolean, false)), + Expr::ScalarSubquery(sq) => { + let data_type = sq.subquery.schema().field(0).data_type().clone(); + Arc::new(Field::new(subquery_alias, data_type, false)) + } + _ => { + unreachable!() + } + }; + + (None, field) +} + /// Creates a schema for a join operation. /// The fields from the left side are first pub fn build_join_schema( @@ -1603,7 +1711,7 @@ pub fn build_join_schema( .collect::>(); left_fields.into_iter().chain(right_fields).collect() } - JoinType::Left => { + JoinType::Left | JoinType::LeftSingle => { // left then right, right set to nullable in case of not matched scenario let left_fields = left_fields .map(|(q, f)| (q.cloned(), Arc::clone(f))) @@ -2146,7 +2254,6 @@ mod tests { use crate::test::function_stub::sum; use datafusion_common::{Constraint, RecursionUnnestOption, SchemaError}; - use insta::assert_snapshot; #[test] fn plan_builder_simple() -> Result<()> { @@ -2156,11 +2263,11 @@ mod tests { .project(vec![col("id")])? .build()?; - assert_snapshot!(plan, @r#" - Projection: employee_csv.id - Filter: employee_csv.state = Utf8("CO") - TableScan: employee_csv projection=[id, state] - "#); + let expected = "Projection: employee_csv.id\ + \n Filter: employee_csv.state = Utf8(\"CO\")\ + \n TableScan: employee_csv projection=[id, state]"; + + assert_eq!(expected, format!("{plan}")); Ok(()) } @@ -2172,7 +2279,12 @@ mod tests { let plan = LogicalPlanBuilder::scan("employee_csv", table_source(&schema), projection) .unwrap(); - assert_snapshot!(plan.schema().as_ref(), @"fields:[employee_csv.id, employee_csv.first_name, employee_csv.last_name, employee_csv.state, employee_csv.salary], metadata:{}"); + let expected = DFSchema::try_from_qualified_schema( + TableReference::bare("employee_csv"), + &schema, + ) + .unwrap(); + assert_eq!(&expected, plan.schema().as_ref()); // Note scan of "EMPLOYEE_CSV" is treated as a SQL identifier // (and thus normalized to "employee"csv") as well @@ -2180,7 +2292,7 @@ mod tests { let plan = LogicalPlanBuilder::scan("EMPLOYEE_CSV", table_source(&schema), projection) .unwrap(); - assert_snapshot!(plan.schema().as_ref(), @"fields:[employee_csv.id, employee_csv.first_name, employee_csv.last_name, employee_csv.state, employee_csv.salary], metadata:{}"); + assert_eq!(&expected, plan.schema().as_ref()); } #[test] @@ -2189,9 +2301,9 @@ mod tests { let projection = None; let err = LogicalPlanBuilder::scan("", table_source(&schema), projection).unwrap_err(); - assert_snapshot!( + assert_eq!( err.strip_backtrace(), - @"Error during planning: table_name cannot be empty" + "Error during planning: table_name cannot be empty" ); } @@ -2205,10 +2317,10 @@ mod tests { ])? .build()?; - assert_snapshot!(plan, @r" - Sort: employee_csv.state ASC NULLS FIRST, employee_csv.salary DESC NULLS LAST - TableScan: employee_csv projection=[state, salary] - "); + let expected = "Sort: employee_csv.state ASC NULLS FIRST, employee_csv.salary DESC NULLS LAST\ + \n TableScan: employee_csv projection=[state, salary]"; + + assert_eq!(expected, format!("{plan}")); Ok(()) } @@ -2225,15 +2337,15 @@ mod tests { .union(plan.build()?)? .build()?; - assert_snapshot!(plan, @r" - Union - Union - Union - TableScan: employee_csv projection=[state, salary] - TableScan: employee_csv projection=[state, salary] - TableScan: employee_csv projection=[state, salary] - TableScan: employee_csv projection=[state, salary] - "); + let expected = "Union\ + \n Union\ + \n Union\ + \n TableScan: employee_csv projection=[state, salary]\ + \n TableScan: employee_csv projection=[state, salary]\ + \n TableScan: employee_csv projection=[state, salary]\ + \n TableScan: employee_csv projection=[state, salary]"; + + assert_eq!(expected, format!("{plan}")); Ok(()) } @@ -2250,18 +2362,19 @@ mod tests { .union_distinct(plan.build()?)? .build()?; - assert_snapshot!(plan, @r" - Distinct: - Union - Distinct: - Union - Distinct: - Union - TableScan: employee_csv projection=[state, salary] - TableScan: employee_csv projection=[state, salary] - TableScan: employee_csv projection=[state, salary] - TableScan: employee_csv projection=[state, salary] - "); + let expected = "\ + Distinct:\ + \n Union\ + \n Distinct:\ + \n Union\ + \n Distinct:\ + \n Union\ + \n TableScan: employee_csv projection=[state, salary]\ + \n TableScan: employee_csv projection=[state, salary]\ + \n TableScan: employee_csv projection=[state, salary]\ + \n TableScan: employee_csv projection=[state, salary]"; + + assert_eq!(expected, format!("{plan}")); Ok(()) } @@ -2275,12 +2388,13 @@ mod tests { .distinct()? .build()?; - assert_snapshot!(plan, @r#" - Distinct: - Projection: employee_csv.id - Filter: employee_csv.state = Utf8("CO") - TableScan: employee_csv projection=[id, state] - "#); + let expected = "\ + Distinct:\ + \n Projection: employee_csv.id\ + \n Filter: employee_csv.state = Utf8(\"CO\")\ + \n TableScan: employee_csv projection=[id, state]"; + + assert_eq!(expected, format!("{plan}")); Ok(()) } @@ -2300,15 +2414,14 @@ mod tests { .filter(exists(Arc::new(subquery)))? .build()?; - assert_snapshot!(outer_query, @r" - Filter: EXISTS () - Subquery: - Filter: foo.a = bar.a - Projection: foo.a - TableScan: foo - Projection: bar.a - TableScan: bar - "); + let expected = "Filter: EXISTS ()\ + \n Subquery:\ + \n Filter: foo.a = bar.a\ + \n Projection: foo.a\ + \n TableScan: foo\ + \n Projection: bar.a\ + \n TableScan: bar"; + assert_eq!(expected, format!("{outer_query}")); Ok(()) } @@ -2329,15 +2442,14 @@ mod tests { .filter(in_subquery(col("a"), Arc::new(subquery)))? .build()?; - assert_snapshot!(outer_query, @r" - Filter: bar.a IN () - Subquery: - Filter: foo.a = bar.a - Projection: foo.a - TableScan: foo - Projection: bar.a - TableScan: bar - "); + let expected = "Filter: bar.a IN ()\ + \n Subquery:\ + \n Filter: foo.a = bar.a\ + \n Projection: foo.a\ + \n TableScan: foo\ + \n Projection: bar.a\ + \n TableScan: bar"; + assert_eq!(expected, format!("{outer_query}")); Ok(()) } @@ -2357,14 +2469,13 @@ mod tests { .project(vec![scalar_subquery(Arc::new(subquery))])? .build()?; - assert_snapshot!(outer_query, @r" - Projection: () - Subquery: - Filter: foo.a = bar.a - Projection: foo.b - TableScan: foo - TableScan: bar - "); + let expected = "Projection: ()\ + \n Subquery:\ + \n Filter: foo.a = bar.a\ + \n Projection: foo.b\ + \n TableScan: foo\ + \n TableScan: bar"; + assert_eq!(expected, format!("{outer_query}")); Ok(()) } @@ -2462,11 +2573,13 @@ mod tests { let plan2 = table_scan(TableReference::none(), &employee_schema(), Some(vec![3, 4]))?; + let expected = "Error during planning: INTERSECT/EXCEPT query must have the same number of columns. \ + Left is 1 and right is 2."; let err_msg1 = LogicalPlanBuilder::intersect(plan1.build()?, plan2.build()?, true) .unwrap_err(); - assert_snapshot!(err_msg1.strip_backtrace(), @"Error during planning: INTERSECT/EXCEPT query must have the same number of columns. Left is 1 and right is 2."); + assert_eq!(err_msg1.strip_backtrace(), expected); Ok(()) } @@ -2477,29 +2590,19 @@ mod tests { let err = nested_table_scan("test_table")? .unnest_column("scalar") .unwrap_err(); - - let DataFusionError::Internal(desc) = err else { - return plan_err!("Plan should have returned an DataFusionError::Internal"); - }; - - let desc = desc - .split(DataFusionError::BACK_TRACE_SEP) - .collect::>() - .first() - .unwrap_or(&"") - .to_string(); - - assert_snapshot!(desc, @"trying to unnest on invalid data type UInt32"); + assert!(err + .to_string() + .starts_with("Internal error: trying to unnest on invalid data type UInt32")); // Unnesting the strings list. let plan = nested_table_scan("test_table")? .unnest_column("strings")? .build()?; - assert_snapshot!(plan, @r" - Unnest: lists[test_table.strings|depth=1] structs[] - TableScan: test_table - "); + let expected = "\ + Unnest: lists[test_table.strings|depth=1] structs[]\ + \n TableScan: test_table"; + assert_eq!(expected, format!("{plan}")); // Check unnested field is a scalar let field = plan.schema().field_with_name(None, "strings").unwrap(); @@ -2510,10 +2613,10 @@ mod tests { .unnest_column("struct_singular")? .build()?; - assert_snapshot!(plan, @r" - Unnest: lists[] structs[test_table.struct_singular] - TableScan: test_table - "); + let expected = "\ + Unnest: lists[] structs[test_table.struct_singular]\ + \n TableScan: test_table"; + assert_eq!(expected, format!("{plan}")); for field_name in &["a", "b"] { // Check unnested struct field is a scalar @@ -2531,12 +2634,12 @@ mod tests { .unnest_column("struct_singular")? .build()?; - assert_snapshot!(plan, @r" - Unnest: lists[] structs[test_table.struct_singular] - Unnest: lists[test_table.structs|depth=1] structs[] - Unnest: lists[test_table.strings|depth=1] structs[] - TableScan: test_table - "); + let expected = "\ + Unnest: lists[] structs[test_table.struct_singular]\ + \n Unnest: lists[test_table.structs|depth=1] structs[]\ + \n Unnest: lists[test_table.strings|depth=1] structs[]\ + \n TableScan: test_table"; + assert_eq!(expected, format!("{plan}")); // Check unnested struct list field should be a struct. let field = plan.schema().field_with_name(None, "structs").unwrap(); @@ -2552,10 +2655,10 @@ mod tests { .unnest_columns_with_options(cols, UnnestOptions::default())? .build()?; - assert_snapshot!(plan, @r" - Unnest: lists[test_table.strings|depth=1, test_table.structs|depth=1] structs[test_table.struct_singular] - TableScan: test_table - "); + let expected = "\ + Unnest: lists[test_table.strings|depth=1, test_table.structs|depth=1] structs[test_table.struct_singular]\ + \n TableScan: test_table"; + assert_eq!(expected, format!("{plan}")); // Unnesting missing column should fail. let plan = nested_table_scan("test_table")?.unnest_column("missing"); @@ -2579,10 +2682,10 @@ mod tests { )? .build()?; - assert_snapshot!(plan, @r" - Unnest: lists[test_table.stringss|depth=1, test_table.stringss|depth=2] structs[test_table.struct_singular] - TableScan: test_table - "); + let expected = "\ + Unnest: lists[test_table.stringss|depth=1, test_table.stringss|depth=2] structs[test_table.struct_singular]\ + \n TableScan: test_table"; + assert_eq!(expected, format!("{plan}")); // Check output columns has correct type let field = plan @@ -2654,24 +2757,10 @@ mod tests { let join = LogicalPlanBuilder::from(left).cross_join(right)?.build()?; - let plan = LogicalPlanBuilder::from(join.clone()) + let _ = LogicalPlanBuilder::from(join.clone()) .union(join)? .build()?; - assert_snapshot!(plan, @r" - Union - Cross Join: - SubqueryAlias: left - Values: (Int32(1)) - SubqueryAlias: right - Values: (Int32(1)) - Cross Join: - SubqueryAlias: left - Values: (Int32(1)) - SubqueryAlias: right - Values: (Int32(1)) - "); - Ok(()) } @@ -2731,10 +2820,10 @@ mod tests { .aggregate(vec![col("id")], vec![sum(col("salary"))])? .build()?; - assert_snapshot!(plan, @r" - Aggregate: groupBy=[[employee_csv.id]], aggr=[[sum(employee_csv.salary)]] - TableScan: employee_csv projection=[id, state, salary] - "); + let expected = + "Aggregate: groupBy=[[employee_csv.id]], aggr=[[sum(employee_csv.salary)]]\ + \n TableScan: employee_csv projection=[id, state, salary]"; + assert_eq!(expected, format!("{plan}")); Ok(()) } @@ -2753,10 +2842,10 @@ mod tests { .aggregate(vec![col("id")], vec![sum(col("salary"))])? .build()?; - assert_snapshot!(plan, @r" - Aggregate: groupBy=[[employee_csv.id, employee_csv.state, employee_csv.salary]], aggr=[[sum(employee_csv.salary)]] - TableScan: employee_csv projection=[id, state, salary] - "); + let expected = + "Aggregate: groupBy=[[employee_csv.id, employee_csv.state, employee_csv.salary]], aggr=[[sum(employee_csv.salary)]]\ + \n TableScan: employee_csv projection=[id, state, salary]"; + assert_eq!(expected, format!("{plan}")); Ok(()) } diff --git a/datafusion/expr/src/logical_plan/display.rs b/datafusion/expr/src/logical_plan/display.rs index cc3fbad7b0c2..f735ae974396 100644 --- a/datafusion/expr/src/logical_plan/display.rs +++ b/datafusion/expr/src/logical_plan/display.rs @@ -486,6 +486,7 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { object } + LogicalPlan::DependentJoin(..) => json!({}), LogicalPlan::Join(Join { on: ref keys, filter, @@ -651,6 +652,7 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { "StructColumn": expr_vec_fmt!(struct_type_columns), }) } + LogicalPlan::DelimGet(_) => todo!(), } } } diff --git a/datafusion/expr/src/logical_plan/invariants.rs b/datafusion/expr/src/logical_plan/invariants.rs index d8d6739b0e8f..25a02c731a68 100644 --- a/datafusion/expr/src/logical_plan/invariants.rs +++ b/datafusion/expr/src/logical_plan/invariants.rs @@ -201,20 +201,27 @@ pub fn check_subquery_expr( }?; match outer_plan { LogicalPlan::Projection(_) - | LogicalPlan::Filter(_) => Ok(()), - LogicalPlan::Aggregate(Aggregate { group_expr, aggr_expr, .. }) => { + | LogicalPlan::Filter(_) + | LogicalPlan::DependentJoin(_) => Ok(()), + LogicalPlan::Aggregate(Aggregate { + group_expr, + aggr_expr, + .. + }) => { if group_expr.contains(expr) && !aggr_expr.contains(expr) { // TODO revisit this validation logic plan_err!( - "Correlated scalar subquery in the GROUP BY clause must also be in the aggregate expressions" + "Correlated scalar subquery in the GROUP BY clause must \ + also be in the aggregate expressions" ) } else { Ok(()) } } _ => plan_err!( - "Correlated scalar subquery can only be used in Projection, Filter, Aggregate plan nodes" - ) + "Correlated scalar subquery can only be used in Projection, Filter, \ + Aggregate, DependentJoin plan nodes" + ), }?; } check_correlations_in_subquery(inner_plan) @@ -235,11 +242,12 @@ pub fn check_subquery_expr( | LogicalPlan::TableScan(_) | LogicalPlan::Window(_) | LogicalPlan::Aggregate(_) - | LogicalPlan::Join(_) => Ok(()), + | LogicalPlan::Join(_) + | LogicalPlan::DependentJoin(_) => Ok(()), _ => plan_err!( "In/Exist subquery can only be used in \ - Projection, Filter, TableScan, Window functions, Aggregate and Join plan nodes, \ - but was used in [{}]", + Projection, Filter, TableScan, Window functions, Aggregate, Join and \ + Dependent Join plan nodes, but was used in [{}]", outer_plan.display() ), }?; @@ -306,7 +314,8 @@ fn check_inner_plan(inner_plan: &LogicalPlan) -> Result<()> { JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti - | JoinType::LeftMark => { + | JoinType::LeftMark + | JoinType::LeftSingle => { check_inner_plan(left)?; check_no_outer_references(right) } @@ -326,6 +335,7 @@ fn check_inner_plan(inner_plan: &LogicalPlan) -> Result<()> { } }, LogicalPlan::Extension(_) => Ok(()), + LogicalPlan::DependentJoin(_) => Ok(()), plan => check_no_outer_references(plan), } } diff --git a/datafusion/expr/src/logical_plan/mod.rs b/datafusion/expr/src/logical_plan/mod.rs index 4bbb9d7ada7e..5eb94269f525 100644 --- a/datafusion/expr/src/logical_plan/mod.rs +++ b/datafusion/expr/src/logical_plan/mod.rs @@ -38,11 +38,12 @@ pub use ddl::{ }; pub use dml::{DmlStatement, WriteOp}; pub use plan::{ - projection_schema, Aggregate, Analyze, ColumnUnnestList, DescribeTable, Distinct, - DistinctOn, EmptyRelation, Explain, ExplainFormat, ExplainOption, Extension, - FetchType, Filter, Join, JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, - PlanType, Projection, RecursiveQuery, Repartition, SkipType, Sort, StringifiedPlan, - Subquery, SubqueryAlias, TableScan, ToStringifiedPlan, Union, Unnest, Values, Window, + projection_schema, Aggregate, Analyze, ColumnUnnestList, CorrelatedColumnInfo, + DelimGet, DependentJoin, DescribeTable, Distinct, DistinctOn, EmptyRelation, Explain, + ExplainFormat, ExplainOption, Extension, FetchType, Filter, Join, JoinConstraint, + JoinKind, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Projection, + RecursiveQuery, Repartition, SkipType, Sort, StringifiedPlan, Subquery, + SubqueryAlias, TableScan, ToStringifiedPlan, Union, Unnest, Values, Window, }; pub use statement::{ Deallocate, Execute, Prepare, SetVariable, Statement, TransactionAccessMode, diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index d68e6cd81272..ccace3bf8695 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -288,6 +288,208 @@ pub enum LogicalPlan { Unnest(Unnest), /// A variadic query (e.g. "Recursive CTEs") RecursiveQuery(RecursiveQuery), + /// A node type that only exist during subquery decorrelation + /// TODO: maybe we can avoid creating new type of LogicalPlan for this usecase + DependentJoin(DependentJoin), + DelimGet(DelimGet), +} + +#[derive(Clone, Debug, Eq, PartialOrd, Hash)] +pub struct CorrelatedColumnInfo { + pub col: Column, + // TODO: is data_type necessary? + pub data_type: DataType, + pub depth: usize, +} + +impl CorrelatedColumnInfo { + pub fn new(col: Column) -> Self { + Self { + col, + data_type: DataType::Null, + depth: 0, + } + } +} + +impl PartialEq for CorrelatedColumnInfo { + fn eq(&self, other: &Self) -> bool { + self.col == other.col + } +} + +#[derive(Debug, Clone, Eq)] +pub struct DelimGet { + // TODO: is it necessary to alias? + pub table_name: TableReference, + pub columns: Vec, + /// The schema description of the output + pub projected_schema: DFSchemaRef, + // TODO: add more variables as needed. +} + +impl DelimGet { + pub fn try_new(correlated_columns: &Vec) -> Result { + if correlated_columns.is_empty() { + // return plan_err!("failed to construct DelimGet: empty correlated columns"); + // TODO: revisit if dummy dependent join is nesessary. + return Ok(Self { + table_name: TableReference::bare("empty scan"), + columns: vec![], + projected_schema: Arc::new(DFSchema::empty()), + }); + } + + // Extract the first table reference to validate all columns come from the same table + let first_table_ref = correlated_columns[0].col.relation.clone(); + + // Validate all columns come from the same table + for column_info in correlated_columns.into_iter() { + if column_info.col.relation != first_table_ref { + return internal_err!( + "DelimGet requires all columns to be from the same table, found mixed table references"); + } + } + + let table_name = first_table_ref.ok_or_else(|| { + DataFusionError::Plan( + "DelimGet requires all columns to have a table reference".to_string(), + ) + })?; + + // Collect both table references and fields together + let qualified_fields: Vec<(Option, Arc)> = + correlated_columns + .iter() + .map(|c| { + let field = Field::new(c.col.name(), c.data_type.clone(), true); + (Some(table_name.clone()), Arc::new(field)) + }) + .collect(); + + let columns: Vec = + correlated_columns.iter().map(|c| c.col.clone()).collect(); + + let schema = DFSchema::new_with_metadata(qualified_fields, HashMap::new())?; + + Ok(DelimGet { + table_name, + columns, + projected_schema: Arc::new(schema), + }) + } +} + +impl PartialEq for DelimGet { + fn eq(&self, other: &Self) -> bool { + self.table_name == other.table_name && self.columns == other.columns + } +} + +impl Hash for DelimGet { + fn hash(&self, state: &mut H) { + self.table_name.hash(state); + self.columns.hash(state); + } +} + +impl PartialOrd for DelimGet { + fn partial_cmp(&self, other: &Self) -> Option { + match self.table_name.partial_cmp(&other.table_name) { + Some(Ordering::Equal) => self.columns.partial_cmp(&other.columns), + cmp => cmp, + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct DependentJoin { + pub schema: DFSchemaRef, + // All combinations of (subquery depth,Column and its DataType) on the RHS (and its descendant) + // which points to a column on the LHS of this dependent join + // Note that not all outer_refs from the RHS are mentioned in this vectors + // because RHS may reference columns provided somewhere from the above parent dependent join. + // Depths of each correlated_columns should always be gte current dependent join + // subquery_depth + pub correlated_columns: Vec, + // the upper expr that containing the subquery expr + // i.e for predicates: where outer = scalar_sq + 1 + // correlated exprs are `scalar_sq + 1` + pub subquery_expr: Option, + // begins with depth = 1 + pub subquery_depth: usize, + pub left: Arc, + // dependent side accessing columns from left hand side (and maybe columns) + // belong to the parent dependent join node in case of recursion) + pub right: Arc, + pub subquery_name: String, + + pub lateral_join_condition: Option<(JoinType, Expr)>, +} + +impl Display for DependentJoin { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + let correlated_str = self + .correlated_columns + .iter() + .map(|info| format!("{0} lvl {1}", info.col, info.depth)) + .collect::>() + .join(", "); + let lateral_join_info = + if let Some((join_type, join_expr)) = &self.lateral_join_condition { + format!(" lateral {join_type} join with {join_expr}") + } else { + "".to_string() + }; + let subquery_expr_str = if let Some(expr) = &self.subquery_expr { + format!(" with expr {expr}") + } else { + "".to_string() + }; + write!( + f, + "DependentJoin on [{correlated_str}]{subquery_expr_str}\ + {lateral_join_info} depth {0}", + self.subquery_depth, + ) + } +} + +impl PartialOrd for DependentJoin { + fn partial_cmp(&self, other: &Self) -> Option { + #[derive(PartialEq, PartialOrd)] + struct ComparableJoin<'a> { + correlated_columns: &'a Vec, + // the upper expr that containing the subquery expr + // i.e for predicates: where outer = scalar_sq + 1 + // correlated exprs are `scalar_sq + 1` + subquery_expr: &'a Option, + + depth: &'a usize, + left: &'a Arc, + // dependent side accessing columns from left hand side (and maybe columns) + // belong to the parent dependent join node in case of recursion) + right: &'a Arc, + lateral_join_condition: &'a Option<(JoinType, Expr)>, + } + let comparable_self = ComparableJoin { + left: &self.left, + right: &self.right, + correlated_columns: &self.correlated_columns, + subquery_expr: &self.subquery_expr, + depth: &self.subquery_depth, + lateral_join_condition: &self.lateral_join_condition, + }; + let comparable_other = ComparableJoin { + left: &other.left, + right: &other.right, + correlated_columns: &other.correlated_columns, + subquery_expr: &other.subquery_expr, + depth: &other.subquery_depth, + lateral_join_condition: &other.lateral_join_condition, + }; + comparable_self.partial_cmp(&comparable_other) + } } impl Default for LogicalPlan { @@ -319,6 +521,7 @@ impl LogicalPlan { /// Get a reference to the logical plan's schema pub fn schema(&self) -> &DFSchemaRef { match self { + LogicalPlan::DependentJoin(DependentJoin { schema, .. }) => schema, LogicalPlan::EmptyRelation(EmptyRelation { schema, .. }) => schema, LogicalPlan::Values(Values { schema, .. }) => schema, LogicalPlan::TableScan(TableScan { @@ -352,6 +555,9 @@ impl LogicalPlan { // we take the schema of the static term as the schema of the entire recursive query static_term.schema() } + LogicalPlan::DelimGet(DelimGet { + projected_schema, .. + }) => projected_schema, } } @@ -453,6 +659,9 @@ impl LogicalPlan { LogicalPlan::Aggregate(Aggregate { input, .. }) => vec![input], LogicalPlan::Sort(Sort { input, .. }) => vec![input], LogicalPlan::Join(Join { left, right, .. }) => vec![left, right], + LogicalPlan::DependentJoin(DependentJoin { left, right, .. }) => { + vec![left, right] + } LogicalPlan::Limit(Limit { input, .. }) => vec![input], LogicalPlan::Subquery(Subquery { subquery, .. }) => vec![subquery], LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => vec![input], @@ -479,7 +688,8 @@ impl LogicalPlan { LogicalPlan::TableScan { .. } | LogicalPlan::EmptyRelation { .. } | LogicalPlan::Values { .. } - | LogicalPlan::DescribeTable(_) => vec![], + | LogicalPlan::DescribeTable(_) + | LogicalPlan::DelimGet(_) => vec![], } } @@ -541,13 +751,14 @@ impl LogicalPlan { | LogicalPlan::Limit(Limit { input, .. }) | LogicalPlan::Repartition(Repartition { input, .. }) | LogicalPlan::Window(Window { input, .. }) => input.head_output_expr(), + LogicalPlan::DependentJoin(..) => todo!(), LogicalPlan::Join(Join { left, right, join_type, .. }) => match join_type { - JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => { + JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full | JoinType::LeftSingle => { if left.schema().fields().is_empty() { right.head_output_expr() } else { @@ -593,6 +804,7 @@ impl LogicalPlan { | LogicalPlan::Ddl(_) | LogicalPlan::DescribeTable(_) | LogicalPlan::Unnest(_) => Ok(None), + LogicalPlan::DelimGet(_) => todo!(), } } @@ -650,6 +862,7 @@ impl LogicalPlan { }) => Aggregate::try_new(input, group_expr, aggr_expr) .map(LogicalPlan::Aggregate), LogicalPlan::Sort(_) => Ok(self), + LogicalPlan::DependentJoin(_) => todo!(), LogicalPlan::Join(Join { left, right, @@ -659,6 +872,7 @@ impl LogicalPlan { on, schema: _, null_equality, + join_kind, }) => { let schema = build_join_schema(left.schema(), right.schema(), &join_type)?; @@ -680,6 +894,7 @@ impl LogicalPlan { filter, schema: DFSchemaRef::new(schema), null_equality, + join_kind, })) } LogicalPlan::Subquery(_) => Ok(self), @@ -749,6 +964,7 @@ impl LogicalPlan { // Update schema with unnested column type. unnest_with_options(Arc::unwrap_or_clone(input), exec_columns, options) } + LogicalPlan::DelimGet(_) => Ok(self), } } @@ -899,6 +1115,7 @@ impl LogicalPlan { join_constraint, on, null_equality, + join_kind, .. }) => { let (left, right) = self.only_two_inputs(inputs)?; @@ -938,6 +1155,7 @@ impl LogicalPlan { filter: filter_expr, schema: DFSchemaRef::new(schema), null_equality: *null_equality, + join_kind: *join_kind, })) } LogicalPlan::Subquery(Subquery { @@ -1142,6 +1360,8 @@ impl LogicalPlan { unnest_with_options(input, columns.clone(), options.clone())?; Ok(new_plan) } + LogicalPlan::DependentJoin(_) => todo!(), + LogicalPlan::DelimGet(_) => todo!(), } } @@ -1294,6 +1514,7 @@ impl LogicalPlan { /// If `Some(n)` then the plan can return at most `n` rows but may return fewer. pub fn max_rows(self: &LogicalPlan) -> Option { match self { + LogicalPlan::DependentJoin(DependentJoin { left, .. }) => left.max_rows(), LogicalPlan::Projection(Projection { input, .. }) => input.max_rows(), LogicalPlan::Filter(filter) => { if filter.is_scalar() { @@ -1333,7 +1554,7 @@ impl LogicalPlan { .. }) => match join_type { JoinType::Inner => Some(left.max_rows()? * right.max_rows()?), - JoinType::Left | JoinType::Right | JoinType::Full => { + JoinType::Left | JoinType::Right | JoinType::Full | JoinType::LeftSingle => { match (left.max_rows()?, right.max_rows()?, join_type) { (0, 0, _) => Some(0), (max_rows, 0, JoinType::Left | JoinType::Full) => Some(max_rows), @@ -1377,6 +1598,7 @@ impl LogicalPlan { | LogicalPlan::DescribeTable(_) | LogicalPlan::Statement(_) | LogicalPlan::Extension(_) => None, + LogicalPlan::DelimGet(_) => todo!(), } } @@ -1820,6 +2042,16 @@ impl LogicalPlan { Ok(()) } + LogicalPlan::DelimGet(DelimGet{columns,..}) => { + write!(f, "DelimGet:")?; // TODO + for (i, expr_item) in columns.iter().enumerate() { + if i > 0 { + write!(f, ",")?; + } + write!(f, " {expr_item}")?; + } + Ok(()) + } LogicalPlan::Projection(Projection { ref expr, .. }) => { write!(f, "Projection:")?; for (i, expr_item) in expr.iter().enumerate() { @@ -1888,11 +2120,16 @@ impl LogicalPlan { Ok(()) } + + LogicalPlan::DependentJoin(dependent_join) => { + Display::fmt(dependent_join, f) + }, LogicalPlan::Join(Join { on: ref keys, filter, join_constraint, join_type, + join_kind, .. }) => { let join_expr: Vec = @@ -1906,12 +2143,18 @@ impl LogicalPlan { } else { join_type.to_string() }; + let join_kind = if matches!(join_kind, JoinKind::ComparisonJoin) { + "ComparisonJoin" + } else { + "DelimJoin" + }; match join_constraint { JoinConstraint::On => { write!( f, - "{} Join: {}{}", + "{} Join({}): {}{}", join_type, + join_kind, join_expr.join(", "), filter_expr ) @@ -3734,6 +3977,12 @@ pub struct Sort { pub fetch: Option, } +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Hash)] +pub enum JoinKind { + ComparisonJoin, + DelimJoin, +} + /// Join two logical plans on one or more join columns #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Join { @@ -3753,6 +4002,9 @@ pub struct Join { pub schema: DFSchemaRef, /// Defines the null equality for the join. pub null_equality: NullEquality, + /// Join generated by decorrelation is DelimJoin kind. + // TODO: maybe it's better to add a new join logical plan? but i't almost the same. + pub join_kind: JoinKind, } impl Join { @@ -3794,9 +4046,16 @@ impl Join { join_constraint, schema: Arc::new(join_schema), null_equality, + join_kind: JoinKind::ComparisonJoin, }) } + pub fn is_cross_product(&self) -> bool { + self.filter.is_none() + && self.on.is_empty() + && matches!(self.join_type, JoinType::Inner) + } + /// Create Join with input which wrapped with projection, this method is used in physcial planning only to help /// create the physical join. pub fn try_new_with_project_input( @@ -3849,6 +4108,7 @@ impl Join { join_constraint: original_join.join_constraint, schema: Arc::new(join_schema), null_equality: original_join.null_equality, + join_kind: JoinKind::ComparisonJoin, }, requalified, )) @@ -5168,6 +5428,7 @@ mod tests { join_constraint: JoinConstraint::On, schema: Arc::new(left_schema.join(&right_schema)?), null_equality: NullEquality::NullEqualsNothing, + join_kind: JoinKind::ComparisonJoin, })) } diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index 47088370a1d9..eb9bc5df9f9e 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -39,9 +39,9 @@ use crate::{ dml::CopyTo, Aggregate, Analyze, CreateMemoryTable, CreateView, DdlStatement, - Distinct, DistinctOn, DmlStatement, Execute, Explain, Expr, Extension, Filter, Join, - Limit, LogicalPlan, Partitioning, Prepare, Projection, RecursiveQuery, Repartition, - Sort, Statement, Subquery, SubqueryAlias, TableScan, Union, Unnest, + DependentJoin, Distinct, DistinctOn, DmlStatement, Execute, Explain, Expr, Extension, + Filter, Join, Limit, LogicalPlan, Partitioning, Prepare, Projection, RecursiveQuery, + Repartition, Sort, Statement, Subquery, SubqueryAlias, TableScan, Union, Unnest, UserDefinedLogicalNode, Values, Window, }; use datafusion_common::tree_node::TreeNodeRefContainer; @@ -133,6 +133,7 @@ impl TreeNode for LogicalPlan { join_constraint, schema, null_equality, + join_kind, }) => (left, right).map_elements(f)?.update_data(|(left, right)| { LogicalPlan::Join(Join { left, @@ -143,6 +144,7 @@ impl TreeNode for LogicalPlan { join_constraint, schema, null_equality, + join_kind, }) }), LogicalPlan::Limit(Limit { skip, fetch, input }) => input @@ -349,7 +351,29 @@ impl TreeNode for LogicalPlan { LogicalPlan::TableScan { .. } | LogicalPlan::EmptyRelation { .. } | LogicalPlan::Values { .. } - | LogicalPlan::DescribeTable(_) => Transformed::no(self), + | LogicalPlan::DescribeTable(_) + | LogicalPlan::DelimGet(_) => Transformed::no(self), + LogicalPlan::DependentJoin(DependentJoin { + schema, + correlated_columns, + subquery_expr, + subquery_depth, + subquery_name, + lateral_join_condition, + left, + right, + }) => (left, right).map_elements(f)?.update_data(|(left, right)| { + LogicalPlan::DependentJoin(DependentJoin { + schema, + correlated_columns, + subquery_expr, + subquery_depth, + subquery_name, + lateral_join_condition, + left, + right, + }) + }), }) } } @@ -402,6 +426,22 @@ impl LogicalPlan { mut f: F, ) -> Result { match self { + LogicalPlan::DependentJoin(DependentJoin { + correlated_columns, + lateral_join_condition, + .. + }) => { + let correlated_column_exprs = correlated_columns + .iter() + .map(|info| Expr::Column(info.col.clone())) + .collect::>(); + let maybe_lateral_join_condition = lateral_join_condition + .as_ref() + .map(|(_, condition)| condition.clone()); + + (&correlated_column_exprs, &maybe_lateral_join_condition) + .apply_ref_elements(f) + } LogicalPlan::Projection(Projection { expr, .. }) => expr.apply_elements(f), LogicalPlan::Values(Values { values, .. }) => values.apply_elements(f), LogicalPlan::Filter(Filter { predicate, .. }) => f(predicate), @@ -473,7 +513,8 @@ impl LogicalPlan { | LogicalPlan::Dml(_) | LogicalPlan::Ddl(_) | LogicalPlan::Copy(_) - | LogicalPlan::DescribeTable(_) => Ok(TreeNodeRecursion::Continue), + | LogicalPlan::DescribeTable(_) + | LogicalPlan::DelimGet(_) => Ok(TreeNodeRecursion::Continue), } } @@ -564,6 +605,7 @@ impl LogicalPlan { join_constraint, schema, null_equality, + join_kind, }) => (on, filter).map_elements(f)?.update_data(|(on, filter)| { LogicalPlan::Join(Join { left, @@ -574,6 +616,7 @@ impl LogicalPlan { join_constraint, schema, null_equality, + join_kind, }) }), LogicalPlan::Sort(Sort { expr, input, fetch }) => expr @@ -653,7 +696,9 @@ impl LogicalPlan { | LogicalPlan::Dml(_) | LogicalPlan::Ddl(_) | LogicalPlan::Copy(_) - | LogicalPlan::DescribeTable(_) => Transformed::no(self), + | LogicalPlan::DescribeTable(_) + | LogicalPlan::DependentJoin(_) + | LogicalPlan::DelimGet(_) => Transformed::no(self), }) } diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml index 6d43ab7e9d7b..8288f40743fd 100644 --- a/datafusion/optimizer/Cargo.toml +++ b/datafusion/optimizer/Cargo.toml @@ -39,12 +39,14 @@ name = "datafusion_optimizer" [features] recursive_protection = ["dep:recursive"] +backtrace = ["datafusion-common/backtrace"] [dependencies] arrow = { workspace = true } chrono = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } +datafusion-functions-window = { workspace = true } datafusion-expr-common = { workspace = true } datafusion-physical-expr = { workspace = true } indexmap = { workspace = true } diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 88d51e1adea3..f8d8cdf01722 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -564,7 +564,9 @@ impl OptimizerRule for CommonSubexprEliminate { | LogicalPlan::Dml(_) | LogicalPlan::Copy(_) | LogicalPlan::Unnest(_) - | LogicalPlan::RecursiveQuery(_) => { + | LogicalPlan::RecursiveQuery(_) + | LogicalPlan::DependentJoin(_) + | LogicalPlan::DelimGet(_) => { // This rule handles recursion itself in a `ApplyOrder::TopDown` like // manner. plan.map_children(|c| self.rewrite(c, config))? diff --git a/datafusion/optimizer/src/decorrelate_dependent_join.rs b/datafusion/optimizer/src/decorrelate_dependent_join.rs new file mode 100644 index 000000000000..d632159a7586 --- /dev/null +++ b/datafusion/optimizer/src/decorrelate_dependent_join.rs @@ -0,0 +1,2533 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`DependentJoinRewriter`] converts correlated subqueries to `DependentJoin` + +use crate::rewrite_dependent_join::DependentJoinRewriter; +use crate::{ApplyOrder, OptimizerConfig, OptimizerRule}; +use std::ops::Deref; +use std::sync::Arc; + +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; +use datafusion_common::{internal_datafusion_err, internal_err, Column, Result}; +use datafusion_expr::expr::{ + self, Exists, InSubquery, WindowFunction, WindowFunctionParams, +}; +use datafusion_expr::utils::conjunction; +use datafusion_expr::{ + binary_expr, col, lit, not, when, Aggregate, CorrelatedColumnInfo, DependentJoin, + Expr, FetchType, GroupingSet, Join, JoinType, LogicalPlan, LogicalPlanBuilder, + Operator, SkipType, WindowFrame, WindowFunctionDefinition, +}; + +use datafusion_functions_window::row_number::row_number_udwf; +use indexmap::{IndexMap, IndexSet}; +use itertools::Itertools; + +#[derive(Clone, Debug)] +pub struct DependentJoinDecorrelator { + // immutable, defined when this object is constructed + domains: IndexSet, + // for each domain column, the corresponding column in delim_get + correlated_map: IndexMap, + is_initial: bool, + + // top-most subquery DecorrelateDependentJoin has depth 1 and so on + // TODO: for now it has no usage + // depth: usize, + // all correlated columns in current depth and downward (if any) + correlated_columns: Vec, + // check if we have to replace any COUNT aggregates into "CASE WHEN X IS NULL THEN 0 ELSE COUNT END" + // store a mapping between a expr and its original index in the loglan output + replacement_map: IndexMap, + // if during the top down traversal, we observe any operator that requires + // joining all rows from the lhs with nullable rows on the rhs + any_join: bool, + delim_scan_id: usize, + dscan_cols: Vec, +} + +// normal join, but remove redundant columns +// i.e if we join two table with equi joins left=right +// only take the matching table on the right; +fn natural_join( + mut builder: LogicalPlanBuilder, + right: LogicalPlan, + join_type: JoinType, + conditions: Vec<(Column, Column)>, +) -> Result { + // let mut exclude_cols = IndexSet::new(); + let join_exprs: Vec<_> = conditions + .iter() + .map(|(lhs, rhs)| { + // exclude_cols.insert(rhs); + binary_expr( + Expr::Column(lhs.clone()), + Operator::IsNotDistinctFrom, + Expr::Column(rhs.clone()), + ) + }) + .collect(); + // let require_dedup = !join_exprs.is_empty(); + + builder = builder.delim_join( + right, + join_type, + (Vec::::new(), Vec::::new()), + conjunction(join_exprs).or(Some(lit(true))), + )?; + //if require_dedup { + // let remain_cols = builder.schema().columns().into_iter().filter_map(|c| { + // if exclude_cols.contains(&c) { + // None + // } else { + // Some(Expr::Column(c)) + // } + // }); + // builder.project(remain_cols) + //} else { + Ok(builder) + //} +} + +impl DependentJoinDecorrelator { + fn new_root() -> Self { + Self { + domains: IndexSet::new(), + correlated_map: IndexMap::new(), + is_initial: true, + correlated_columns: vec![], + replacement_map: IndexMap::new(), + any_join: true, + delim_scan_id: 0, + dscan_cols: vec![], + } + } + + fn new( + node: &DependentJoin, + correlated_columns_from_parent: &Vec, + is_initial: bool, + any_join: bool, + delim_scan_id: usize, + depth: usize, + ) -> Self { + // the correlated_columns may contains columns referenced by lower depth, filter them out + let current_depth_correlated_columns = + node.correlated_columns.iter().filter_map(|info| { + if depth == info.depth { + Some(info) + } else { + None + } + }); + + // TODO: it's better if dependentjoin node store all outer ref on RHS itself + let all_outer_ref = node.right.all_out_ref_exprs(); + let parent_correlated_columns = + correlated_columns_from_parent.iter().filter(|info| { + all_outer_ref.contains(&Expr::OuterReferenceColumn( + info.data_type.clone(), + info.col.clone(), + )) + }); + + let domains: IndexSet<_> = current_depth_correlated_columns + .chain(parent_correlated_columns) + .unique() + .cloned() + .collect(); + + let mut merged_correlated_columns = correlated_columns_from_parent.clone(); + merged_correlated_columns.retain(|info| info.depth >= depth); + merged_correlated_columns.extend_from_slice(&node.correlated_columns); + + Self { + domains, + correlated_map: IndexMap::new(), + is_initial, + correlated_columns: merged_correlated_columns, + replacement_map: IndexMap::new(), + any_join, + delim_scan_id, + dscan_cols: vec![], + } + } + + fn decorrelate_independent(&mut self, plan: &LogicalPlan) -> Result { + let mut decorrelator = DependentJoinDecorrelator::new_root(); + + decorrelator.decorrelate(plan, true, 0) + } + + fn decorrelate( + &mut self, + plan: &LogicalPlan, + parent_propagate_nulls: bool, + lateral_depth: usize, + ) -> Result { + if let LogicalPlan::DependentJoin(djoin) = plan { + let perform_delim = true; + let left = djoin.left.as_ref(); + + let new_left = if !self.is_initial { + let mut has_correlated_expr = false; + detect_correlated_expressions( + plan, + &self.domains, + &mut has_correlated_expr, + )?; + let new_left = if !has_correlated_expr { + self.decorrelate_independent(left)? + } else { + self.push_down_dependent_join( + left, + parent_propagate_nulls, + lateral_depth, + )? + }; + + // TODO: duckdb does this redundant rewrite for no reason??? + // let mut new_plan = Self::rewrite_outer_ref_columns( + // new_left, + // &self.correlated_map, + // false, + // )?; + + let new_plan = Self::rewrite_outer_ref_columns( + new_left, + &self.correlated_map, + true, + )?; + new_plan + } else { + self.decorrelate(left, true, 0)? + }; + let lateral_depth = 0; + // let propagate_null_values = node.propagate_null_value(); + let _propagate_null_values = true; + let mut decorrelator = DependentJoinDecorrelator::new( + djoin, + &self.correlated_columns, + false, + false, + self.delim_scan_id, + djoin.subquery_depth, + ); + let right = decorrelator.push_down_dependent_join( + &djoin.right, + parent_propagate_nulls, + lateral_depth, + )?; + + let (join_condition, join_type, post_join_expr) = self + .delim_join_conditions( + djoin, + &decorrelator, + right.schema().columns(), + perform_delim, + )?; + + let mut builder = LogicalPlanBuilder::new(new_left).join( + right, + join_type, + (Vec::::new(), Vec::::new()), + Some(join_condition), + )?; + if let Some(subquery_proj_expr) = post_join_expr { + let new_exprs: Vec = builder + .schema() + .columns() + .into_iter() + // remove any "mark" columns output by the markjoin + .filter_map(|c| { + if c.name == "mark" { + None + } else { + Some(Expr::Column(c)) + } + }) + .chain(std::iter::once(subquery_proj_expr)) + .collect(); + builder = builder.project(new_exprs)?; + } + + self.delim_scan_id = decorrelator.delim_scan_id; + self.merge_child(&decorrelator); + + builder.build() + } else { + Ok(plan + .clone() + .map_children(|n| Ok(Transformed::yes(self.decorrelate(&n, true, 0)?)))? + .data) + } + } + + fn merge_child(&mut self, child: &Self) { + self.delim_scan_id = child.delim_scan_id; + for entry in child.correlated_map.iter() { + self.correlated_map.insert(entry.0.clone(), entry.1.clone()); + } + } + + // TODO: support lateral join + // convert dependent join into delim join + fn delim_join_conditions( + &self, + node: &DependentJoin, + decorrelator: &DependentJoinDecorrelator, + right_columns: Vec, + _perform_delim: bool, + ) -> Result<(Expr, JoinType, Option)> { + if node.lateral_join_condition.is_some() { + unimplemented!() + } + + let mut join_conditions = vec![]; + // if this is set, a new expr will be added to the parent projection + // after delimJoin + // this is because some expr cannot be evaluated during the join, for example + // binary_expr(subquery_1,subquery_2) + // this will result into 2 consecutive delim_join + // project(binary_expr(result_subquery_1, result_subquery_2)) + // delim_join on subquery1 + // delim_join on subquery2 + let mut extra_expr_after_join = None; + let mut join_type = JoinType::Inner; + if let Some(ref expr) = node.subquery_expr { + match expr { + Expr::ScalarSubquery(_) => { + // TODO: support JoinType::Single + // That works similar to left outer join + // But having extra check that only for each entry on the LHS + // only at most 1 parter on the RHS matches + join_type = JoinType::LeftSingle; + + // The reason we does not make this as a condition inside the delim join + // is because the evaluation of scalar_subquery expr may be needed + // somewhere above + extra_expr_after_join = Some( + Expr::Column(right_columns.first().unwrap().clone()) + .alias(format!("{}", node.subquery_name)), + ); + } + Expr::Exists(Exists { negated, .. }) => { + join_type = JoinType::LeftMark; + if *negated { + extra_expr_after_join = Some( + not(col("mark")).alias(format!("{}", node.subquery_name)), + ); + } else { + extra_expr_after_join = + Some(col("mark").alias(format!("{}", node.subquery_name))); + } + } + Expr::InSubquery(InSubquery { expr, negated, .. }) => { + // TODO: looks like there is a comment that + // markjoin does not support fully null semantic for ANY/IN subquery + join_type = JoinType::LeftMark; + extra_expr_after_join = + Some(col("mark").alias(format!("{}", node.subquery_name))); + let op = if *negated { + Operator::NotEq + } else { + Operator::Eq + }; + join_conditions.push(binary_expr( + expr.deref().clone(), + op, + Expr::Column(right_columns.first().unwrap().clone()), + )); + } + _ => { + unreachable!() + } + } + } + + let curr_lv_correlated_cols = if self.is_initial { + node.correlated_columns + .iter() + .filter_map(|info| { + if node.subquery_depth == info.depth { + Some(info.clone()) + } else { + None + } + }) + .collect() + } else { + decorrelator.domains.clone() + }; + + // TODO: natural join? + for corr_col in curr_lv_correlated_cols.iter().unique() { + let right_col = Self::fetch_dscan_col_from_correlated_col( + &decorrelator.correlated_map, + &corr_col.col, + )?; + + join_conditions.push(binary_expr( + col(corr_col.col.clone()), + Operator::IsNotDistinctFrom, + col(right_col.clone()), + )); + } + Ok(( + conjunction(join_conditions).or(Some(lit(true))).unwrap(), + join_type, + extra_expr_after_join, + )) + } + + fn rewrite_current_plan_outer_ref_columns( + plan: LogicalPlan, + correlated_map: &IndexMap, + ) -> Result { + // replace correlated column in dependent with delimget's column + let new_plan = if let LogicalPlan::DependentJoin(DependentJoin { .. }) = plan { + return internal_err!( + "logical error, this function should not be called if one of the plan is still dependent join node"); + } else { + plan + }; + + new_plan + .map_expressions(|e| { + e.transform(|e| { + if let Expr::OuterReferenceColumn(_, outer_col) = &e { + if let Some(delim_col) = correlated_map.get(outer_col) { + return Ok(Transformed::yes(Expr::Column(delim_col.clone()))); + }else{ + return internal_err!("correlated map does not detect for outer reference of column {}",outer_col); + } + } + Ok(Transformed::no(e)) + }) + })? + .data + .recompute_schema() + } + + fn rewrite_outer_ref_columns( + plan: LogicalPlan, + correlated_map: &IndexMap, + recursive: bool, + ) -> Result { + // TODO: take depth into consideration + let new_plan = if recursive { + plan.transform_down(|child| { + Ok(Transformed::yes( + Self::rewrite_current_plan_outer_ref_columns(child, correlated_map)?, + )) + })? + .data + .recompute_schema()? + } else { + plan + }; + + Self::rewrite_current_plan_outer_ref_columns(new_plan, correlated_map) + } + + fn fetch_dscan_col_from_correlated_col( + correlated_map: &IndexMap, + original: &Column, + ) -> Result { + correlated_map + .get(original) + .ok_or(internal_datafusion_err!( + "correlated map does not have entry for {}", + original + )) + .cloned() + } + + fn build_delim_scan(&mut self) -> Result { + // Clear last dscan info every time we build new dscan. + self.dscan_cols.clear(); + + // Collect all correlated columns of different outer table. + let mut domains_by_table: IndexMap> = + IndexMap::new(); + + for domain in &self.domains { + let table_ref = domain + .col + .relation + .clone() + .ok_or(internal_datafusion_err!( + "TableRef should exists in correlatd column" + ))? + .clone(); + let domains = domains_by_table.entry(table_ref.to_string()).or_default(); + if !domains.iter().any(|existing| { + (&existing.col == &domain.col) + && (&existing.data_type == &domain.data_type) + }) { + domains.push(domain.clone()); + } + } + + // Collect all D from different tables. + let mut delim_scans = vec![]; + for (table_ref, table_domains) in domains_by_table { + self.delim_scan_id += 1; + let delim_scan_name = + format!("{0}_dscan_{1}", table_ref.clone(), self.delim_scan_id); + + let mut projection_exprs = vec![]; + table_domains.iter().for_each(|c| { + let dcol_name = c.col.flat_name().replace(".", "_"); + let dscan_col = Column::from_qualified_name(format!( + "{}.{dcol_name}", + delim_scan_name.clone(), + )); + + self.correlated_map.insert(c.col.clone(), dscan_col.clone()); + self.dscan_cols.push(dscan_col); + + // Construct alias for projection. + projection_exprs.push( + col(c.col.clone()) + .alias_qualified(delim_scan_name.clone().into(), dcol_name), + ); + }); + + // Apply projection to rename columns and then alias the entire plan. + delim_scans.push( + LogicalPlanBuilder::delim_get(&table_domains)? + .project(projection_exprs)? + .build()?, + ); + } + + // Join all delim_scans together. + let final_delim_scan = if delim_scans.len() == 1 { + delim_scans.into_iter().next().unwrap() + } else { + let mut iter = delim_scans.into_iter(); + let first = iter + .next() + .ok_or_else(|| internal_datafusion_err!("Empty delim_scans vector"))?; + iter.try_fold(first, |acc, delim_scan| { + LogicalPlanBuilder::new(acc) + .join( + delim_scan, + JoinType::Inner, + (Vec::::new(), Vec::::new()), + None, + )? + .build() + })? + }; + + final_delim_scan.recompute_schema() + } + + fn rewrite_expr_from_replacement_map( + replacement: &IndexMap, + plan: LogicalPlan, + ) -> Result { + // TODO: not sure if rewrite should stop once found replacement expr + plan.transform_down(|p| { + if let LogicalPlan::DependentJoin(_) = &p { + return internal_err!( + "calling rewrite_correlated_exprs while some of \ + the plan is still dependent join plan" + ); + } + if let LogicalPlan::Projection(_proj) = &p { + p.map_expressions(|e| { + e.transform(|e| { + if let Some(to_replace) = replacement.get(&e.to_string()) { + Ok(Transformed::yes(to_replace.clone())) + } else { + Ok(Transformed::no(e)) + } + }) + }) + } else { + Ok(Transformed::no(p)) + // unimplemented!() + } + })? + .data + .recompute_schema() + } + + // on recursive rewrite, make sure to update any correlated_column + // TODO: make all of the delim join natural join + fn push_down_dependent_join_internal( + &mut self, + plan: &LogicalPlan, + parent_propagate_nulls: bool, + lateral_depth: usize, + ) -> Result { + // First check if the logical plan has correlated expressions. + let mut has_correlated_expr = false; + detect_correlated_expressions(plan, &self.domains, &mut has_correlated_expr)?; + + let mut exit_projection = false; + + if !has_correlated_expr { + // We reached a node without correlated expressions. + // We can eliminate the dependent join now and create a simple cross product. + // Now create the duplicate eliminated scan for this node. + match plan { + LogicalPlan::Projection(_) => { + // We want to keep the logical projection for positionality. + exit_projection = true; + } + LogicalPlan::RecursiveQuery(_) => { + // TODO: Add cte support. + unimplemented!("") + } + other => { + let delim_scan = self.build_delim_scan()?; + let left = self.decorrelate(other, true, 0)?; + return Ok(natural_join( + LogicalPlanBuilder::new(left), + delim_scan, + JoinType::Inner, + vec![], + )? + .build()?); + } + } + } + match plan { + LogicalPlan::Filter(old_filter) => { + // TODO: any join support + + let new_input = self.push_down_dependent_join_internal( + old_filter.input.as_ref(), + parent_propagate_nulls, + lateral_depth, + )?; + let mut filter = old_filter.clone(); + filter.input = Arc::new(new_input); + + return Ok(Self::rewrite_outer_ref_columns( + LogicalPlan::Filter(filter), + &self.correlated_map, + false, + )?); + } + LogicalPlan::Projection(old_proj) => { + // TODO: Take propagate_null_value into consideration. + + // If the node has no correlated expressions, push the cross product with the + // delim scan only below the projection. This will preserve positionality of the + // columns and prevent errors when reordering of delim scans is enabled. + let mut proj = old_proj.clone(); + proj.input = Arc::new(if exit_projection { + let delim_scan = self.build_delim_scan()?; + let new_left = self.decorrelate(proj.input.deref(), true, 0)?; + LogicalPlanBuilder::new(new_left) + .join( + delim_scan, + JoinType::Inner, + (Vec::::new(), Vec::::new()), + None, + )? + .build()? + } else { + self.push_down_dependent_join_internal( + proj.input.as_ref(), + parent_propagate_nulls, + lateral_depth, + )? + }); + + // Now we add all the columns of the delim scan to the projection list. + //for dcol in self.dscan_cols.iter() { + // proj.expr.push(col(dcol.clone())); + //} + + for domain_col in self.domains.iter() { + proj.expr + .push(col(Self::fetch_dscan_col_from_correlated_col( + &self.correlated_map, + &domain_col.col, + )?)); + } + + // Then we replace any correlated expressions with the corresponding entry in the + // correlated_map. + proj = match Self::rewrite_outer_ref_columns( + LogicalPlan::Projection(proj), + &self.correlated_map, + false, + )? { + LogicalPlan::Projection(projection) => projection, + _ => { + return internal_err!( + "Expected Projection after rewrite_outer_ref_columns" + ) + } + }; + + return Ok(LogicalPlan::Projection(proj)); + } + LogicalPlan::Aggregate(old_agg) => { + // TODO: support propagates null values + + // First we flatten the dependent join in the child of the projection. + let new_input = self.push_down_dependent_join_internal( + old_agg.input.as_ref(), + parent_propagate_nulls, + lateral_depth, + )?; + + // Then we replace any correlated expressions with the corresponding entry + // in the correlated_map. + let mut new_agg = old_agg.clone(); + new_agg.input = Arc::new(new_input); + let new_plan = Self::rewrite_outer_ref_columns( + LogicalPlan::Aggregate(new_agg), + &self.correlated_map, + false, + )?; + + // TODO: take perform_delim into consideration. + // Now we add all the correlated columns of current level's dependent join + // to the grouping operators AND the projection list. + let (mut group_expr, aggr_expr, input) = + if let LogicalPlan::Aggregate(Aggregate { + group_expr, + aggr_expr, + input, + .. + }) = new_plan + { + (group_expr.clone(), aggr_expr.clone(), input.clone()) + } else { + return Err(internal_datafusion_err!( + "Expected LogicalPlan::Aggregate" + )); + }; + + for c in self.domains.iter() { + let dcol = Self::fetch_dscan_col_from_correlated_col( + &self.correlated_map, + &c.col, + )?; + + for expr in &mut group_expr { + if let Expr::GroupingSet(grouping_set) = expr { + if let GroupingSet::GroupingSets(sets) = grouping_set { + for set in sets { + set.push(col(dcol.clone())) + } + } + } + } + } + + for c in self.domains.iter() { + group_expr.push(col(Self::fetch_dscan_col_from_correlated_col( + &self.correlated_map, + &c.col, + )?)); + } + + let ungroup_join = true; + if ungroup_join { + // We have to perform an INNER or LEFT OUTER JOIN between the result of this + // aggregate and the delim scan. + // This does not always have to be a LEFt OUTER JOIN, depending on whether + // aggr func return NULL or a value. + let join_type = if self.any_join || !parent_propagate_nulls { + JoinType::Left + } else { + JoinType::Inner + }; + + // Construct delim join condition. + // let mut join_conditions = vec![]; + let mut join_left_side = vec![]; + for corr in self.domains.iter() { + let delim_col = Self::fetch_dscan_col_from_correlated_col( + &self.correlated_map, + &corr.col, + )?; + join_left_side.push(delim_col); + } + + let dscan = self.build_delim_scan()?; + let mut join_right_side = vec![]; + for corr in self.domains.iter() { + let delim_col = Self::fetch_dscan_col_from_correlated_col( + &self.correlated_map, + &corr.col, + )?; + join_right_side.push(delim_col); + } + + let mut join_conditions = vec![]; + for (left_col, right_col) in + join_left_side.iter().zip(join_right_side.iter()) + { + join_conditions.push((left_col.clone(), right_col.clone())); + } + + // For any COUNT aggregate we replace reference to the column with: + // CASE WHTN COUNT (*) IS NULL THEN 0 ELSE COUNT(*) END. + for agg_expr in &aggr_expr { + match agg_expr { + Expr::AggregateFunction(expr::AggregateFunction { + func, + .. + }) => { + if func.name() == "count" { + let expr_name = agg_expr.to_string(); + let expr_to_replace = + when(agg_expr.clone().is_null(), lit(0)) + .otherwise(agg_expr.clone())?; + // Have to replace this expr with CASE exr. + self.replacement_map + .insert(expr_name, expr_to_replace); + continue; + } + } + _ => {} + } + } + + let new_agg = Aggregate::try_new(input, group_expr, aggr_expr)?; + let agg_output_cols = new_agg + .schema + .columns() + .into_iter() + .map(|c| Expr::Column(c)); + + let builder = + LogicalPlanBuilder::new(LogicalPlan::Aggregate(new_agg)) + // TODO: a hack to ensure aggregated expr are ordered first in the output + .project(agg_output_cols.rev())?; + natural_join(builder, dscan, join_type, join_conditions)?.build() + } else { + // TODO: handle this case + unimplemented!() + } + } + LogicalPlan::DependentJoin(_) => { + return self.decorrelate(plan, parent_propagate_nulls, lateral_depth); + } + LogicalPlan::Join(old_join) => { + let mut left_has_correlation = false; + detect_correlated_expressions( + old_join.left.as_ref(), + &self.domains, + &mut left_has_correlation, + )?; + let mut right_has_correlation = false; + detect_correlated_expressions( + old_join.right.as_ref(), + &self.domains, + &mut right_has_correlation, + )?; + + // Cross projuct, push into both sides of the plan. + if old_join.is_cross_product() { + if !right_has_correlation { + // Only left has correlation, push into left. + let new_left = self.push_down_dependent_join_internal( + old_join.left.as_ref(), + parent_propagate_nulls, + lateral_depth, + )?; + let new_right = + self.decorrelate_independent(old_join.right.as_ref())?; + + return self.join_without_correlation( + new_left, + new_right, + old_join.clone(), + ); + } else if !left_has_correlation { + // Only right has correlation, push into right. + let new_right = self.push_down_dependent_join_internal( + old_join.right.as_ref(), + parent_propagate_nulls, + lateral_depth, + )?; + let new_left = + self.decorrelate_independent(old_join.left.as_ref())?; + + return self.join_without_correlation( + new_left, + new_right, + old_join.clone(), + ); + } + + // Both sides have correlation, turn into an inner join. + let new_left = self.push_down_dependent_join_internal( + old_join.left.as_ref(), + parent_propagate_nulls, + lateral_depth, + )?; + let new_right = self.push_down_dependent_join_internal( + old_join.right.as_ref(), + parent_propagate_nulls, + lateral_depth, + )?; + + // Add the correlated columns to th join conditions. + return self.join_with_correlation( + new_left, + new_right, + old_join.clone(), + ); + } + + // If it's a comparison join. + match old_join.join_type { + JoinType::Inner => { + if !right_has_correlation { + // Only left has correlation, push info left. + let new_left = self.push_down_dependent_join_internal( + old_join.left.as_ref(), + parent_propagate_nulls, + lateral_depth, + )?; + let new_right = + self.decorrelate_independent(old_join.right.as_ref())?; + + return self.join_without_correlation( + new_left, + new_right, + old_join.clone(), + ); + } + + if !left_has_correlation { + // Only right has correlation, push into right. + let new_right = self.push_down_dependent_join_internal( + old_join.right.as_ref(), + parent_propagate_nulls, + lateral_depth, + )?; + let new_left = + self.decorrelate_independent(old_join.left.as_ref())?; + + return self.join_without_correlation( + new_left, + new_right, + old_join.clone(), + ); + } + } + JoinType::Left => { + if !right_has_correlation { + // Only left has correlation, push info left. + let new_left = self.push_down_dependent_join_internal( + old_join.left.as_ref(), + parent_propagate_nulls, + lateral_depth, + )?; + let new_right = + self.decorrelate_independent(old_join.right.as_ref())?; + + return self.join_without_correlation( + new_left, + new_right, + old_join.clone(), + ); + } + } + JoinType::Right => { + if !left_has_correlation { + // Only right has correlation, push into right. + let new_right = self.push_down_dependent_join_internal( + old_join.right.as_ref(), + parent_propagate_nulls, + lateral_depth, + )?; + let new_left = + self.decorrelate_independent(old_join.left.as_ref())?; + + return self.join_without_correlation( + new_left, + new_right, + old_join.clone(), + ); + } + } + JoinType::LeftMark => { + // Push the child into the RHS. + let new_left = self.push_down_dependent_join_internal( + old_join.left.as_ref(), + parent_propagate_nulls, + lateral_depth, + )?; + let new_right = + self.decorrelate_independent(old_join.right.as_ref())?; + + let new_join = self.join_without_correlation( + new_left, + new_right, + old_join.clone(), + )?; + + return Self::rewrite_outer_ref_columns( + new_join, + &self.correlated_map, + false, + ); + } + _ => return internal_err!("unreachable"), + } + + // Both sides have correlation, push into both sides. + let new_left = self.push_down_dependent_join_internal( + old_join.left.as_ref(), + parent_propagate_nulls, + lateral_depth, + )?; + let left_dscan_cols = self.dscan_cols.clone(); + + let new_right = self.push_down_dependent_join_internal( + old_join.right.as_ref(), + parent_propagate_nulls, + lateral_depth, + )?; + let right_dscan_cols = self.dscan_cols.clone(); + + // NOTE: For OUTER JOINS it matters what the correlated column map is after the join: + // for the LEFT OUTER JOIN: we want the LEFT side to be the base map after we push, + // because the RIGHT might contains NULL values. + if old_join.join_type == JoinType::Left { + self.dscan_cols = left_dscan_cols.clone(); + } + + // Add the correlated columns to the join conditions. + let new_join = self.join_with_delim_scan( + new_left, + new_right, + old_join.clone(), + &left_dscan_cols, + &right_dscan_cols, + )?; + + // Then we replace any correlated expressions with the corresponding entry in the + // correlated_map. + return Self::rewrite_outer_ref_columns( + new_join, + &self.correlated_map, + false, + ); + } + LogicalPlan::Limit(old_limit) => { + let mut sort = None; + + // Check if the direct child of this LIMIT node is an ORDER BY node, if so, keep it + // separate. This is done for an optimization to avoid having to compute the total + // order. + let new_input = if let LogicalPlan::Sort(child) = old_limit.input.as_ref() + { + sort = Some(old_limit.input.as_ref().clone()); + self.push_down_dependent_join_internal( + &child.input, + parent_propagate_nulls, + lateral_depth, + )? + } else { + self.push_down_dependent_join_internal( + &old_limit.input, + parent_propagate_nulls, + lateral_depth, + )? + }; + + let new_input_cols = new_input.schema().columns().clone(); + + // We push a row_number() OVER (PARTITION BY [correlated columns]) + // TODO: take perform delim into consideration + let mut partition_by = vec![]; + for corr_col in self.domains.iter() { + let delim_col = Self::fetch_dscan_col_from_correlated_col( + &self.correlated_map, + &corr_col.col, + )?; + partition_by.push(Expr::Column(delim_col)); + } + + let order_by = if let Some(LogicalPlan::Sort(sort)) = &sort { + // Optimization: if there is an ORDER BY node followed by a LIMIT rather than + // computing the entire order, we push the ORDER BY expressions into the + // row_num computation. This way the order only needs to be computed per + // partition. + sort.expr.clone() + } else { + vec![] + }; + + // Create row_number() window function. + let row_number_expr = Expr::WindowFunction(Box::new(WindowFunction { + fun: WindowFunctionDefinition::WindowUDF(row_number_udwf()), + params: WindowFunctionParams { + args: vec![], + partition_by, + order_by, + window_frame: WindowFrame::new(Some(false)), + null_treatment: None, + }, + })) + .alias("row_number"); + let mut window_exprs = vec![]; + window_exprs.push(row_number_expr); + + let window = LogicalPlanBuilder::new(new_input) + .window(window_exprs)? + .build()?; + + // Add filter based on row_number + // the filter we add is "row_number > offset AND row_number <= offset + limit" + let mut filter_conditions = vec![]; + + if let FetchType::Literal(Some(fetch)) = old_limit.get_fetch_type()? { + let upper_bound = + if let SkipType::Literal(skip) = old_limit.get_skip_type()? { + // Both offset and limit specified - upper bound is offset + limit. + fetch + skip + } else { + // No offset - upper bound is not only the limit. + fetch + }; + + filter_conditions + .push(col("row_number").lt_eq(lit(upper_bound as u64))); + } + + // We only need to add "row_number > offset" if offset is bigger than 0. + if let SkipType::Literal(skip) = old_limit.get_skip_type()? { + if skip > 0 { + filter_conditions.push(col("row_number").gt(lit(skip as u64))); + } + } + + let mut result_plan = window; + if !filter_conditions.is_empty() { + let filter_expr = filter_conditions + .into_iter() + .reduce(|acc, expr| acc.and(expr)) + .unwrap(); + + result_plan = LogicalPlanBuilder::new(result_plan) + .filter(filter_expr)? + .build()?; + } + + // Project away the row_number column, keeping only original columns + let final_exprs = new_input_cols + .iter() + .map(|c| col(c.clone())) + .collect::>(); + result_plan = LogicalPlanBuilder::new(result_plan) + .project(final_exprs)? + .build()?; + + return Ok(result_plan); + } + LogicalPlan::Distinct(old_distinct) => { + // Push down into child. + let new_input = self.push_down_dependent_join( + old_distinct.input().as_ref(), + parent_propagate_nulls, + lateral_depth, + )?; + // Add all correlated columns to the DISTINCT targets. + let mut distinct_exprs = old_distinct + .input() + .schema() + .columns() + .into_iter() + .map(|c| col(c.clone())) + .collect::>(); + + // Add correlated columns as additional columns for grouping + for domain_col in self.domains.iter() { + let delim_col = Self::fetch_dscan_col_from_correlated_col( + &self.correlated_map, + &domain_col.col, + )?; + distinct_exprs.push(col(delim_col)); + } + + // Create new distinct plan with additional correlated columns + let distinct = LogicalPlanBuilder::new(new_input) + .distinct_on(distinct_exprs, vec![], None)? + .build()?; + + return Ok(distinct); + } + LogicalPlan::Sort(old_sort) => { + let new_input = self.push_down_dependent_join( + old_sort.input.as_ref(), + parent_propagate_nulls, + lateral_depth, + )?; + let mut sort = old_sort.clone(); + sort.input = Arc::new(new_input); + Ok(LogicalPlan::Sort(sort)) + } + LogicalPlan::TableScan(old_table_scan) => { + let delim_scan = self.build_delim_scan()?; + + // Add correlated columns to the table scan output + let mut projection_exprs: Vec = old_table_scan + .projected_schema + .columns() + .into_iter() + .map(|c| Expr::Column(c)) + .collect(); + + // Add delim columns to projection + for domain_col in self.domains.iter() { + let delim_col = Self::fetch_dscan_col_from_correlated_col( + &self.correlated_map, + &domain_col.col, + )?; + projection_exprs.push(Expr::Column(delim_col)); + } + + // Cross join with delim scan and project + let cross_join = LogicalPlanBuilder::new(LogicalPlan::TableScan( + old_table_scan.clone(), + )) + .join( + delim_scan, + JoinType::Inner, + (Vec::::new(), Vec::::new()), + None, + )? + .project(projection_exprs)? + .build()?; + + // Rewrite correlated expressions + Self::rewrite_outer_ref_columns(cross_join, &self.correlated_map, false) + } + LogicalPlan::Window(old_window) => { + // Push into children. + let new_input = self.push_down_dependent_join_internal( + &old_window.input, + parent_propagate_nulls, + lateral_depth, + )?; + + // Create new window expressions with updated partition clauses + let mut new_window_exprs = old_window.window_expr.clone(); + + // Add correlated columns to PARTITION BY clauses in each window expression + for window_expr in &mut new_window_exprs { + // Handle both direct window functions and aliased window functions + let window_func = match window_expr { + Expr::WindowFunction(ref mut window_func) => window_func, + Expr::Alias(alias) => { + if let Expr::WindowFunction(ref mut window_func) = + alias.expr.as_mut() + { + window_func + } else { + continue; // Skip if alias doesn't contain a window function + } + } + _ => continue, // Skip non-window expressions + }; + + // Add correlated columns to the partition by clause + for domain_col in self.domains.iter() { + let delim_col = Self::fetch_dscan_col_from_correlated_col( + &self.correlated_map, + &domain_col.col, + )?; + window_func + .params + .partition_by + .push(Expr::Column(delim_col)); + } + } + + // Create new window plan with updated expressions and input + let mut window = old_window.clone(); + window.input = Arc::new(new_input); + window.window_expr = new_window_exprs; + + // We replace any correlated expressions with the corresponding entry in the + // correlated_map. + Self::rewrite_outer_ref_columns( + LogicalPlan::Window(window), + &self.correlated_map, + false, + ) + } + plan_ => { + unimplemented!("implement pushdown dependent join for node {plan_}") + } + } + } + + fn push_down_dependent_join( + &mut self, + node: &LogicalPlan, + parent_propagate_nulls: bool, + lateral_depth: usize, + ) -> Result { + let mut new_plan = self.push_down_dependent_join_internal( + node, + parent_propagate_nulls, + lateral_depth, + )?; + if !self.replacement_map.is_empty() { + new_plan = + Self::rewrite_expr_from_replacement_map(&self.replacement_map, new_plan)?; + } + + // let projected_expr = new_plan.schema().columns().into_iter().map(|c| { + // if let Some(alt_expr) = self.replacement_map.swap_remove(&c.name) { + // return alt_expr; + // } + // Expr::Column(c.clone()) + // }); + // new_plan = LogicalPlanBuilder::new(new_plan) + // .project(projected_expr)? + // .build()?; + Ok(new_plan) + } + + fn join_without_correlation( + &mut self, + left: LogicalPlan, + right: LogicalPlan, + join: Join, + ) -> Result { + let new_join = LogicalPlan::Join(Join::try_new( + Arc::new(left), + Arc::new(right), + join.on, + join.filter, + join.join_type, + join.join_constraint, + join.null_equality, + )?); + + Self::rewrite_outer_ref_columns(new_join, &self.correlated_map, false) + } + + fn join_with_correlation( + &mut self, + left: LogicalPlan, + right: LogicalPlan, + join: Join, + ) -> Result { + let mut join_conditions = vec![]; + if let Some(filter) = join.filter { + join_conditions.push(filter); + } + + for col_pair in &self.correlated_map { + join_conditions.push(binary_expr( + Expr::Column(col_pair.0.clone()), + Operator::IsNotDistinctFrom, + Expr::Column(col_pair.1.clone()), + )); + } + + let new_join = LogicalPlan::Join(Join::try_new( + Arc::new(left), + Arc::new(right), + join.on, + conjunction(join_conditions).or(Some(lit(true))), + join.join_type, + join.join_constraint, + join.null_equality, + )?); + + Self::rewrite_outer_ref_columns(new_join, &self.correlated_map, false) + } + + fn join_with_delim_scan( + &mut self, + left: LogicalPlan, + right: LogicalPlan, + join: Join, + left_dscan_cols: &Vec, + right_dscan_cols: &Vec, + ) -> Result { + let mut join_conditions = vec![]; + if let Some(filter) = join.filter { + join_conditions.push(filter); + } + + // Ensure left_dscan_cols and right_dscan_cols have the same length + if left_dscan_cols.len() != right_dscan_cols.len() { + return Err(internal_datafusion_err!( + "Mismatched dscan columns length: left_dscan_cols has {} elements, right_dscan_cols has {} elements", + left_dscan_cols.len(), + right_dscan_cols.len() + )); + } + + for (left_delim_col, right_delim_col) in + left_dscan_cols.iter().zip(right_dscan_cols.iter()) + { + join_conditions.push(binary_expr( + Expr::Column(left_delim_col.clone()), + Operator::IsNotDistinctFrom, + Expr::Column(right_delim_col.clone()), + )); + } + + let new_join = LogicalPlan::Join(Join::try_new( + Arc::new(left), + Arc::new(right), + join.on, + conjunction(join_conditions).or(Some(lit(true))), + join.join_type, + join.join_constraint, + join.null_equality, + )?); + + Self::rewrite_outer_ref_columns(new_join, &self.correlated_map, false) + } +} + +// TODO: take lateral into consideration +fn detect_correlated_expressions( + plan: &LogicalPlan, + correlated_columns: &IndexSet, + has_correlated_expressions: &mut bool, +) -> Result<()> { + plan.apply(|child| match child { + any_plan => { + for e in any_plan.all_out_ref_exprs().iter() { + if let Expr::OuterReferenceColumn(data_type, col) = e { + if correlated_columns + .iter() + .any(|c| (&c.col == col) && (&c.data_type == data_type)) + { + *has_correlated_expressions = true; + return Ok(TreeNodeRecursion::Stop); + } + } + } + Ok(TreeNodeRecursion::Continue) + } + })?; + + Ok(()) +} + +/// Optimizer rule for rewriting any arbitrary subqueries +#[allow(dead_code)] +#[derive(Debug)] +pub struct DecorrelateDependentJoin {} + +impl DecorrelateDependentJoin { + pub fn new() -> Self { + return DecorrelateDependentJoin {}; + } +} + +impl OptimizerRule for DecorrelateDependentJoin { + fn supports_rewrite(&self) -> bool { + true + } + + // There will be 2 rewrites going on + // - Convert all subqueries (maybe including lateral join in the future) to temporary + // LogicalPlan node called DependentJoin + // - Decorrelate DependentJoin following top-down approach recursively + fn rewrite( + &self, + plan: LogicalPlan, + config: &dyn OptimizerConfig, + ) -> Result> { + let mut transformer = + DependentJoinRewriter::new(Arc::clone(config.alias_generator())); + let rewrite_result = transformer.rewrite_subqueries_into_dependent_joins(plan)?; + + if rewrite_result.transformed { + println!("{}", rewrite_result.data.display_indent_schema()); + let mut decorrelator = DependentJoinDecorrelator::new_root(); + return Ok(Transformed::yes(decorrelator.decorrelate( + &rewrite_result.data, + true, + 0, + )?)); + } + + Ok(rewrite_result) + } + + fn name(&self) -> &str { + "decorrelate_subquery" + } + + fn apply_order(&self) -> Option { + None + } +} + +#[cfg(test)] +mod tests { + + use crate::decorrelate_dependent_join::DecorrelateDependentJoin; + use crate::test::{test_table_scan_with_name, test_table_with_columns}; + use crate::Optimizer; + use crate::{ + assert_optimized_plan_eq_display_indent_snapshot, OptimizerConfig, + OptimizerContext, OptimizerRule, + }; + use arrow::datatypes::DataType as ArrowDataType; + use datafusion_common::{Column, Result}; + use datafusion_expr::expr::{WindowFunction, WindowFunctionParams}; + use datafusion_expr::{ + binary_expr, not_in_subquery, JoinType, Operator, WindowFrame, + WindowFunctionDefinition, + }; + use datafusion_expr::{ + exists, expr_fn::col, in_subquery, lit, out_ref_col, scalar_subquery, Expr, + LogicalPlan, LogicalPlanBuilder, + }; + use datafusion_functions_aggregate::{count::count, sum::sum}; + use datafusion_functions_window::row_number::row_number_udwf; + use std::sync::Arc; + fn print_optimize_tree(plan: &LogicalPlan) { + let rule: Arc = + Arc::new(DecorrelateDependentJoin::new()); + let optimizer = Optimizer::with_rules(vec![rule]); + let _optimized_plan = optimizer + .optimize(plan.clone(), &OptimizerContext::new(), |_, _| {}) + .expect("failed to optimize plan"); + } + + macro_rules! assert_decorrelate { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + print_optimize_tree(&$plan); + let rule: Arc = Arc::new(DecorrelateDependentJoin::new()); + assert_optimized_plan_eq_display_indent_snapshot!( + rule, + $plan, + @ $expected, + )?; + }}; + } + + // TODO: This test is failing + #[test] + fn correlated_subquery_nested_in_uncorrelated_subquery() -> Result<()> { + let outer_table = test_table_scan_with_name("outer_table")?; + let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; + let inner_table_lv2 = test_table_scan_with_name("inner_table_lv2")?; + + let sq2 = Arc::new( + LogicalPlanBuilder::from(inner_table_lv2.clone()) + .filter( + col("inner_table_lv2.b") + .eq(out_ref_col(ArrowDataType::UInt32, "inner_table_1.b")), + )? + .build()?, + ); + let sq1 = Arc::new( + LogicalPlanBuilder::from(inner_table_lv1.clone()) + .filter(exists(sq2))? + .build()?, + ); + + let _plan = LogicalPlanBuilder::from(outer_table.clone()) + .filter(exists(sq1))? + .build()?; + // assert_decorrelate!(plan, @r" + // Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] + // Filter: __exists_sq_1.output AND __exists_sq_2.output [a:UInt32, b:UInt32, c:UInt32, __exists_sq_1.output:Boolean, __exists_sq_2.output:Boolean] + // Projection: outer_table.a, outer_table.b, outer_table.c, __exists_sq_1.output, mark AS __exists_sq_2.output [a:UInt32, b:UInt32, c:UInt32, __exists_sq_1.output:Boolean, __exists_sq_2.output:Boolean] + // LeftMark Join(ComparisonJoin): Filter: outer_table.c IS NOT DISTINCT FROM delim_scan_2.outer_table_c [a:UInt32, b:UInt32, c:UInt32, __exists_sq_1.output:Boolean, mark:Boolean] + // Projection: outer_table.a, outer_table.b, outer_table.c, mark AS __exists_sq_1.output [a:UInt32, b:UInt32, c:UInt32, __exists_sq_1.output:Boolean] + // LeftMark Join(ComparisonJoin): Filter: outer_table.b IS NOT DISTINCT FROM delim_scan_1.outer_table_b [a:UInt32, b:UInt32, c:UInt32, mark:Boolean] + // TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] + // Filter: inner_table_lv1.b = delim_scan_1.outer_table_b [a:UInt32, b:UInt32, c:UInt32, outer_table_b:UInt32;N] + // Inner Join(DelimJoin): Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32, outer_table_b:UInt32;N] + // TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] + // SubqueryAlias: delim_scan_1 [outer_table_b:UInt32;N] + // DelimGet: outer_table.b [outer_table_b:UInt32;N] + // Filter: inner_table_lv1.c = delim_scan_2.outer_table_c [a:UInt32, b:UInt32, c:UInt32, outer_table_c:UInt32;N] + // Inner Join(DelimJoin): Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32, outer_table_c:UInt32;N] + // TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] + // SubqueryAlias: delim_scan_2 [outer_table_c:UInt32;N] + // DelimGet: outer_table.c [outer_table_c:UInt32;N] + // "); + Ok(()) + } + + #[test] + fn two_dependent_joins_at_the_same_depth() -> Result<()> { + let outer_table = test_table_scan_with_name("outer_table")?; + let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; + + let sq1 = Arc::new( + LogicalPlanBuilder::from(inner_table_lv1.clone()) + .filter( + col("inner_table_lv1.b") + .eq(out_ref_col(ArrowDataType::UInt32, "outer_table.b")), + )? + .build()?, + ); + let sq2 = Arc::new( + LogicalPlanBuilder::from(inner_table_lv1.clone()) + .filter( + col("inner_table_lv1.c") + .eq(out_ref_col(ArrowDataType::UInt32, "outer_table.c")), + )? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(outer_table.clone()) + .filter(exists(sq1).and(exists(sq2)))? + .build()?; + + assert_decorrelate!(plan, @r" + Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] + Filter: __exists_sq_1 AND __exists_sq_2 [a:UInt32, b:UInt32, c:UInt32, __exists_sq_1:Boolean, __exists_sq_2:Boolean] + Projection: outer_table.a, outer_table.b, outer_table.c, __exists_sq_1, mark AS __exists_sq_2 [a:UInt32, b:UInt32, c:UInt32, __exists_sq_1:Boolean, __exists_sq_2:Boolean] + LeftMark Join(ComparisonJoin): Filter: outer_table.c IS NOT DISTINCT FROM outer_table_dscan_2.outer_table_c [a:UInt32, b:UInt32, c:UInt32, __exists_sq_1:Boolean, mark:Boolean] + Projection: outer_table.a, outer_table.b, outer_table.c, mark AS __exists_sq_1 [a:UInt32, b:UInt32, c:UInt32, __exists_sq_1:Boolean] + LeftMark Join(ComparisonJoin): Filter: outer_table.b IS NOT DISTINCT FROM outer_table_dscan_1.outer_table_b [a:UInt32, b:UInt32, c:UInt32, mark:Boolean] + TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] + Filter: inner_table_lv1.b = outer_table_dscan_1.outer_table_b [a:UInt32, b:UInt32, c:UInt32, outer_table_b:UInt32;N] + Inner Join(DelimJoin): Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32, outer_table_b:UInt32;N] + TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] + Projection: outer_table.b AS outer_table_dscan_1.outer_table_b [outer_table_b:UInt32;N] + DelimGet: outer_table.b [b:UInt32;N] + Filter: inner_table_lv1.c = outer_table_dscan_2.outer_table_c [a:UInt32, b:UInt32, c:UInt32, outer_table_c:UInt32;N] + Inner Join(DelimJoin): Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32, outer_table_c:UInt32;N] + TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] + Projection: outer_table.c AS outer_table_dscan_2.outer_table_c [outer_table_c:UInt32;N] + DelimGet: outer_table.c [c:UInt32;N] + "); + Ok(()) + } + + // Given a plan with 2 level of subquery + // This test the fact that correlated columns from the top + // are propagated to the very bottom subquery + #[test] + fn correlated_column_ref_from_parent() -> Result<()> { + let outer_table = test_table_scan_with_name("outer_table")?; + let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; + + let inner_table_lv2 = test_table_scan_with_name("inner_table_lv2")?; + let scalar_sq_level2 = Arc::new( + LogicalPlanBuilder::from(inner_table_lv2) + .filter( + col("inner_table_lv2.a") + .eq(out_ref_col(ArrowDataType::UInt32, "outer_table.a")), + )? + .aggregate(Vec::::new(), vec![count(col("inner_table_lv2.a"))])? + .build()?, + ); + let scalar_sq_level1 = Arc::new( + LogicalPlanBuilder::from(inner_table_lv1.clone()) + .filter( + col("inner_table_lv1.c") + .eq(out_ref_col(ArrowDataType::UInt32, "outer_table.c")) + .and(scalar_subquery(scalar_sq_level2).eq(lit(1))), + )? + .aggregate(Vec::::new(), vec![count(col("inner_table_lv1.a"))])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(outer_table.clone()) + .filter(scalar_subquery(scalar_sq_level1).eq(col("outer_table.a")))? + .build()?; + + assert_decorrelate!(plan, @r" + Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] + Filter: __scalar_sq_2 = outer_table.a [a:UInt32, b:UInt32, c:UInt32, count(inner_table_lv1.a):Int64;N, outer_table_c:UInt32;N, outer_table_c:UInt32;N, __scalar_sq_2:Int64;N] + Projection: outer_table.a, outer_table.b, outer_table.c, count(inner_table_lv1.a), outer_table_dscan_1.outer_table_c, outer_table_dscan_4.outer_table_c, count(inner_table_lv1.a) AS __scalar_sq_2 [a:UInt32, b:UInt32, c:UInt32, count(inner_table_lv1.a):Int64;N, outer_table_c:UInt32;N, outer_table_c:UInt32;N, __scalar_sq_2:Int64;N] + Left Join(ComparisonJoin): Filter: outer_table.c IS NOT DISTINCT FROM outer_table_dscan_4.outer_table_c [a:UInt32, b:UInt32, c:UInt32, count(inner_table_lv1.a):Int64;N, outer_table_c:UInt32;N, outer_table_c:UInt32;N] + TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] + Inner Join(DelimJoin): Filter: outer_table_dscan_1.outer_table_c IS NOT DISTINCT FROM outer_table_dscan_4.outer_table_c [count(inner_table_lv1.a):Int64, outer_table_c:UInt32;N, outer_table_c:UInt32;N] + Projection: CASE WHEN count(inner_table_lv1.a) IS NULL THEN Int32(0) ELSE count(inner_table_lv1.a) END, outer_table_dscan_1.outer_table_c [count(inner_table_lv1.a):Int64, outer_table_c:UInt32;N] + Aggregate: groupBy=[[outer_table_dscan_1.outer_table_c]], aggr=[[count(inner_table_lv1.a)]] [outer_table_c:UInt32;N, count(inner_table_lv1.a):Int64] + Projection: inner_table_lv1.a, inner_table_lv1.b, inner_table_lv1.c, outer_table_dscan_1.outer_table_c [a:UInt32, b:UInt32, c:UInt32, outer_table_c:UInt32;N] + Filter: inner_table_lv1.c = outer_table_dscan_1.outer_table_c AND __scalar_sq_1 = Int32(1) [a:UInt32, b:UInt32, c:UInt32, count(inner_table_lv2.a):Int64;N, outer_table_a:UInt32;N, outer_table_a:UInt32;N, __scalar_sq_1:Int64;N, outer_table_c:UInt32;N] + Inner Join(DelimJoin): Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32, count(inner_table_lv2.a):Int64;N, outer_table_a:UInt32;N, outer_table_a:UInt32;N, __scalar_sq_1:Int64;N, outer_table_c:UInt32;N] + Projection: inner_table_lv1.a, inner_table_lv1.b, inner_table_lv1.c, count(inner_table_lv2.a), outer_table_dscan_2.outer_table_a, outer_table_dscan_3.outer_table_a, count(inner_table_lv2.a) AS __scalar_sq_1 [a:UInt32, b:UInt32, c:UInt32, count(inner_table_lv2.a):Int64;N, outer_table_a:UInt32;N, outer_table_a:UInt32;N, __scalar_sq_1:Int64;N] + Left Join(ComparisonJoin): Filter: outer_table.a IS NOT DISTINCT FROM outer_table_dscan_3.outer_table_a [a:UInt32, b:UInt32, c:UInt32, count(inner_table_lv2.a):Int64;N, outer_table_a:UInt32;N, outer_table_a:UInt32;N] + TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] + Inner Join(DelimJoin): Filter: outer_table_dscan_2.outer_table_a IS NOT DISTINCT FROM outer_table_dscan_3.outer_table_a [count(inner_table_lv2.a):Int64, outer_table_a:UInt32;N, outer_table_a:UInt32;N] + Projection: CASE WHEN count(inner_table_lv2.a) IS NULL THEN Int32(0) ELSE count(inner_table_lv2.a) END, outer_table_dscan_2.outer_table_a [count(inner_table_lv2.a):Int64, outer_table_a:UInt32;N] + Aggregate: groupBy=[[outer_table_dscan_2.outer_table_a]], aggr=[[count(inner_table_lv2.a)]] [outer_table_a:UInt32;N, count(inner_table_lv2.a):Int64] + Filter: inner_table_lv2.a = outer_table_dscan_2.outer_table_a [a:UInt32, b:UInt32, c:UInt32, outer_table_a:UInt32;N] + Inner Join(DelimJoin): Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32, outer_table_a:UInt32;N] + TableScan: inner_table_lv2 [a:UInt32, b:UInt32, c:UInt32] + Projection: outer_table.a AS outer_table_dscan_2.outer_table_a [outer_table_a:UInt32;N] + DelimGet: outer_table.a [a:UInt32;N] + Projection: outer_table.a AS outer_table_dscan_3.outer_table_a [outer_table_a:UInt32;N] + DelimGet: outer_table.a [a:UInt32;N] + Projection: outer_table.c AS outer_table_dscan_1.outer_table_c [outer_table_c:UInt32;N] + DelimGet: outer_table.c [c:UInt32;N] + Projection: outer_table.c AS outer_table_dscan_4.outer_table_c [outer_table_c:UInt32;N] + DelimGet: outer_table.c [c:UInt32;N] + "); + Ok(()) + } + + #[test] + fn decorrelated_two_nested_subqueries() -> Result<()> { + let outer_table = test_table_scan_with_name("outer_table")?; + let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; + + let inner_table_lv2 = test_table_scan_with_name("inner_table_lv2")?; + let scalar_sq_level2 = + Arc::new( + LogicalPlanBuilder::from(inner_table_lv2) + .filter( + col("inner_table_lv2.a") + .eq(out_ref_col(ArrowDataType::UInt32, "outer_table.a")) + .and(col("inner_table_lv2.b").eq(out_ref_col( + ArrowDataType::UInt32, + "inner_table_lv1.b", + ))), + )? + .aggregate(Vec::::new(), vec![count(col("inner_table_lv2.a"))])? + .build()?, + ); + let scalar_sq_level1 = Arc::new( + LogicalPlanBuilder::from(inner_table_lv1.clone()) + .filter( + col("inner_table_lv1.c") + .eq(out_ref_col(ArrowDataType::UInt32, "outer_table.c")) + .and(scalar_subquery(scalar_sq_level2).eq(lit(1))), + )? + .aggregate(Vec::::new(), vec![count(col("inner_table_lv1.a"))])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(outer_table.clone()) + .filter( + col("outer_table.a") + .gt(lit(1)) + .and(scalar_subquery(scalar_sq_level1).eq(col("outer_table.a"))), + )? + .build()?; + + // Projection: outer_table.a, outer_table.b, outer_table.c + // Filter: outer_table.a > Int32(1) AND __scalar_sq_2.output = outer_table.a + // DependentJoin on [outer_table.a lvl 2, outer_table.c lvl 1] with expr () depth 1 + // TableScan: outer_table + // Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv1.a)]] + // Projection: inner_table_lv1.a, inner_table_lv1.b, inner_table_lv1.c + // Filter: inner_table_lv1.c = outer_ref(outer_table.c) AND __scalar_sq_1.output = Int32(1) + // DependentJoin on [inner_table_lv1.b lvl 2] with expr () depth 2 + // TableScan: inner_table_lv1 + // Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv2.a)]] + // Filter: inner_table_lv2.a = outer_ref(outer_table.a) AND inner_table_lv2.b = outer_ref(inner_table_lv1.b) + // TableScan: inner_table_lv2 + assert_decorrelate!(plan, @r" + Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] + Filter: outer_table.a > Int32(1) AND __scalar_sq_2 = outer_table.a [a:UInt32, b:UInt32, c:UInt32, count(inner_table_lv1.a):Int64;N, outer_table_c:UInt32;N, outer_table_c:UInt32;N, __scalar_sq_2:Int64;N] + Projection: outer_table.a, outer_table.b, outer_table.c, count(inner_table_lv1.a), outer_table_dscan_1.outer_table_c, outer_table_dscan_6.outer_table_c, count(inner_table_lv1.a) AS __scalar_sq_2 [a:UInt32, b:UInt32, c:UInt32, count(inner_table_lv1.a):Int64;N, outer_table_c:UInt32;N, outer_table_c:UInt32;N, __scalar_sq_2:Int64;N] + Left Join(ComparisonJoin): Filter: outer_table.c IS NOT DISTINCT FROM outer_table_dscan_6.outer_table_c [a:UInt32, b:UInt32, c:UInt32, count(inner_table_lv1.a):Int64;N, outer_table_c:UInt32;N, outer_table_c:UInt32;N] + TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] + Inner Join(DelimJoin): Filter: outer_table_dscan_1.outer_table_c IS NOT DISTINCT FROM outer_table_dscan_6.outer_table_c [count(inner_table_lv1.a):Int64, outer_table_c:UInt32;N, outer_table_c:UInt32;N] + Projection: CASE WHEN count(inner_table_lv1.a) IS NULL THEN Int32(0) ELSE count(inner_table_lv1.a) END, outer_table_dscan_1.outer_table_c [count(inner_table_lv1.a):Int64, outer_table_c:UInt32;N] + Aggregate: groupBy=[[outer_table_dscan_1.outer_table_c]], aggr=[[count(inner_table_lv1.a)]] [outer_table_c:UInt32;N, count(inner_table_lv1.a):Int64] + Projection: inner_table_lv1.a, inner_table_lv1.b, inner_table_lv1.c, outer_table_dscan_1.outer_table_c [a:UInt32, b:UInt32, c:UInt32, outer_table_c:UInt32;N] + Filter: inner_table_lv1.c = outer_table_dscan_1.outer_table_c AND __scalar_sq_1 = Int32(1) [a:UInt32, b:UInt32, c:UInt32, count(inner_table_lv2.a):Int64;N, outer_table_a:UInt32;N, inner_table_lv1_b:UInt32;N, inner_table_lv1_b:UInt32;N, outer_table_a:UInt32;N, __scalar_sq_1:Int64;N, outer_table_c:UInt32;N] + Inner Join(DelimJoin): Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32, count(inner_table_lv2.a):Int64;N, outer_table_a:UInt32;N, inner_table_lv1_b:UInt32;N, inner_table_lv1_b:UInt32;N, outer_table_a:UInt32;N, __scalar_sq_1:Int64;N, outer_table_c:UInt32;N] + Projection: inner_table_lv1.a, inner_table_lv1.b, inner_table_lv1.c, count(inner_table_lv2.a), outer_table_dscan_3.outer_table_a, inner_table_lv1_dscan_2.inner_table_lv1_b, inner_table_lv1_dscan_4.inner_table_lv1_b, outer_table_dscan_5.outer_table_a, count(inner_table_lv2.a) AS __scalar_sq_1 [a:UInt32, b:UInt32, c:UInt32, count(inner_table_lv2.a):Int64;N, outer_table_a:UInt32;N, inner_table_lv1_b:UInt32;N, inner_table_lv1_b:UInt32;N, outer_table_a:UInt32;N, __scalar_sq_1:Int64;N] + Left Join(ComparisonJoin): Filter: inner_table_lv1.b IS NOT DISTINCT FROM inner_table_lv1_dscan_4.inner_table_lv1_b AND outer_table.a IS NOT DISTINCT FROM outer_table_dscan_5.outer_table_a [a:UInt32, b:UInt32, c:UInt32, count(inner_table_lv2.a):Int64;N, outer_table_a:UInt32;N, inner_table_lv1_b:UInt32;N, inner_table_lv1_b:UInt32;N, outer_table_a:UInt32;N] + TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] + Inner Join(DelimJoin): Filter: inner_table_lv1_dscan_2.inner_table_lv1_b IS NOT DISTINCT FROM inner_table_lv1_dscan_4.inner_table_lv1_b AND outer_table_dscan_3.outer_table_a IS NOT DISTINCT FROM outer_table_dscan_5.outer_table_a [count(inner_table_lv2.a):Int64, outer_table_a:UInt32;N, inner_table_lv1_b:UInt32;N, inner_table_lv1_b:UInt32;N, outer_table_a:UInt32;N] + Projection: CASE WHEN count(inner_table_lv2.a) IS NULL THEN Int32(0) ELSE count(inner_table_lv2.a) END, outer_table_dscan_3.outer_table_a, inner_table_lv1_dscan_2.inner_table_lv1_b [count(inner_table_lv2.a):Int64, outer_table_a:UInt32;N, inner_table_lv1_b:UInt32;N] + Aggregate: groupBy=[[inner_table_lv1_dscan_2.inner_table_lv1_b, outer_table_dscan_3.outer_table_a]], aggr=[[count(inner_table_lv2.a)]] [inner_table_lv1_b:UInt32;N, outer_table_a:UInt32;N, count(inner_table_lv2.a):Int64] + Filter: inner_table_lv2.a = outer_table_dscan_3.outer_table_a AND inner_table_lv2.b = inner_table_lv1_dscan_2.inner_table_lv1_b [a:UInt32, b:UInt32, c:UInt32, inner_table_lv1_b:UInt32;N, outer_table_a:UInt32;N] + Inner Join(DelimJoin): Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32, inner_table_lv1_b:UInt32;N, outer_table_a:UInt32;N] + TableScan: inner_table_lv2 [a:UInt32, b:UInt32, c:UInt32] + Cross Join(ComparisonJoin): [inner_table_lv1_b:UInt32;N, outer_table_a:UInt32;N] + Projection: inner_table_lv1.b AS inner_table_lv1_dscan_2.inner_table_lv1_b [inner_table_lv1_b:UInt32;N] + DelimGet: inner_table_lv1.b [b:UInt32;N] + Projection: outer_table.a AS outer_table_dscan_3.outer_table_a [outer_table_a:UInt32;N] + DelimGet: outer_table.a [a:UInt32;N] + Cross Join(ComparisonJoin): [inner_table_lv1_b:UInt32;N, outer_table_a:UInt32;N] + Projection: inner_table_lv1.b AS inner_table_lv1_dscan_4.inner_table_lv1_b [inner_table_lv1_b:UInt32;N] + DelimGet: inner_table_lv1.b [b:UInt32;N] + Projection: outer_table.a AS outer_table_dscan_5.outer_table_a [outer_table_a:UInt32;N] + DelimGet: outer_table.a [a:UInt32;N] + Projection: outer_table.c AS outer_table_dscan_1.outer_table_c [outer_table_c:UInt32;N] + DelimGet: outer_table.c [c:UInt32;N] + Projection: outer_table.c AS outer_table_dscan_6.outer_table_c [outer_table_c:UInt32;N] + DelimGet: outer_table.c [c:UInt32;N] + "); + Ok(()) + } + + #[test] + fn decorrelate_join_in_subquery_with_count_depth_1() -> Result<()> { + let outer_table = test_table_scan_with_name("outer_table")?; + let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; + let sq_level1 = Arc::new( + LogicalPlanBuilder::from(inner_table_lv1) + .filter( + col("inner_table_lv1.a") + .eq(out_ref_col(ArrowDataType::UInt32, "outer_table.a")) + .and( + out_ref_col(ArrowDataType::UInt32, "outer_table.a") + .gt(col("inner_table_lv1.c")), + ) + .and(col("inner_table_lv1.b").eq(lit(1))) + .and( + out_ref_col(ArrowDataType::UInt32, "outer_table.b") + .eq(col("inner_table_lv1.b")), + ), + )? + .aggregate(Vec::::new(), vec![count(col("inner_table_lv1.a"))])? + // TODO: if uncomment this the test fail + // .project(vec![count(col("inner_table_lv1.a")).alias("count_a")])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(outer_table.clone()) + .filter( + col("outer_table.a") + .gt(lit(1)) + .and(in_subquery(col("outer_table.c"), sq_level1)), + )? + .build()?; + // Projection: outer_table.a, outer_table.b, outer_table.c + // Filter: outer_table.a > Int32(1) AND __in_sq_1.output + // DependentJoin on [outer_table.a lvl 1, outer_table.b lvl 1] with expr outer_table.c IN () depth 1 + // TableScan: outer_table + // Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv1.a)]] + // Filter: inner_table_lv1.a = outer_ref(outer_table.a) AND outer_ref(outer_table.a) > inner_table_lv1.c AND inner_table_lv1.b = Int32(1) AND outer_ref(outer_table.b) = inner_table_lv1.b + // TableScan: inner_table_lv1 + + assert_decorrelate!(plan, @r" + Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] + Filter: outer_table.a > Int32(1) AND __in_sq_1 [a:UInt32, b:UInt32, c:UInt32, __in_sq_1:Boolean] + Projection: outer_table.a, outer_table.b, outer_table.c, mark AS __in_sq_1 [a:UInt32, b:UInt32, c:UInt32, __in_sq_1:Boolean] + LeftMark Join(ComparisonJoin): Filter: outer_table.c = count(inner_table_lv1.a) AND outer_table.a IS NOT DISTINCT FROM outer_table_dscan_2.outer_table_a AND outer_table.b IS NOT DISTINCT FROM outer_table_dscan_2.outer_table_b [a:UInt32, b:UInt32, c:UInt32, mark:Boolean] + TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] + Inner Join(DelimJoin): Filter: outer_table_dscan_1.outer_table_a IS NOT DISTINCT FROM outer_table_dscan_2.outer_table_a AND outer_table_dscan_1.outer_table_b IS NOT DISTINCT FROM outer_table_dscan_2.outer_table_b [count(inner_table_lv1.a):Int64, outer_table_b:UInt32;N, outer_table_a:UInt32;N, outer_table_a:UInt32;N, outer_table_b:UInt32;N] + Projection: CASE WHEN count(inner_table_lv1.a) IS NULL THEN Int32(0) ELSE count(inner_table_lv1.a) END, outer_table_dscan_1.outer_table_b, outer_table_dscan_1.outer_table_a [count(inner_table_lv1.a):Int64, outer_table_b:UInt32;N, outer_table_a:UInt32;N] + Aggregate: groupBy=[[outer_table_dscan_1.outer_table_a, outer_table_dscan_1.outer_table_b]], aggr=[[count(inner_table_lv1.a)]] [outer_table_a:UInt32;N, outer_table_b:UInt32;N, count(inner_table_lv1.a):Int64] + Filter: inner_table_lv1.a = outer_table_dscan_1.outer_table_a AND outer_table_dscan_1.outer_table_a > inner_table_lv1.c AND inner_table_lv1.b = Int32(1) AND outer_table_dscan_1.outer_table_b = inner_table_lv1.b [a:UInt32, b:UInt32, c:UInt32, outer_table_a:UInt32;N, outer_table_b:UInt32;N] + Inner Join(DelimJoin): Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32, outer_table_a:UInt32;N, outer_table_b:UInt32;N] + TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] + Projection: outer_table.a AS outer_table_dscan_1.outer_table_a, outer_table.b AS outer_table_dscan_1.outer_table_b [outer_table_a:UInt32;N, outer_table_b:UInt32;N] + DelimGet: outer_table.a, outer_table.b [a:UInt32;N, b:UInt32;N] + Projection: outer_table.a AS outer_table_dscan_2.outer_table_a, outer_table.b AS outer_table_dscan_2.outer_table_b [outer_table_a:UInt32;N, outer_table_b:UInt32;N] + DelimGet: outer_table.a, outer_table.b [a:UInt32;N, b:UInt32;N] + "); + Ok(()) + } + + // TODO: an issue with uncorrelated subquery making this fail + #[test] + fn one_correlated_subquery_and_one_uncorrelated_subquery_at_the_same_level( + ) -> Result<()> { + let outer_table = test_table_scan_with_name("outer_table")?; + let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; + let in_sq_level1 = Arc::new( + LogicalPlanBuilder::from(inner_table_lv1.clone()) + .filter(col("inner_table_lv1.c").eq(lit(2)))? + .project(vec![col("inner_table_lv1.a")])? + .build()?, + ); + let exist_sq_level1 = Arc::new( + LogicalPlanBuilder::from(inner_table_lv1) + .filter( + col("inner_table_lv1.a").and(col("inner_table_lv1.b").eq(lit(1))), + )? + .build()?, + ); + + let _plan = LogicalPlanBuilder::from(outer_table.clone()) + .filter( + col("outer_table.a") + .gt(lit(1)) + .and(exists(exist_sq_level1)) + .and(in_subquery(col("outer_table.b"), in_sq_level1)), + )? + .build()?; + // println!("{plan}"); + // assert_decorrelate!(plan, @r" + // Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] + // Filter: outer_table.a > Int32(1) AND __exists_sq_1.output AND __in_sq_2.output [a:UInt32, b:UInt32, c:UInt32, __exists_sq_1.output:Boolean, __in_sq_2.output:Boolean] + // Projection: outer_table.a, outer_table.b, outer_table.c, __exists_sq_1.output, inner_table_lv1.mark AS __in_sq_2.output [a:UInt32, b:UInt32, c:UInt32, __exists_sq_1.output:Boolean, __in_sq_2.output:Boolean] + // LeftMark Join(ComparisonJoin): Filter: outer_table.b = inner_table_lv1.a [a:UInt32, b:UInt32, c:UInt32, __exists_sq_1.output:Boolean, mark:Boolean] + // Projection: outer_table.a, outer_table.b, outer_table.c, inner_table_lv1.mark AS __exists_sq_1.output [a:UInt32, b:UInt32, c:UInt32, __exists_sq_1.output:Boolean] + // LeftMark Join(ComparisonJoin): Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32, mark:Boolean] + // TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] + // Inner Join(DelimJoin): Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32] + // Filter: inner_table_lv1.a AND inner_table_lv1.b = Int32(1) [a:UInt32, b:UInt32, c:UInt32] + // TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] + // SubqueryAlias: delim_scan_1 [] + // DelimGet: [] + // Projection: inner_table_lv1.a [a:UInt32] + // Cross Join(ComparisonJoin): [a:UInt32, b:UInt32, c:UInt32] + // Filter: inner_table_lv1.c = Int32(2) [a:UInt32, b:UInt32, c:UInt32] + // TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] + // SubqueryAlias: delim_scan_2 [] + // DelimGet: [] + // "); + Ok(()) + } + + #[test] + fn decorrelate_with_in_subquery_has_dependent_column() -> Result<()> { + let outer_table = test_table_scan_with_name("outer_table")?; + let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; + let sq_level1 = Arc::new( + LogicalPlanBuilder::from(inner_table_lv1) + .filter( + col("inner_table_lv1.a") + .eq(out_ref_col(ArrowDataType::UInt32, "outer_table.a")) + .and( + out_ref_col(ArrowDataType::UInt32, "outer_table.a") + .gt(col("inner_table_lv1.c")), + ) + .and(col("inner_table_lv1.b").eq(lit(1))) + .and( + out_ref_col(ArrowDataType::UInt32, "outer_table.b") + .eq(col("inner_table_lv1.b")), + ), + )? + .project(vec![col("inner_table_lv1.b")])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(outer_table.clone()) + .filter( + col("outer_table.a") + .gt(lit(1)) + .and(in_subquery(col("outer_table.c"), sq_level1)), + )? + .build()?; + let dec = DecorrelateDependentJoin::new(); + let ctx: Box = Box::new(OptimizerContext::new()); + let plan = dec.rewrite(plan, ctx.as_ref())?.data; + + // Projection: outer_table.a, outer_table.b, outer_table.c + // Filter: outer_table.a > Int32(1) AND __in_sq_1.output + // DependentJoin on [outer_table.a lvl 1, outer_table.b lvl 1] with expr outer_table.c IN () depth 1 + // TableScan: outer_table + // Projection: inner_table_lv1.b + // Filter: inner_table_lv1.a = outer_ref(outer_table.a) AND outer_ref(outer_table.a) > inner_table_lv1.c AND inner_table_lv1.b = Int32(1) AND outer_ref(outer_table.b) = inner_table_lv1.b + // TableScan: inner_table_lv1 + assert_decorrelate!(plan, @r" + Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] + Filter: outer_table.a > Int32(1) AND __in_sq_1 [a:UInt32, b:UInt32, c:UInt32, __in_sq_1:Boolean] + Projection: outer_table.a, outer_table.b, outer_table.c, mark AS __in_sq_1 [a:UInt32, b:UInt32, c:UInt32, __in_sq_1:Boolean] + LeftMark Join(ComparisonJoin): Filter: outer_table.c = inner_table_lv1.b AND outer_table.a IS NOT DISTINCT FROM outer_table_dscan_1.outer_table_a AND outer_table.b IS NOT DISTINCT FROM outer_table_dscan_1.outer_table_b [a:UInt32, b:UInt32, c:UInt32, mark:Boolean] + TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] + Projection: inner_table_lv1.b, outer_table_dscan_1.outer_table_a, outer_table_dscan_1.outer_table_b [b:UInt32, outer_table_a:UInt32;N, outer_table_b:UInt32;N] + Filter: inner_table_lv1.a = outer_table_dscan_1.outer_table_a AND outer_table_dscan_1.outer_table_a > inner_table_lv1.c AND inner_table_lv1.b = Int32(1) AND outer_table_dscan_1.outer_table_b = inner_table_lv1.b [a:UInt32, b:UInt32, c:UInt32, outer_table_a:UInt32;N, outer_table_b:UInt32;N] + Inner Join(DelimJoin): Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32, outer_table_a:UInt32;N, outer_table_b:UInt32;N] + TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] + Projection: outer_table.a AS outer_table_dscan_1.outer_table_a, outer_table.b AS outer_table_dscan_1.outer_table_b [outer_table_a:UInt32;N, outer_table_b:UInt32;N] + DelimGet: outer_table.a, outer_table.b [a:UInt32;N, b:UInt32;N] + "); + + Ok(()) + } + + // This query is inside the paper + #[test] + fn decorrelate_two_different_outer_tables() -> Result<()> { + let t1 = test_table_scan_with_name("T1")?; + let t2 = test_table_scan_with_name("T2")?; + + let t3 = test_table_scan_with_name("T3")?; + let scalar_sq_level2 = Arc::new( + LogicalPlanBuilder::from(t3) + .filter( + col("T3.b") + .eq(out_ref_col(ArrowDataType::UInt32, "T2.b")) + .and(col("T3.a").eq(out_ref_col(ArrowDataType::UInt32, "T1.a"))), + )? + .aggregate(Vec::::new(), vec![sum(col("T3.a"))])? + .build()?, + ); + let scalar_sq_level1 = Arc::new( + LogicalPlanBuilder::from(t2.clone()) + .filter( + col("T2.a") + .eq(out_ref_col(ArrowDataType::UInt32, "T1.a")) + .and(scalar_subquery(scalar_sq_level2).gt(lit(300000))), + )? + .aggregate(Vec::::new(), vec![count(col("T2.a"))])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(t1.clone()) + .filter( + col("T1.c") + .eq(lit(123)) + .and(scalar_subquery(scalar_sq_level1).gt(lit(5))), + )? + .build()?; + + // Projection: t1.a, t1.b, t1.c [a:UInt32, b:UInt32, c:UInt32] + // Filter: t1.c = Int32(123) AND __scalar_sq_2.output > Int32(5) [a:UInt32, b:UInt32, c:UInt32, output:Int64] + // DependentJoin on [t1.a lvl 1, t1.a lvl 2] with expr () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Int64] + // TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + // Aggregate: groupBy=[[]], aggr=[[count(t2.a)]] [count(t2.a):Int64] + // Projection: t2.a, t2.b, t2.c [a:UInt32, b:UInt32, c:UInt32] + // Filter: t2.a = outer_ref(t1.a) AND __scalar_sq_1.output > Int32(300000) [a:UInt32, b:UInt32, c:UInt32, output:UInt64] + // DependentJoin on [t2.b lvl 2] with expr () depth 2 [a:UInt32, b:UInt32, c:UInt32, output:UInt64] + // TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + // Aggregate: groupBy=[[]], aggr=[[sum(t3.a)]] [sum(t3.a):UInt64;N] + // Filter: t3.b = outer_ref(t2.b) AND t3.a = outer_ref(t1.a) [a:UInt32, b:UInt32, c:UInt32] + // TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] + + assert_decorrelate!(plan, @r" + Projection: t1.a, t1.b, t1.c [a:UInt32, b:UInt32, c:UInt32] + Filter: t1.c = Int32(123) AND __scalar_sq_2 > Int32(5) [a:UInt32, b:UInt32, c:UInt32, count(t2.a):Int64;N, t1_a:UInt32;N, t1_a:UInt32;N, __scalar_sq_2:Int64;N] + Projection: t1.a, t1.b, t1.c, count(t2.a), t1_dscan_5.t1_a, t1_dscan_6.t1_a, count(t2.a) AS __scalar_sq_2 [a:UInt32, b:UInt32, c:UInt32, count(t2.a):Int64;N, t1_a:UInt32;N, t1_a:UInt32;N, __scalar_sq_2:Int64;N] + Left Join(ComparisonJoin): Filter: t1.a IS NOT DISTINCT FROM t1_dscan_6.t1_a [a:UInt32, b:UInt32, c:UInt32, count(t2.a):Int64;N, t1_a:UInt32;N, t1_a:UInt32;N] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + Inner Join(DelimJoin): Filter: t1_dscan_5.t1_a IS NOT DISTINCT FROM t1_dscan_6.t1_a [count(t2.a):Int64, t1_a:UInt32;N, t1_a:UInt32;N] + Projection: CASE WHEN count(t2.a) IS NULL THEN Int32(0) ELSE count(t2.a) END, t1_dscan_5.t1_a [count(t2.a):Int64, t1_a:UInt32;N] + Aggregate: groupBy=[[t1_dscan_5.t1_a]], aggr=[[count(t2.a)]] [t1_a:UInt32;N, count(t2.a):Int64] + Projection: t2.a, t2.b, t2.c, t1_dscan_5.t1_a [a:UInt32, b:UInt32, c:UInt32, t1_a:UInt32;N] + Filter: t2.a = t1_dscan_5.t1_a AND __scalar_sq_1 > Int32(300000) [a:UInt32, b:UInt32, c:UInt32, t1_a:UInt32;N, sum(t3.a):UInt64;N, t1_a:UInt32;N, t2_b:UInt32;N, t2_b:UInt32;N, t1_a:UInt32;N, __scalar_sq_1:UInt64;N] + Projection: t2.a, t2.b, t2.c, t1_dscan_1.t1_a, sum(t3.a), t1_dscan_3.t1_a, t2_dscan_2.t2_b, t2_dscan_4.t2_b, t1_dscan_5.t1_a, sum(t3.a) AS __scalar_sq_1 [a:UInt32, b:UInt32, c:UInt32, t1_a:UInt32;N, sum(t3.a):UInt64;N, t1_a:UInt32;N, t2_b:UInt32;N, t2_b:UInt32;N, t1_a:UInt32;N, __scalar_sq_1:UInt64;N] + Left Join(ComparisonJoin): Filter: t2.b IS NOT DISTINCT FROM t2_dscan_4.t2_b AND t1.a IS NOT DISTINCT FROM t1_dscan_5.t1_a AND t1.a IS NOT DISTINCT FROM t1_dscan_5.t1_a [a:UInt32, b:UInt32, c:UInt32, t1_a:UInt32;N, sum(t3.a):UInt64;N, t1_a:UInt32;N, t2_b:UInt32;N, t2_b:UInt32;N, t1_a:UInt32;N] + Inner Join(DelimJoin): Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32, t1_a:UInt32;N] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + Projection: t1.a AS t1_dscan_1.t1_a [t1_a:UInt32;N] + DelimGet: t1.a [a:UInt32;N] + Inner Join(DelimJoin): Filter: t2_dscan_2.t2_b IS NOT DISTINCT FROM t2_dscan_4.t2_b AND t1_dscan_3.t1_a IS NOT DISTINCT FROM t1_dscan_5.t1_a AND t1_dscan_3.t1_a IS NOT DISTINCT FROM t1_dscan_5.t1_a [sum(t3.a):UInt64;N, t1_a:UInt32;N, t2_b:UInt32;N, t2_b:UInt32;N, t1_a:UInt32;N] + Projection: sum(t3.a), t1_dscan_3.t1_a, t2_dscan_2.t2_b [sum(t3.a):UInt64;N, t1_a:UInt32;N, t2_b:UInt32;N] + Aggregate: groupBy=[[t2_dscan_2.t2_b, t1_dscan_3.t1_a, t1_dscan_3.t1_a]], aggr=[[sum(t3.a)]] [t2_b:UInt32;N, t1_a:UInt32;N, sum(t3.a):UInt64;N] + Filter: t3.b = t2_dscan_2.t2_b AND t3.a = t1_dscan_3.t1_a [a:UInt32, b:UInt32, c:UInt32, t2_b:UInt32;N, t1_a:UInt32;N] + Inner Join(DelimJoin): Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32, t2_b:UInt32;N, t1_a:UInt32;N] + TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] + Cross Join(ComparisonJoin): [t2_b:UInt32;N, t1_a:UInt32;N] + Projection: t2.b AS t2_dscan_2.t2_b [t2_b:UInt32;N] + DelimGet: t2.b [b:UInt32;N] + Projection: t1.a AS t1_dscan_3.t1_a [t1_a:UInt32;N] + DelimGet: t1.a [a:UInt32;N] + Cross Join(ComparisonJoin): [t2_b:UInt32;N, t1_a:UInt32;N] + Projection: t2.b AS t2_dscan_4.t2_b [t2_b:UInt32;N] + DelimGet: t2.b [b:UInt32;N] + Projection: t1.a AS t1_dscan_5.t1_a [t1_a:UInt32;N] + DelimGet: t1.a [a:UInt32;N] + Projection: t1.a AS t1_dscan_6.t1_a [t1_a:UInt32;N] + DelimGet: t1.a [a:UInt32;N] + "); + Ok(()) + } + + #[test] + fn decorrelate_inner_join_left() -> Result<()> { + let outer_table = test_table_scan_with_name("outer_table")?; + let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; + let inner_table_lv2 = test_table_scan_with_name("inner_table_lv2")?; + + let sq_level1 = Arc::new( + LogicalPlanBuilder::from(inner_table_lv1) + .join( + inner_table_lv2, + JoinType::Inner, + (Vec::::new(), Vec::::new()), + Some( + col("inner_table_lv1.a") + .eq(out_ref_col(ArrowDataType::UInt32, "outer_table.a")) + .and( + out_ref_col(ArrowDataType::UInt32, "outer_table.a") + .gt(col("inner_table_lv1.c")), + ) + .and(col("inner_table_lv1.b").eq(lit(1))) + .and( + out_ref_col(ArrowDataType::UInt32, "outer_table.b") + .eq(col("inner_table_lv1.b")), + ) + .and(col("inner_table_lv1.a").eq(col("inner_table_lv2.a"))), + ), + )? + .project(vec![col("inner_table_lv1.b")])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(outer_table.clone()) + .filter( + col("outer_table.a") + .gt(lit(1)) + .and(in_subquery(col("outer_table.c"), sq_level1)), + )? + .build()?; + + // Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] + // Filter: outer_table.a > Int32(1) AND __in_sq_1.output [a:UInt32, b:UInt32, c:UInt32, output:Boolean] + // DependentJoin on [outer_table.a lvl 1, outer_table.b lvl 1] with expr outer_table.c IN () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Boolean] + // TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] + // Projection: inner_table_lv1.b [b:UInt32] + // Inner Join(ComparisonJoin): Filter: inner_table_lv1.a = outer_ref(outer_table.a) AND outer_ref(outer_table.a) > inner_table_lv1.c AND inner_table_lv1.b = Int32(1) AND outer_ref(outer_table.b) = inner_table_lv1.b AND inner_table_lv1.a = inner_table_lv2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + // TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] + // TableScan: inner_table_lv2 [a:UInt32, b:UInt32, c:UInt32] + + assert_decorrelate!(plan, @r" + Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] + Filter: outer_table.a > Int32(1) AND __in_sq_1 [a:UInt32, b:UInt32, c:UInt32, __in_sq_1:Boolean] + Projection: outer_table.a, outer_table.b, outer_table.c, mark AS __in_sq_1 [a:UInt32, b:UInt32, c:UInt32, __in_sq_1:Boolean] + LeftMark Join(ComparisonJoin): Filter: outer_table.c = inner_table_lv1.b AND outer_table.a IS NOT DISTINCT FROM outer_table_dscan_1.outer_table_a AND outer_table.b IS NOT DISTINCT FROM outer_table_dscan_1.outer_table_b [a:UInt32, b:UInt32, c:UInt32, mark:Boolean] + TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] + Projection: inner_table_lv1.b, outer_table_dscan_1.outer_table_a, outer_table_dscan_1.outer_table_b [b:UInt32, outer_table_a:UInt32;N, outer_table_b:UInt32;N] + Inner Join(ComparisonJoin): Filter: inner_table_lv1.a = outer_table_dscan_1.outer_table_a AND outer_table_dscan_1.outer_table_a > inner_table_lv1.c AND inner_table_lv1.b = Int32(1) AND outer_table_dscan_1.outer_table_b = inner_table_lv1.b AND inner_table_lv1.a = inner_table_lv2.a [a:UInt32, b:UInt32, c:UInt32, outer_table_a:UInt32;N, outer_table_b:UInt32;N, a:UInt32, b:UInt32, c:UInt32] + Inner Join(DelimJoin): Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32, outer_table_a:UInt32;N, outer_table_b:UInt32;N] + TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] + Projection: outer_table.a AS outer_table_dscan_1.outer_table_a, outer_table.b AS outer_table_dscan_1.outer_table_b [outer_table_a:UInt32;N, outer_table_b:UInt32;N] + DelimGet: outer_table.a, outer_table.b [a:UInt32;N, b:UInt32;N] + TableScan: inner_table_lv2 [a:UInt32, b:UInt32, c:UInt32] + "); + + Ok(()) + } + + #[test] + fn decorrelate_in_subquery_with_sort_limit() -> Result<()> { + let outer_table = test_table_scan_with_name("customers")?; + let inner_table = test_table_scan_with_name("orders")?; + + let in_subquery_plan = Arc::new( + LogicalPlanBuilder::from(inner_table) + .filter( + col("orders.a") + .eq(out_ref_col(ArrowDataType::UInt32, "customers.a")) + .and(col("orders.b").eq(lit(1))), // status = 'completed' simplified as b = 1 + )? + .sort(vec![col("orders.c").sort(false, true)])? // ORDER BY order_amount DESC + .limit(0, Some(3))? // LIMIT 3 + .project(vec![col("orders.c")])? + .build()?, + ); + + // Outer query + let plan = LogicalPlanBuilder::from(outer_table.clone()) + .filter( + col("customers.a") + .gt(lit(100)) + .and(in_subquery(col("customers.a"), in_subquery_plan)), + )? + .build()?; + + // Projection: customers.a, customers.b, customers.c + // Filter: customers.a > Int32(100) AND __in_sq_1.output + // DependentJoin on [customers.a lvl 1] with expr customers.a IN () depth 1 + // TableScan: customers + // Projection: orders.c + // Limit: skip=0, fetch=3 + // Sort: orders.c DESC NULLS FIRST + // Filter: orders.a = outer_ref(customers.a) AND orders.b = Int32(1) + // TableScan: orders + + assert_decorrelate!(plan, @r" + Projection: customers.a, customers.b, customers.c [a:UInt32, b:UInt32, c:UInt32] + Filter: customers.a > Int32(100) AND __in_sq_1 [a:UInt32, b:UInt32, c:UInt32, __in_sq_1:Boolean] + Projection: customers.a, customers.b, customers.c, mark AS __in_sq_1 [a:UInt32, b:UInt32, c:UInt32, __in_sq_1:Boolean] + LeftMark Join(ComparisonJoin): Filter: customers.a = orders.c AND customers.a IS NOT DISTINCT FROM customers_dscan_1.customers_a [a:UInt32, b:UInt32, c:UInt32, mark:Boolean] + TableScan: customers [a:UInt32, b:UInt32, c:UInt32] + Projection: orders.c, customers_dscan_1.customers_a [c:UInt32, customers_a:UInt32;N] + Projection: orders.a, orders.b, orders.c, customers_dscan_1.customers_a [a:UInt32, b:UInt32, c:UInt32, customers_a:UInt32;N] + Filter: row_number <= UInt64(3) [a:UInt32, b:UInt32, c:UInt32, customers_a:UInt32;N, row_number:UInt64] + WindowAggr: windowExpr=[[row_number() PARTITION BY [customers_dscan_1.customers_a] ORDER BY [orders.c DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS row_number]] [a:UInt32, b:UInt32, c:UInt32, customers_a:UInt32;N, row_number:UInt64] + Filter: orders.a = customers_dscan_1.customers_a AND orders.b = Int32(1) [a:UInt32, b:UInt32, c:UInt32, customers_a:UInt32;N] + Inner Join(DelimJoin): Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32, customers_a:UInt32;N] + TableScan: orders [a:UInt32, b:UInt32, c:UInt32] + Projection: customers.a AS customers_dscan_1.customers_a [customers_a:UInt32;N] + DelimGet: customers.a [a:UInt32;N] + "); + + Ok(()) + } + + #[test] + fn decorrelate_subquery_with_window_function() -> Result<()> { + let outer_table = test_table_scan_with_name("outer_table")?; + let inner_table = test_table_scan_with_name("inner_table")?; + + // Create a subquery with window function + let window_expr = Expr::WindowFunction(Box::new(WindowFunction { + fun: WindowFunctionDefinition::WindowUDF(row_number_udwf()), + params: WindowFunctionParams { + args: vec![], + partition_by: vec![col("inner_table.b")], + order_by: vec![col("inner_table.c").sort(false, true)], + window_frame: WindowFrame::new(Some(false)), + null_treatment: None, + }, + })) + .alias("row_num"); + + let subquery = Arc::new( + LogicalPlanBuilder::from(inner_table) + .filter( + col("inner_table.a") + .eq(out_ref_col(ArrowDataType::UInt32, "outer_table.a")), + )? + .window(vec![window_expr])? + .filter(col("row_num").eq(lit(1)))? + .project(vec![col("inner_table.b")])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(outer_table) + .filter( + col("outer_table.a") + .gt(lit(1)) + .and(in_subquery(col("outer_table.c"), subquery)), + )? + .build()?; + + // Projection: outer_table.a, outer_table.b, outer_table.c + // Filter: outer_table.a > Int32(1) AND __in_sq_1.output + // DependentJoin on [outer_table.a lvl 1] with expr outer_table.c IN () depth 1 + // TableScan: outer_table + // Projection: inner_table.b + // Filter: row_num = Int32(1) + // WindowAggr: windowExpr=[[row_number() PARTITION BY [inner_table.b] ORDER BY [inner_table.c DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS row_num]] + // Filter: inner_table.a = outer_ref(outer_table.a) + // TableScan: inner_table + + assert_decorrelate!(plan, @r" + Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] + Filter: outer_table.a > Int32(1) AND __in_sq_1 [a:UInt32, b:UInt32, c:UInt32, __in_sq_1:Boolean] + Projection: outer_table.a, outer_table.b, outer_table.c, mark AS __in_sq_1 [a:UInt32, b:UInt32, c:UInt32, __in_sq_1:Boolean] + LeftMark Join(ComparisonJoin): Filter: outer_table.c = inner_table.b AND outer_table.a IS NOT DISTINCT FROM outer_table_dscan_1.outer_table_a [a:UInt32, b:UInt32, c:UInt32, mark:Boolean] + TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] + Projection: inner_table.b, outer_table_dscan_1.outer_table_a [b:UInt32, outer_table_a:UInt32;N] + Filter: row_num = Int32(1) [a:UInt32, b:UInt32, c:UInt32, outer_table_a:UInt32;N, row_num:UInt64] + WindowAggr: windowExpr=[[row_number() PARTITION BY [inner_table.b, outer_table_dscan_1.outer_table_a] ORDER BY [inner_table.c DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS row_num]] [a:UInt32, b:UInt32, c:UInt32, outer_table_a:UInt32;N, row_num:UInt64] + Filter: inner_table.a = outer_table_dscan_1.outer_table_a [a:UInt32, b:UInt32, c:UInt32, outer_table_a:UInt32;N] + Inner Join(DelimJoin): Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32, outer_table_a:UInt32;N] + TableScan: inner_table [a:UInt32, b:UInt32, c:UInt32] + Projection: outer_table.a AS outer_table_dscan_1.outer_table_a [outer_table_a:UInt32;N] + DelimGet: outer_table.a [a:UInt32;N] + "); + + Ok(()) + } + + // TODO: support uncorrelated subquery + #[test] + fn subquery_slt_test1() -> Result<()> { + // Create test tables with custom column names + let t1 = test_table_with_columns( + "t1", + &[ + ("t1_id", ArrowDataType::UInt32), + ("t1_name", ArrowDataType::Utf8), + ("t1_int", ArrowDataType::Int32), + ], + )?; + + let t2 = test_table_with_columns( + "t2", + &[ + ("t2_id", ArrowDataType::UInt32), + ("t2_value", ArrowDataType::Utf8), + ], + )?; + + // Create the subquery plan (SELECT t2_id FROM t2) + let subquery = Arc::new( + LogicalPlanBuilder::from(t2) + .project(vec![col("t2_id")])? + .build()?, + ); + + // Create the main query plan + // SELECT t1_id, t1_name, t1_int FROM t1 WHERE t1_id IN (SELECT t2_id FROM t2) + let _plan = LogicalPlanBuilder::from(t1) + .filter(in_subquery(col("t1_id"), subquery))? + .project(vec![col("t1_id"), col("t1_name"), col("t1_int")])? + .build()?; + + // Test the decorrelation transformation + // assert_decorrelate!(plan, @r""); + + Ok(()) + } + + #[test] + fn subquery_slt_test2() -> Result<()> { + // Create test tables with custom column names + let t1 = test_table_with_columns( + "t1", + &[ + ("t1_id", ArrowDataType::UInt32), + ("t1_name", ArrowDataType::Utf8), + ("t1_int", ArrowDataType::Int32), + ], + )?; + + let t2 = test_table_with_columns( + "t2", + &[ + ("t2_id", ArrowDataType::UInt32), + ("t2_value", ArrowDataType::Utf8), + ], + )?; + + // Create the subquery plan + // SELECT t2_id + 1 FROM t2 WHERE t1.t1_int > 0 + let subquery = Arc::new( + LogicalPlanBuilder::from(t2) + .filter(out_ref_col(ArrowDataType::Int32, "t1.t1_int").gt(lit(0)))? + .project(vec![binary_expr(col("t2_id"), Operator::Plus, lit(1))])? + .build()?, + ); + + // Create the main query plan + // SELECT t1_id, t1_name, t1_int FROM t1 WHERE t1_id + 12 NOT IN (SELECT t2_id + 1 FROM t2 WHERE t1.t1_int > 0) + let plan = LogicalPlanBuilder::from(t1) + .filter(not_in_subquery( + binary_expr(col("t1_id"), Operator::Plus, lit(12)), + subquery, + ))? + .project(vec![col("t1_id"), col("t1_name"), col("t1_int")])? + .build()?; + + // select t1.t1_id, + // t1.t1_name, + // t1.t1_int + // from t1 + // where t1.t1_id + 12 not in (select t2.t2_id + 1 from t2 where t1.t1_int > 0); + + // Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32, t1_name:Utf8, t1_int:Int32] + // Filter: t1.t1_id + Int32(12) NOT IN () [t1_id:UInt32, t1_name:Utf8, t1_int:Int32] + // Subquery: [t2.t2_id + Int32(1):Int64] + // Projection: t2.t2_id + Int32(1) [t2.t2_id + Int32(1):Int64] + // Filter: outer_ref(t1.t1_int) > Int32(0) [t2_id:UInt32, t2_value:Utf8] + // TableScan: t2 [t2_id:UInt32, t2_value:Utf8] + // TableScan: t1 [t1_id:UInt32, t1_name:Utf8, t1_int:Int32] + + // Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32, t1_name:Utf8, t1_int:Int32] + // Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32, t1_name:Utf8, t1_int:Int32] + // Filter: __in_sq_1.output [t1_id:UInt32, t1_name:Utf8, t1_int:Int32, output:Boolean] + // DependentJoin on [t1.t1_int lvl 1] with expr t1.t1_id + Int32(12) NOT IN () depth 1 [t1_id:UInt32, t1_name:Utf8, t1_int:Int32, output:Boolean] + // TableScan: t1 [t1_id:UInt32, t1_name:Utf8, t1_int:Int32] + // Projection: t2.t2_id + Int32(1) [t2.t2_id + Int32(1):Int64] + // Filter: outer_ref(t1.t1_int) > Int32(0) [t2_id:UInt32, t2_value:Utf8] + // TableScan: t2 [t2_id:UInt32, t2_value:Utf8] + + assert_decorrelate!(plan, @r" + Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32, t1_name:Utf8, t1_int:Int32] + Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32, t1_name:Utf8, t1_int:Int32] + Filter: NOT __in_sq_1 [t1_id:UInt32, t1_name:Utf8, t1_int:Int32, __in_sq_1:Boolean] + Projection: t1.t1_id, t1.t1_name, t1.t1_int, t1_dscan_1.mark AS __in_sq_1 [t1_id:UInt32, t1_name:Utf8, t1_int:Int32, __in_sq_1:Boolean] + LeftMark Join(ComparisonJoin): Filter: t1.t1_id + Int32(12) = t2.t2_id + Int32(1) AND t1.t1_int IS NOT DISTINCT FROM t1_dscan_1.t1_t1_int [t1_id:UInt32, t1_name:Utf8, t1_int:Int32, mark:Boolean] + TableScan: t1 [t1_id:UInt32, t1_name:Utf8, t1_int:Int32] + Projection: t2.t2_id + Int32(1), t1_dscan_1.t1_t1_int [t2.t2_id + Int32(1):Int64, t1_t1_int:Int32;N] + Filter: t1_dscan_1.t1_t1_int > Int32(0) [t2_id:UInt32, t2_value:Utf8, t1_t1_int:Int32;N] + Inner Join(DelimJoin): Filter: Boolean(true) [t2_id:UInt32, t2_value:Utf8, t1_t1_int:Int32;N] + TableScan: t2 [t2_id:UInt32, t2_value:Utf8] + Projection: t1.t1_int AS t1_dscan_1.t1_t1_int [t1_t1_int:Int32;N] + DelimGet: t1.t1_int [t1_int:Int32;N] + "); + + Ok(()) + } + + #[test] + fn subquery_slt_test3() -> Result<()> { + // Test case for: SELECT t1_id, (SELECT sum(t2_int) FROM t2 WHERE t2.t2_id = t1.t1_id) as t2_sum from t1 + + // Create test tables + let t1 = test_table_with_columns( + "t1", + &[ + ("t1_id", ArrowDataType::UInt32), + ("t1_name", ArrowDataType::Utf8), + ("t1_int", ArrowDataType::Int32), + ], + )?; + + let t2 = test_table_with_columns( + "t2", + &[ + ("t2_id", ArrowDataType::UInt32), + ("t2_int", ArrowDataType::Int32), + ("t2_value", ArrowDataType::Utf8), + ], + )?; + + // Create the scalar subquery: SELECT sum(t2_int) FROM t2 WHERE t2.t2_id = t1.t1_id + let scalar_sq = Arc::new( + LogicalPlanBuilder::from(t2) + .filter( + col("t2.t2_id").eq(out_ref_col(ArrowDataType::UInt32, "t1.t1_id")), + )? + .aggregate(Vec::::new(), vec![sum(col("t2_int"))])? + .build()?, + ); + + // Create the main query plan: SELECT t1_id, (subquery) as t2_sum FROM t1 + let plan = LogicalPlanBuilder::from(t1) + .project(vec![ + col("t1_id"), + scalar_subquery(scalar_sq).alias("t2_sum"), + ])? + .build()?; + + // Projection: t1.t1_id, __scalar_sq_1.output AS t2_sum [t1_id:UInt32, t2_sum:Int64] + // DependentJoin on [t1.t1_id lvl 1] with expr () depth 1 [t1_id:UInt32, t1_name:Utf8, t1_int:Int32, output:Int64] + // TableScan: t1 [t1_id:UInt32, t1_name:Utf8, t1_int:Int32] + // Aggregate: groupBy=[[]], aggr=[[sum(t2.t2_int)]] [sum(t2.t2_int):Int64;N] + // Filter: t2.t2_id = outer_ref(t1.t1_id) [t2_id:UInt32, t2_int:Int32, t2_value:Utf8] + // TableScan: t2 [t2_id:UInt32, t2_int:Int32, t2_value:Utf8] + + assert_decorrelate!(plan, @r" + Projection: t1.t1_id, __scalar_sq_1 AS t2_sum [t1_id:UInt32, t2_sum:Int64] + Projection: t1.t1_id, t1.t1_name, t1.t1_int, sum(t2.t2_int), t1_dscan_1.t1_t1_id, t1_dscan_2.t1_t1_id, sum(t2.t2_int) AS __scalar_sq_1 [t1_id:UInt32, t1_name:Utf8, t1_int:Int32, sum(t2.t2_int):Int64;N, t1_t1_id:UInt32;N, t1_t1_id:UInt32;N, __scalar_sq_1:Int64;N] + Left Join(ComparisonJoin): Filter: t1.t1_id IS NOT DISTINCT FROM t1_dscan_2.t1_t1_id [t1_id:UInt32, t1_name:Utf8, t1_int:Int32, sum(t2.t2_int):Int64;N, t1_t1_id:UInt32;N, t1_t1_id:UInt32;N] + TableScan: t1 [t1_id:UInt32, t1_name:Utf8, t1_int:Int32] + Inner Join(DelimJoin): Filter: t1_dscan_1.t1_t1_id IS NOT DISTINCT FROM t1_dscan_2.t1_t1_id [sum(t2.t2_int):Int64;N, t1_t1_id:UInt32;N, t1_t1_id:UInt32;N] + Projection: sum(t2.t2_int), t1_dscan_1.t1_t1_id [sum(t2.t2_int):Int64;N, t1_t1_id:UInt32;N] + Aggregate: groupBy=[[t1_dscan_1.t1_t1_id]], aggr=[[sum(t2.t2_int)]] [t1_t1_id:UInt32;N, sum(t2.t2_int):Int64;N] + Filter: t2.t2_id = t1_dscan_1.t1_t1_id [t2_id:UInt32, t2_int:Int32, t2_value:Utf8, t1_t1_id:UInt32;N] + Inner Join(DelimJoin): Filter: Boolean(true) [t2_id:UInt32, t2_int:Int32, t2_value:Utf8, t1_t1_id:UInt32;N] + TableScan: t2 [t2_id:UInt32, t2_int:Int32, t2_value:Utf8] + Projection: t1.t1_id AS t1_dscan_1.t1_t1_id [t1_t1_id:UInt32;N] + DelimGet: t1.t1_id [t1_id:UInt32;N] + Projection: t1.t1_id AS t1_dscan_2.t1_t1_id [t1_t1_id:UInt32;N] + DelimGet: t1.t1_id [t1_id:UInt32;N] + "); + + Ok(()) + } + + #[test] + fn subquery_slt_test4() -> Result<()> { + // Test case for: SELECT t1_id, t1_name FROM t1 WHERE EXISTS (SELECT * FROM t2 WHERE t2_id = t1_id LIMIT 1); + + // Create test tables matching the SQL schema + let t1 = test_table_with_columns( + "t1", + &[ + ("t1_id", ArrowDataType::Int32), + ("t1_name", ArrowDataType::Utf8), + ("t1_int", ArrowDataType::Int32), + ], + )?; + + let t2 = test_table_with_columns( + "t2", + &[ + ("t2_id", ArrowDataType::Int32), + ("t2_name", ArrowDataType::Utf8), + ("t2_int", ArrowDataType::Int32), + ], + )?; + + // Create the EXISTS subquery: SELECT * FROM t2 WHERE t2_id = t1_id LIMIT 1 + let exists_subquery = Arc::new( + LogicalPlanBuilder::from(t2) + .filter( + col("t2.t2_id").eq(out_ref_col(ArrowDataType::Int32, "t1.t1_id")), + )? + .limit(0, Some(1))? // LIMIT 1 + .build()?, + ); + + // Create the main query plan: SELECT t1_id, t1_name FROM t1 WHERE EXISTS (subquery) + let plan = LogicalPlanBuilder::from(t1) + .filter(exists(exists_subquery))? + .project(vec![col("t1_id"), col("t1_name")])? + .build()?; + + assert_decorrelate!(plan, @r" + Projection: t1.t1_id, t1.t1_name [t1_id:Int32, t1_name:Utf8] + Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:Int32, t1_name:Utf8, t1_int:Int32] + Filter: __exists_sq_1 [t1_id:Int32, t1_name:Utf8, t1_int:Int32, __exists_sq_1:Boolean] + Projection: t1.t1_id, t1.t1_name, t1.t1_int, mark AS __exists_sq_1 [t1_id:Int32, t1_name:Utf8, t1_int:Int32, __exists_sq_1:Boolean] + LeftMark Join(ComparisonJoin): Filter: t1.t1_id IS NOT DISTINCT FROM t1_dscan_1.t1_t1_id [t1_id:Int32, t1_name:Utf8, t1_int:Int32, mark:Boolean] + TableScan: t1 [t1_id:Int32, t1_name:Utf8, t1_int:Int32] + Projection: t2.t2_id, t2.t2_name, t2.t2_int, t1_dscan_1.t1_t1_id [t2_id:Int32, t2_name:Utf8, t2_int:Int32, t1_t1_id:Int32;N] + Filter: row_number <= UInt64(1) [t2_id:Int32, t2_name:Utf8, t2_int:Int32, t1_t1_id:Int32;N, row_number:UInt64] + WindowAggr: windowExpr=[[row_number() PARTITION BY [t1_dscan_1.t1_t1_id] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS row_number]] [t2_id:Int32, t2_name:Utf8, t2_int:Int32, t1_t1_id:Int32;N, row_number:UInt64] + Filter: t2.t2_id = t1_dscan_1.t1_t1_id [t2_id:Int32, t2_name:Utf8, t2_int:Int32, t1_t1_id:Int32;N] + Inner Join(DelimJoin): Filter: Boolean(true) [t2_id:Int32, t2_name:Utf8, t2_int:Int32, t1_t1_id:Int32;N] + TableScan: t2 [t2_id:Int32, t2_name:Utf8, t2_int:Int32] + Projection: t1.t1_id AS t1_dscan_1.t1_t1_id [t1_t1_id:Int32;N] + DelimGet: t1.t1_id [t1_id:Int32;N] + "); + + Ok(()) + } + + #[test] + fn subquery_slt_test5() -> Result<()> { + // Test case for: SELECT t1_id, (SELECT count(*) FROM t2 WHERE t2.t2_int = t1.t1_int) from t1; + + // Create test tables matching the SQL schema + let t1 = test_table_with_columns( + "t1", + &[ + ("t1_id", ArrowDataType::Int32), + ("t1_name", ArrowDataType::Utf8), + ("t1_int", ArrowDataType::Int32), + ], + )?; + + let t2 = test_table_with_columns( + "t2", + &[ + ("t2_id", ArrowDataType::Int32), + ("t2_name", ArrowDataType::Utf8), + ("t2_int", ArrowDataType::Int32), + ], + )?; + + // Create the scalar subquery: SELECT count(*) FROM t2 WHERE t2.t2_int = t1.t1_int + let scalar_sq = Arc::new( + LogicalPlanBuilder::from(t2) + .filter( + col("t2.t2_int").eq(out_ref_col(ArrowDataType::Int32, "t1.t1_int")), + )? + .aggregate(Vec::::new(), vec![count(lit(1))])? // count(*) is represented as count(1) + .build()?, + ); + + // Create the main query plan: SELECT t1_id, (subquery) FROM t1 + let plan = LogicalPlanBuilder::from(t1) + .project(vec![col("t1_id"), scalar_subquery(scalar_sq)])? + .build()?; + + // Projection: t1.t1_id, __scalar_sq_1 [t1_id:Int32, __scalar_sq_1:Int64] + // DependentJoin on [t1.t1_int lvl 1] with expr () depth 1 [t1_id:Int32, t1_name:Utf8, t1_int:Int32, __scalar_sq_1:Int64] + // TableScan: t1 [t1_id:Int32, t1_name:Utf8, t1_int:Int32] + // Aggregate: groupBy=[[]], aggr=[[count(Int32(1))]] [count(Int32(1)):Int64] + // Filter: t2.t2_int = outer_ref(t1.t1_int) [t2_id:Int32, t2_name:Utf8, t2_int:Int32] + // TableScan: t2 [t2_id:Int32, t2_name:Utf8, t2_int:Int32] + + assert_decorrelate!(plan, @r""); + + Ok(()) + } +} diff --git a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs index a72657bf689d..326eb0e28a63 100644 --- a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs @@ -471,8 +471,8 @@ mod tests { plan, @r" Projection: test.b [b:UInt32] - LeftSemi Join: Filter: test.b = __correlated_sq_2.c [a:UInt32, b:UInt32, c:UInt32] - LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32] + LeftSemi Join(ComparisonJoin): Filter: test.b = __correlated_sq_2.c [a:UInt32, b:UInt32, c:UInt32] + LeftSemi Join(ComparisonJoin): Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32] TableScan: test [a:UInt32, b:UInt32, c:UInt32] SubqueryAlias: __correlated_sq_1 [c:UInt32] Projection: sq_1.c [c:UInt32] @@ -504,7 +504,7 @@ mod tests { @r" Projection: test.b [b:UInt32] Filter: test.a = UInt32(1) AND test.b < UInt32(30) [a:UInt32, b:UInt32, c:UInt32] - LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32] + LeftSemi Join(ComparisonJoin): Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32] TableScan: test [a:UInt32, b:UInt32, c:UInt32] SubqueryAlias: __correlated_sq_1 [c:UInt32] Projection: sq.c [c:UInt32] @@ -532,11 +532,11 @@ mod tests { plan, @r" Projection: test.b [b:UInt32] - LeftSemi Join: Filter: test.b = __correlated_sq_2.a [a:UInt32, b:UInt32, c:UInt32] + LeftSemi Join(ComparisonJoin): Filter: test.b = __correlated_sq_2.a [a:UInt32, b:UInt32, c:UInt32] TableScan: test [a:UInt32, b:UInt32, c:UInt32] SubqueryAlias: __correlated_sq_2 [a:UInt32] Projection: sq.a [a:UInt32] - LeftSemi Join: Filter: sq.a = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32] + LeftSemi Join(ComparisonJoin): Filter: sq.a = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32] TableScan: sq [a:UInt32, b:UInt32, c:UInt32] SubqueryAlias: __correlated_sq_1 [c:UInt32] Projection: sq_nested.c [c:UInt32] @@ -569,10 +569,10 @@ mod tests { assert_optimized_plan_equal!( plan, - @r###" + @r" Projection: customer.c_custkey [c_custkey:Int64] - LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8] - LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + LeftSemi Join(ComparisonJoin): Filter: customer.c_custkey = __correlated_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8] + LeftSemi Join(ComparisonJoin): Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] TableScan: customer [c_custkey:Int64, c_name:Utf8] SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] Projection: orders.o_custkey [o_custkey:Int64] @@ -580,7 +580,7 @@ mod tests { SubqueryAlias: __correlated_sq_2 [o_custkey:Int64] Projection: orders.o_custkey [o_custkey:Int64] TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] - "### + " ) } @@ -619,11 +619,11 @@ mod tests { plan, @r" Projection: customer.c_custkey [c_custkey:Int64] - LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8] + LeftSemi Join(ComparisonJoin): Filter: customer.c_custkey = __correlated_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8] TableScan: customer [c_custkey:Int64, c_name:Utf8] SubqueryAlias: __correlated_sq_2 [o_custkey:Int64] Projection: orders.o_custkey [o_custkey:Int64] - LeftSemi Join: Filter: orders.o_orderkey = __correlated_sq_1.l_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + LeftSemi Join(ComparisonJoin): Filter: orders.o_orderkey = __correlated_sq_1.l_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] SubqueryAlias: __correlated_sq_1 [l_orderkey:Int64] Projection: lineitem.l_orderkey [l_orderkey:Int64] @@ -655,7 +655,7 @@ mod tests { plan, @r" Projection: customer.c_custkey [c_custkey:Int64] - LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + LeftSemi Join(ComparisonJoin): Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] TableScan: customer [c_custkey:Int64, c_name:Utf8] SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] Projection: orders.o_custkey [o_custkey:Int64] @@ -687,7 +687,7 @@ mod tests { plan, @r" Projection: customer.c_custkey [c_custkey:Int64] - LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + LeftSemi Join(ComparisonJoin): Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] TableScan: customer [c_custkey:Int64, c_name:Utf8] SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] Projection: orders.o_custkey [o_custkey:Int64] @@ -715,7 +715,7 @@ mod tests { plan, @r" Projection: customer.c_custkey [c_custkey:Int64] - LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + LeftSemi Join(ComparisonJoin): Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] TableScan: customer [c_custkey:Int64, c_name:Utf8] SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] Projection: orders.o_custkey [o_custkey:Int64] @@ -747,7 +747,7 @@ mod tests { plan, @r" Projection: customer.c_custkey [c_custkey:Int64] - LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey AND customer.c_custkey != __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + LeftSemi Join(ComparisonJoin): Filter: customer.c_custkey = __correlated_sq_1.o_custkey AND customer.c_custkey != __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] TableScan: customer [c_custkey:Int64, c_name:Utf8] SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] Projection: orders.o_custkey [o_custkey:Int64] @@ -778,7 +778,7 @@ mod tests { plan, @r" Projection: customer.c_custkey [c_custkey:Int64] - LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey AND customer.c_custkey < __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + LeftSemi Join(ComparisonJoin): Filter: customer.c_custkey = __correlated_sq_1.o_custkey AND customer.c_custkey < __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] TableScan: customer [c_custkey:Int64, c_name:Utf8] SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] Projection: orders.o_custkey [o_custkey:Int64] @@ -810,7 +810,7 @@ mod tests { plan, @r" Projection: customer.c_custkey [c_custkey:Int64] - LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey AND (customer.c_custkey = __correlated_sq_1.o_custkey OR __correlated_sq_1.o_orderkey = Int32(1)) [c_custkey:Int64, c_name:Utf8] + LeftSemi Join(ComparisonJoin): Filter: customer.c_custkey = __correlated_sq_1.o_custkey AND (customer.c_custkey = __correlated_sq_1.o_custkey OR __correlated_sq_1.o_orderkey = Int32(1)) [c_custkey:Int64, c_name:Utf8] TableScan: customer [c_custkey:Int64, c_name:Utf8] SubqueryAlias: __correlated_sq_1 [o_custkey:Int64, o_orderkey:Int64] Projection: orders.o_custkey, orders.o_orderkey [o_custkey:Int64, o_orderkey:Int64] @@ -864,7 +864,7 @@ mod tests { plan, @r" Projection: customer.c_custkey [c_custkey:Int64] - LeftSemi Join: Filter: customer.c_custkey + Int32(1) = __correlated_sq_1.o_custkey AND customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + LeftSemi Join(ComparisonJoin): Filter: customer.c_custkey + Int32(1) = __correlated_sq_1.o_custkey AND customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] TableScan: customer [c_custkey:Int64, c_name:Utf8] SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] Projection: orders.o_custkey [o_custkey:Int64] @@ -895,7 +895,7 @@ mod tests { plan, @r" Projection: customer.c_custkey [c_custkey:Int64] - LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.orders.o_custkey + Int32(1) AND customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + LeftSemi Join(ComparisonJoin): Filter: customer.c_custkey = __correlated_sq_1.orders.o_custkey + Int32(1) AND customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] TableScan: customer [c_custkey:Int64, c_name:Utf8] SubqueryAlias: __correlated_sq_1 [orders.o_custkey + Int32(1):Int64, o_custkey:Int64] Projection: orders.o_custkey + Int32(1), orders.o_custkey [orders.o_custkey + Int32(1):Int64, o_custkey:Int64] @@ -959,7 +959,7 @@ mod tests { @r" Projection: customer.c_custkey [c_custkey:Int64] Filter: customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8] - LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + LeftSemi Join(ComparisonJoin): Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] TableScan: customer [c_custkey:Int64, c_name:Utf8] SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] Projection: orders.o_custkey [o_custkey:Int64] @@ -987,7 +987,7 @@ mod tests { plan, @r" Projection: test.b [b:UInt32] - LeftSemi Join: Filter: test.c = __correlated_sq_1.c AND test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32] + LeftSemi Join(ComparisonJoin): Filter: test.c = __correlated_sq_1.c AND test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32] TableScan: test [a:UInt32, b:UInt32, c:UInt32] SubqueryAlias: __correlated_sq_1 [c:UInt32, a:UInt32] Projection: sq.c, sq.a [c:UInt32, a:UInt32] @@ -1009,7 +1009,7 @@ mod tests { plan, @r" Projection: test.b [b:UInt32] - LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32] + LeftSemi Join(ComparisonJoin): Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32] TableScan: test [a:UInt32, b:UInt32, c:UInt32] SubqueryAlias: __correlated_sq_1 [c:UInt32] Projection: sq.c [c:UInt32] @@ -1031,7 +1031,7 @@ mod tests { plan, @r" Projection: test.b [b:UInt32] - LeftAnti Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32] + LeftAnti Join(ComparisonJoin): Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32] TableScan: test [a:UInt32, b:UInt32, c:UInt32] SubqueryAlias: __correlated_sq_1 [c:UInt32] Projection: sq.c [c:UInt32] @@ -1052,7 +1052,7 @@ mod tests { plan, @r" Projection: test.b [b:UInt32] - LeftAnti Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32] + LeftAnti Join(ComparisonJoin): Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32] TableScan: test [a:UInt32, b:UInt32, c:UInt32] SubqueryAlias: __correlated_sq_1 [c:UInt32] Projection: sq.c [c:UInt32] @@ -1076,7 +1076,7 @@ mod tests { plan, @r" Projection: test.b [b:UInt32] - LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32] + LeftSemi Join(ComparisonJoin): Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32] TableScan: test [a:UInt32, b:UInt32, c:UInt32] SubqueryAlias: __correlated_sq_1 [c:UInt32] Projection: sq.c [c:UInt32] @@ -1103,7 +1103,7 @@ mod tests { plan, @r" Projection: test.b [b:UInt32] - LeftSemi Join: Filter: test.c + UInt32(1) = __correlated_sq_1.sq.c * UInt32(2) [a:UInt32, b:UInt32, c:UInt32] + LeftSemi Join(ComparisonJoin): Filter: test.c + UInt32(1) = __correlated_sq_1.sq.c * UInt32(2) [a:UInt32, b:UInt32, c:UInt32] TableScan: test [a:UInt32, b:UInt32, c:UInt32] SubqueryAlias: __correlated_sq_1 [sq.c * UInt32(2):UInt32] Projection: sq.c * UInt32(2) [sq.c * UInt32(2):UInt32] @@ -1135,7 +1135,7 @@ mod tests { plan, @r" Projection: test.b [b:UInt32] - LeftSemi Join: Filter: test.c + UInt32(1) = __correlated_sq_1.sq.c * UInt32(2) AND test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32] + LeftSemi Join(ComparisonJoin): Filter: test.c + UInt32(1) = __correlated_sq_1.sq.c * UInt32(2) AND test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32] TableScan: test [a:UInt32, b:UInt32, c:UInt32] SubqueryAlias: __correlated_sq_1 [sq.c * UInt32(2):UInt32, a:UInt32] Projection: sq.c * UInt32(2), sq.a [sq.c * UInt32(2):UInt32, a:UInt32] @@ -1169,7 +1169,7 @@ mod tests { plan, @r" Projection: test.b [b:UInt32] - LeftSemi Join: Filter: test.c + UInt32(1) = __correlated_sq_1.sq.c * UInt32(2) AND test.a + test.b = __correlated_sq_1.a + __correlated_sq_1.b [a:UInt32, b:UInt32, c:UInt32] + LeftSemi Join(ComparisonJoin): Filter: test.c + UInt32(1) = __correlated_sq_1.sq.c * UInt32(2) AND test.a + test.b = __correlated_sq_1.a + __correlated_sq_1.b [a:UInt32, b:UInt32, c:UInt32] TableScan: test [a:UInt32, b:UInt32, c:UInt32] SubqueryAlias: __correlated_sq_1 [sq.c * UInt32(2):UInt32, a:UInt32, b:UInt32] Projection: sq.c * UInt32(2), sq.a, sq.b [sq.c * UInt32(2):UInt32, a:UInt32, b:UInt32] @@ -1210,8 +1210,8 @@ mod tests { @r" Projection: test.b [b:UInt32] Filter: test.c > UInt32(1) [a:UInt32, b:UInt32, c:UInt32] - LeftSemi Join: Filter: test.c * UInt32(2) = __correlated_sq_2.sq2.c * UInt32(2) AND test.a > __correlated_sq_2.a [a:UInt32, b:UInt32, c:UInt32] - LeftSemi Join: Filter: test.c + UInt32(1) = __correlated_sq_1.sq1.c * UInt32(2) AND test.a > __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32] + LeftSemi Join(ComparisonJoin): Filter: test.c * UInt32(2) = __correlated_sq_2.sq2.c * UInt32(2) AND test.a > __correlated_sq_2.a [a:UInt32, b:UInt32, c:UInt32] + LeftSemi Join(ComparisonJoin): Filter: test.c + UInt32(1) = __correlated_sq_1.sq1.c * UInt32(2) AND test.a > __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32] TableScan: test [a:UInt32, b:UInt32, c:UInt32] SubqueryAlias: __correlated_sq_1 [sq1.c * UInt32(2):UInt32, a:UInt32] Projection: sq1.c * UInt32(2), sq1.a [sq1.c * UInt32(2):UInt32, a:UInt32] @@ -1242,7 +1242,7 @@ mod tests { plan, @r" Projection: test.b [b:UInt32] - LeftSemi Join: Filter: test.a = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32] + LeftSemi Join(ComparisonJoin): Filter: test.a = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32] TableScan: test [a:UInt32, b:UInt32, c:UInt32] SubqueryAlias: __correlated_sq_1 [c:UInt32] Projection: test.c [c:UInt32] @@ -1274,8 +1274,8 @@ mod tests { plan, @r" Projection: customer.c_custkey [c_custkey:Int64] - LeftSemi Join: Filter: __correlated_sq_2.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8] - LeftSemi Join: Filter: __correlated_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8] + LeftSemi Join(ComparisonJoin): Filter: __correlated_sq_2.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8] + LeftSemi Join(ComparisonJoin): Filter: __correlated_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8] TableScan: customer [c_custkey:Int64, c_name:Utf8] SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] Projection: orders.o_custkey [o_custkey:Int64] @@ -1321,11 +1321,11 @@ mod tests { plan, @r" Projection: customer.c_custkey [c_custkey:Int64] - LeftSemi Join: Filter: __correlated_sq_2.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8] + LeftSemi Join(ComparisonJoin): Filter: __correlated_sq_2.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8] TableScan: customer [c_custkey:Int64, c_name:Utf8] SubqueryAlias: __correlated_sq_2 [o_custkey:Int64] Projection: orders.o_custkey [o_custkey:Int64] - LeftSemi Join: Filter: __correlated_sq_1.l_orderkey = orders.o_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + LeftSemi Join(ComparisonJoin): Filter: __correlated_sq_1.l_orderkey = orders.o_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] SubqueryAlias: __correlated_sq_1 [l_orderkey:Int64] Projection: lineitem.l_orderkey [l_orderkey:Int64] @@ -1357,7 +1357,7 @@ mod tests { plan, @r" Projection: customer.c_custkey [c_custkey:Int64] - LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + LeftSemi Join(ComparisonJoin): Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] TableScan: customer [c_custkey:Int64, c_name:Utf8] SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] Projection: orders.o_custkey [o_custkey:Int64] @@ -1386,7 +1386,7 @@ mod tests { plan, @r" Projection: customer.c_custkey [c_custkey:Int64] - LeftSemi Join: Filter: customer.c_custkey = UInt32(1) [c_custkey:Int64, c_name:Utf8] + LeftSemi Join(ComparisonJoin): Filter: customer.c_custkey = UInt32(1) [c_custkey:Int64, c_name:Utf8] TableScan: customer [c_custkey:Int64, c_name:Utf8] SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] Projection: orders.o_custkey [o_custkey:Int64] @@ -1414,7 +1414,7 @@ mod tests { plan, @r" Projection: customer.c_custkey [c_custkey:Int64] - LeftSemi Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8] + LeftSemi Join(ComparisonJoin): Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8] TableScan: customer [c_custkey:Int64, c_name:Utf8] SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] Projection: orders.o_custkey [o_custkey:Int64] @@ -1446,7 +1446,7 @@ mod tests { plan, @r" Projection: customer.c_custkey [c_custkey:Int64] - LeftSemi Join: Filter: customer.c_custkey != __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + LeftSemi Join(ComparisonJoin): Filter: customer.c_custkey != __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] TableScan: customer [c_custkey:Int64, c_name:Utf8] SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] Projection: orders.o_custkey [o_custkey:Int64] @@ -1477,7 +1477,7 @@ mod tests { plan, @r" Projection: customer.c_custkey [c_custkey:Int64] - LeftSemi Join: Filter: customer.c_custkey < __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + LeftSemi Join(ComparisonJoin): Filter: customer.c_custkey < __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] TableScan: customer [c_custkey:Int64, c_name:Utf8] SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] Projection: orders.o_custkey [o_custkey:Int64] @@ -1509,7 +1509,7 @@ mod tests { plan, @r" Projection: customer.c_custkey [c_custkey:Int64] - LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey OR __correlated_sq_1.o_orderkey = Int32(1) [c_custkey:Int64, c_name:Utf8] + LeftSemi Join(ComparisonJoin): Filter: customer.c_custkey = __correlated_sq_1.o_custkey OR __correlated_sq_1.o_orderkey = Int32(1) [c_custkey:Int64, c_name:Utf8] TableScan: customer [c_custkey:Int64, c_name:Utf8] SubqueryAlias: __correlated_sq_1 [o_custkey:Int64, o_orderkey:Int64] Projection: orders.o_custkey, orders.o_orderkey [o_custkey:Int64, o_orderkey:Int64] @@ -1539,7 +1539,7 @@ mod tests { plan, @r" Projection: customer.c_custkey [c_custkey:Int64] - LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + LeftSemi Join(ComparisonJoin): Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] TableScan: customer [c_custkey:Int64, c_name:Utf8] SubqueryAlias: __correlated_sq_1 [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] @@ -1569,7 +1569,7 @@ mod tests { plan, @r" Projection: customer.c_custkey [c_custkey:Int64] - LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + LeftSemi Join(ComparisonJoin): Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] TableScan: customer [c_custkey:Int64, c_name:Utf8] SubqueryAlias: __correlated_sq_1 [orders.o_custkey + Int32(1):Int64, o_custkey:Int64] Projection: orders.o_custkey + Int32(1), orders.o_custkey [orders.o_custkey + Int32(1):Int64, o_custkey:Int64] @@ -1600,7 +1600,7 @@ mod tests { @r" Projection: customer.c_custkey [c_custkey:Int64] Filter: customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8] - LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + LeftSemi Join(ComparisonJoin): Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] TableScan: customer [c_custkey:Int64, c_name:Utf8] SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] Projection: orders.o_custkey [o_custkey:Int64] @@ -1629,7 +1629,7 @@ mod tests { @r" Projection: customer.c_custkey [c_custkey:Int64] Filter: __correlated_sq_1.mark OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, mark:Boolean] - LeftMark Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, mark:Boolean] + LeftMark Join(ComparisonJoin): Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, mark:Boolean] TableScan: customer [c_custkey:Int64, c_name:Utf8] SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] Projection: orders.o_custkey [o_custkey:Int64] @@ -1658,7 +1658,7 @@ mod tests { plan, @r" Projection: test.c [c:UInt32] - LeftSemi Join: Filter: test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32] + LeftSemi Join(ComparisonJoin): Filter: test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32] TableScan: test [a:UInt32, b:UInt32, c:UInt32] SubqueryAlias: __correlated_sq_1 [c:UInt32, a:UInt32] Projection: sq.c, sq.a [c:UInt32, a:UInt32] @@ -1680,7 +1680,7 @@ mod tests { plan, @r" Projection: test.b [b:UInt32] - LeftSemi Join: Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32] + LeftSemi Join(ComparisonJoin): Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32] TableScan: test [a:UInt32, b:UInt32, c:UInt32] SubqueryAlias: __correlated_sq_1 [c:UInt32] Projection: sq.c [c:UInt32] @@ -1702,7 +1702,7 @@ mod tests { plan, @r" Projection: test.b [b:UInt32] - LeftAnti Join: Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32] + LeftAnti Join(ComparisonJoin): Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32] TableScan: test [a:UInt32, b:UInt32, c:UInt32] SubqueryAlias: __correlated_sq_1 [c:UInt32] Projection: sq.c [c:UInt32] @@ -1740,8 +1740,8 @@ mod tests { @r" Projection: test.b [b:UInt32] Filter: test.c > UInt32(1) [a:UInt32, b:UInt32, c:UInt32] - LeftSemi Join: Filter: test.a = __correlated_sq_2.a [a:UInt32, b:UInt32, c:UInt32] - LeftSemi Join: Filter: test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32] + LeftSemi Join(ComparisonJoin): Filter: test.a = __correlated_sq_2.a [a:UInt32, b:UInt32, c:UInt32] + LeftSemi Join(ComparisonJoin): Filter: test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32] TableScan: test [a:UInt32, b:UInt32, c:UInt32] SubqueryAlias: __correlated_sq_1 [c:UInt32, a:UInt32] Projection: sq1.c, sq1.a [c:UInt32, a:UInt32] @@ -1773,7 +1773,7 @@ mod tests { plan, @r" Projection: test.b [b:UInt32] - LeftSemi Join: Filter: UInt32(1) + __correlated_sq_1.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32] + LeftSemi Join(ComparisonJoin): Filter: UInt32(1) + __correlated_sq_1.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32] TableScan: test [a:UInt32, b:UInt32, c:UInt32] SubqueryAlias: __correlated_sq_1 [UInt32(1):UInt32, a:UInt32] Projection: UInt32(1), sq.a [UInt32(1):UInt32, a:UInt32] @@ -1801,7 +1801,7 @@ mod tests { plan, @r" Projection: test.b [b:UInt32] - LeftSemi Join: Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32] + LeftSemi Join(ComparisonJoin): Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32] TableScan: test [a:UInt32, b:UInt32, c:UInt32] SubqueryAlias: __correlated_sq_1 [c:UInt32] Projection: test.c [c:UInt32] @@ -1832,7 +1832,7 @@ mod tests { plan, @r" Projection: test.b [b:UInt32] - LeftSemi Join: Filter: UInt32(1) + __correlated_sq_1.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32] + LeftSemi Join(ComparisonJoin): Filter: UInt32(1) + __correlated_sq_1.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32] TableScan: test [a:UInt32, b:UInt32, c:UInt32] SubqueryAlias: __correlated_sq_1 [c:UInt32, a:UInt32] Distinct: [c:UInt32, a:UInt32] @@ -1863,7 +1863,7 @@ mod tests { plan, @r" Projection: test.b [b:UInt32] - LeftSemi Join: Filter: UInt32(1) + __correlated_sq_1.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32] + LeftSemi Join(ComparisonJoin): Filter: UInt32(1) + __correlated_sq_1.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32] TableScan: test [a:UInt32, b:UInt32, c:UInt32] SubqueryAlias: __correlated_sq_1 [sq.b + sq.c:UInt32, a:UInt32] Distinct: [sq.b + sq.c:UInt32, a:UInt32] @@ -1894,7 +1894,7 @@ mod tests { plan, @r" Projection: test.b [b:UInt32] - LeftSemi Join: Filter: UInt32(1) + __correlated_sq_1.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32] + LeftSemi Join(ComparisonJoin): Filter: UInt32(1) + __correlated_sq_1.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32] TableScan: test [a:UInt32, b:UInt32, c:UInt32] SubqueryAlias: __correlated_sq_1 [UInt32(1):UInt32, c:UInt32, a:UInt32] Distinct: [UInt32(1):UInt32, c:UInt32, a:UInt32] @@ -1929,7 +1929,7 @@ mod tests { plan, @r#" Projection: test.b [b:UInt32] - LeftSemi Join: Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32] + LeftSemi Join(ComparisonJoin): Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32] TableScan: test [a:UInt32, b:UInt32, c:UInt32] SubqueryAlias: __correlated_sq_1 [arr:Int32;N] Unnest: lists[sq.arr|depth=1] structs[] [arr:Int32;N] @@ -1964,7 +1964,7 @@ mod tests { plan, @r#" Projection: test.b [b:UInt32] - LeftSemi Join: Filter: __correlated_sq_1.a = test.b [a:UInt32, b:UInt32, c:UInt32] + LeftSemi Join(ComparisonJoin): Filter: __correlated_sq_1.a = test.b [a:UInt32, b:UInt32, c:UInt32] TableScan: test [a:UInt32, b:UInt32, c:UInt32] SubqueryAlias: __correlated_sq_1 [a:UInt32;N] Unnest: lists[sq.a|depth=1] structs[] [a:UInt32;N] @@ -1998,7 +1998,7 @@ mod tests { plan, @r" Projection: TEST_A.B [B:UInt32] - LeftSemi Join: Filter: __correlated_sq_1.A = TEST_A.A [A:UInt32, B:UInt32] + LeftSemi Join(ComparisonJoin): Filter: __correlated_sq_1.A = TEST_A.A [A:UInt32, B:UInt32] TableScan: TEST_A [A:UInt32, B:UInt32] SubqueryAlias: __correlated_sq_1 [Int32(1):Int32, A:UInt32] Projection: Int32(1), TEST_B.A [Int32(1):Int32, A:UInt32] diff --git a/datafusion/optimizer/src/delim_candidate_rewriter.rs b/datafusion/optimizer/src/delim_candidate_rewriter.rs new file mode 100644 index 000000000000..79395193adb9 --- /dev/null +++ b/datafusion/optimizer/src/delim_candidate_rewriter.rs @@ -0,0 +1,119 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; +use datafusion_common::{internal_datafusion_err, Result}; +use datafusion_expr::LogicalPlan; +use indexmap::IndexMap; + +use crate::delim_candidates_collector::{DelimCandidate, JoinWithDelimScan}; + +type ID = usize; + +pub struct DelimCandidateRewriter { + candidates: IndexMap, + joins: IndexMap, + cur_id: ID, + // all the node ids from root to the current node + stack: Vec, +} + +impl DelimCandidateRewriter { + pub fn new( + candidates: IndexMap, + joins: IndexMap, + ) -> Self { + Self { + candidates, + joins, + cur_id: 0, + stack: vec![], + } + } +} + +impl TreeNodeRewriter for DelimCandidateRewriter { + type Node = LogicalPlan; + + fn f_down(&mut self, plan: LogicalPlan) -> Result> { + self.stack.push(self.cur_id); + self.cur_id += 1; + + Ok(Transformed::no(plan)) + } + + fn f_up(&mut self, plan: LogicalPlan) -> Result> { + let mut transformed = Transformed::no(plan); + + let cur_id = self.stack.pop().ok_or(internal_datafusion_err!( + "stack cannot be empty during upward traversal" + ))?; + + let mut diff = 0; + if let Some(candidate) = self.candidates.get(&cur_id) { + if candidate.is_transformed { + return Ok(Transformed::yes(candidate.node.plan.clone())); + } + } else if let Some(join_with_delim_scan) = self.joins.get(&cur_id) { + if join_with_delim_scan.can_be_eliminated { + // let prev_sub_plan_size = join_with_delim_scan.node.sub_plan_size; + // let mut cur_sub_plan_size = 1; + // if join_with_delim_scan.is_filter_generated { + // cur_sub_plan_size += 1; + // } + + // // perv_sub_plan_size should be larger than cur_sub_plan_size. + // diff = prev_sub_plan_size - cur_sub_plan_size; + + diff = 2; + if join_with_delim_scan.is_filter_generated { + diff = 1; + } + + transformed = Transformed::yes( + join_with_delim_scan + .replacement_plan + .clone() + .ok_or(internal_datafusion_err!( + "stack cannot be empty during upward traversal" + ))? + .as_ref() + .clone(), + ); + } + } + + // update tree nodes with id > cur_id. + if diff != 0 { + let keys_to_update: Vec = self + .joins + .keys() + .filter(|&&id| id > cur_id) + .copied() + .collect(); + for old_id in keys_to_update { + if let Some(mut join) = self.joins.swap_remove(&old_id) { + let new_id = old_id - diff; + join.node.id = new_id; + self.joins.insert(new_id, join); + } + } + } + + Ok(transformed) + } +} diff --git a/datafusion/optimizer/src/delim_candidates_collector.rs b/datafusion/optimizer/src/delim_candidates_collector.rs new file mode 100644 index 000000000000..3279dbf1a076 --- /dev/null +++ b/datafusion/optimizer/src/delim_candidates_collector.rs @@ -0,0 +1,560 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}; +use datafusion_common::{internal_datafusion_err, DataFusionError, Result}; +use datafusion_expr::{JoinKind, LogicalPlan}; +use indexmap::IndexMap; + +type ID = usize; + +#[derive(Clone)] +pub struct Node { + pub plan: LogicalPlan, + pub id: ID, + // subplan size of current node. + pub sub_plan_size: usize, +} + +impl Node { + fn new(plan: LogicalPlan, id: ID, sub_plan_size: usize) -> Self { + Self { + plan, + id, + sub_plan_size, + } + } +} + +#[derive(Clone)] +pub struct JoinWithDelimScan { + // Join node under DelimCandidate. + pub node: Node, + pub depth: usize, + pub can_be_eliminated: bool, + pub is_filter_generated: bool, + pub replacement_plan: Option>, +} + +impl JoinWithDelimScan { + fn new(plan: LogicalPlan, id: ID, depth: usize, sub_plan_size: usize) -> Self { + Self { + node: Node::new(plan, id, sub_plan_size), + depth, + can_be_eliminated: false, + is_filter_generated: false, + replacement_plan: None, + } + } +} + +pub struct DelimCandidate { + pub node: Node, + pub joins: Vec, + pub delim_scan_count: usize, + pub is_transformed: bool, +} + +impl DelimCandidate { + fn new(plan: LogicalPlan, id: ID, sub_plan_size: usize) -> Self { + Self { + node: Node::new(plan, id, sub_plan_size), + joins: vec![], + delim_scan_count: 0, + is_transformed: false, + } + } +} + +pub struct NodeVisitor { + nodes: IndexMap, + cur_id: ID, + // all the node ids from root to the current node + stack: Vec, +} + +impl NodeVisitor { + pub fn new() -> Self { + Self { + nodes: IndexMap::new(), + cur_id: 0, + stack: vec![], + } + } + + pub fn collect_nodes(&mut self, plan: &LogicalPlan) -> Result<()> { + plan.apply(|plan| { + self.nodes + .insert(self.cur_id, Node::new(plan.clone(), self.cur_id, 0)); + self.cur_id += 1; + + Ok(TreeNodeRecursion::Continue) + })?; + + // reset current id + self.cur_id = 0; + + plan.visit(self)?; + + println!("\n=== Nodes after visit ==="); + for (id, node) in &self.nodes { + println!( + "Node ID: {}, Type: {:?}, SubPlan Size: {}", + id, + node.plan.display().to_string(), + node.sub_plan_size + ); + } + println!("======================\n"); + + Ok(()) + } +} + +impl TreeNodeVisitor<'_> for NodeVisitor { + type Node = LogicalPlan; + + fn f_down(&mut self, _plan: &LogicalPlan) -> Result { + self.stack.push(self.cur_id); + self.cur_id += 1; + + Ok(TreeNodeRecursion::Continue) + } + + fn f_up(&mut self, plan: &LogicalPlan) -> Result { + let cur_id = self.stack.pop().ok_or(internal_datafusion_err!( + "stack cannot be empty during upward traversal" + ))?; + + // Calculate subplan size: 1 (current node) + sum of children's subplan sizes. + let mut subplan_size = 1; + let mut child_id = cur_id + 1; + plan.apply_children(|_| { + if let Some(child_node) = self.nodes.get(&child_id) { + subplan_size += child_node.sub_plan_size; + child_id = child_id + child_node.sub_plan_size; + } + + Ok(TreeNodeRecursion::Continue) + })?; + + // Store the subplan size for current node. + self.nodes + .get_mut(&cur_id) + .ok_or_else(|| { + DataFusionError::Plan( + "Node should exist when calculating subplan size".to_string(), + ) + })? + .sub_plan_size = subplan_size; + + Ok(TreeNodeRecursion::Continue) + } +} + +pub struct DelimCandidateVisitor { + pub candidates: IndexMap, + node_visitor: NodeVisitor, + cur_id: ID, + // all the node ids from root to the current node + stack: Vec, +} + +impl DelimCandidateVisitor { + pub fn new(node_visitor: NodeVisitor) -> Self { + Self { + candidates: IndexMap::new(), + node_visitor, + cur_id: 0, + stack: vec![], + } + } +} + +impl TreeNodeVisitor<'_> for DelimCandidateVisitor { + type Node = LogicalPlan; + + fn f_down(&mut self, _plan: &LogicalPlan) -> Result { + self.stack.push(self.cur_id); + self.cur_id += 1; + + Ok(TreeNodeRecursion::Continue) + } + + fn f_up(&mut self, plan: &LogicalPlan) -> Result { + let cur_id = self.stack.pop().ok_or(internal_datafusion_err!( + "stack cannot be empty during upward traversal" + ))?; + + if let LogicalPlan::Join(join) = plan { + if join.join_kind == JoinKind::DelimJoin { + let sub_plan_size = self + .node_visitor + .nodes + .get(&cur_id) + .ok_or_else(|| { + DataFusionError::Plan("current node should exist".to_string()) + })? + .sub_plan_size; + + self.candidates.insert( + cur_id, + DelimCandidate::new(plan.clone(), cur_id, sub_plan_size), + ); + + let left_id = cur_id + 1; + // We calculate the right child id from left child's subplan size. + let right_id = self + .node_visitor + .nodes + .get(&left_id) + .ok_or_else(|| { + DataFusionError::Plan(format!( + "left id {} should exist in join", + left_id + )) + })? + .sub_plan_size + + left_id; + + let mut candidate = self + .candidates + .get_mut(&cur_id) + .ok_or_else(|| internal_datafusion_err!("Candidate should exist"))?; + let right_plan = &self + .node_visitor + .nodes + .get(&right_id) + .ok_or_else(|| { + DataFusionError::Plan( + "right child should exist in join".to_string(), + ) + })? + .plan; + + // DelimScan are in the RHS. + let mut collector = DelimCandidatesCollector::new( + &self.node_visitor, + &mut candidate, + 0, + right_id, + ); + right_plan.visit(&mut collector)?; + } + } + + Ok(TreeNodeRecursion::Continue) + } +} + +struct DelimCandidatesCollector<'a> { + node_visitor: &'a NodeVisitor, + candidate: &'a mut DelimCandidate, + depth: usize, + cur_id: ID, + // all the node ids from root to the current node + stack: Vec, +} + +impl<'a> DelimCandidatesCollector<'a> { + fn new( + node_visitor: &'a NodeVisitor, + candidate: &'a mut DelimCandidate, + depth: usize, + cur_id: ID, + ) -> Self { + Self { + node_visitor, + candidate, + depth, + cur_id, + stack: vec![], + } + } +} + +impl<'n> TreeNodeVisitor<'n> for DelimCandidatesCollector<'_> { + type Node = LogicalPlan; + + fn f_down(&mut self, _plan: &LogicalPlan) -> Result { + self.stack.push(self.cur_id); + self.cur_id += 1; + + Ok(TreeNodeRecursion::Continue) + } + + fn f_up(&mut self, plan: &LogicalPlan) -> Result { + let mut recursion = TreeNodeRecursion::Continue; + + let cur_id = self.stack.pop().ok_or(internal_datafusion_err!( + "stack cannot be empty during upward traversal" + ))?; + + match plan { + LogicalPlan::Join(join) => { + if join.join_kind == JoinKind::DelimJoin { + // iterate left child + let left_child_id = cur_id + 1; + let left_plan = &self + .node_visitor + .nodes + .get(&left_child_id) + .ok_or_else(|| { + DataFusionError::Plan( + "left child should exist in join".to_string(), + ) + })? + .plan; + let mut new_collector = DelimCandidatesCollector::new( + &self.node_visitor, + &mut self.candidate, + self.depth + 1, + cur_id + 1, + ); + left_plan.visit(&mut new_collector)?; + + recursion = TreeNodeRecursion::Stop; + } + } + LogicalPlan::DelimGet(_) => { + self.candidate.delim_scan_count += 1; + } + _ => {} + } + + if let LogicalPlan::Join(join) = plan { + if join.join_kind == JoinKind::ComparisonJoin + && (plan_is_delim_scan(join.left.as_ref()) + || plan_is_delim_scan(join.right.as_ref())) + { + let sub_plan_size = self + .node_visitor + .nodes + .get(&cur_id) + .ok_or_else(|| { + DataFusionError::Plan("current node should exist".to_string()) + })? + .sub_plan_size; + + self.candidate.joins.push(JoinWithDelimScan::new( + plan.clone(), + cur_id, + self.depth, + sub_plan_size, + )); + } + } + + self.depth += 1; + + Ok(recursion) + } +} + +fn plan_is_delim_scan(plan: &LogicalPlan) -> bool { + match plan { + LogicalPlan::Filter(filter) => { + if let LogicalPlan::DelimGet(_) = filter.input.as_ref() { + true + } else { + false + } + } + LogicalPlan::DelimGet(_) => true, + _ => false, + } +} + +#[cfg(test)] +mod tests { + use crate::delim_candidates_collector::NodeVisitor; + use crate::test::test_table_scan_with_name; + use datafusion_common::Result; + use datafusion_expr::{expr_fn::col, lit, JoinType, LogicalPlan, LogicalPlanBuilder}; + + #[test] + fn test_collect_nodes() -> Result<()> { + let table = test_table_scan_with_name("t1")?; + let plan = LogicalPlanBuilder::from(table) + .filter(col("t1.a").eq(lit(1)))? + .project(vec![col("t1.a")])? + .build()?; + + // Projection: t1.a + // Filter: t1.a = Int32(1) + // TableScan: t1 + + let mut visitor = NodeVisitor::new(); + visitor.collect_nodes(&plan)?; + + assert_eq!(visitor.nodes.len(), 3); + + match visitor.nodes.get(&2).unwrap().plan { + LogicalPlan::TableScan(_) => (), + _ => panic!("Expected TableScan at id 2"), + } + + match visitor.nodes.get(&1).unwrap().plan { + LogicalPlan::Filter(_) => (), + _ => panic!("Expected Filter at id 1"), + } + + match visitor.nodes.get(&0).unwrap().plan { + LogicalPlan::Projection(_) => (), + _ => panic!("Expected Projection at id 0"), + } + + Ok(()) + } + + #[test] + fn test_collect_nodes_with_subplan_size() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + // Build left side: Filter -> TableScan t1 + let left = LogicalPlanBuilder::from(t1) + .filter(col("t1.a").eq(lit(1)))? + .build()?; + + // Build right side: Filter -> TableScan t2 + let right = LogicalPlanBuilder::from(t2) + .filter(col("t2.a").eq(lit(2)))? + .build()?; + + // Join them together + let plan = LogicalPlanBuilder::from(left) + .join(right, JoinType::Inner, (vec!["a"], vec!["a"]), None)? + .project(vec![col("t1.a")])? + .build()?; + + // Projection: t1.a + // Inner Join(ComparisonJoin): t1.a = t2.a + // Filter: t1.a = Int32(1) + // TableScan: t1 + // Filter: t2.a = Int32(2) + // TableScan: t2 + + let mut visitor = NodeVisitor::new(); + visitor.collect_nodes(&plan)?; + + // Verify nodes count + assert_eq!(visitor.nodes.len(), 6); + + // Verify subplan sizes: + // TableScan t1 (id: 5) - size 1 (just itself) + assert_eq!(visitor.nodes.get(&5).unwrap().sub_plan_size, 1); + + // TableScan t2 (id: 3) - size 1 (just itself) + assert_eq!(visitor.nodes.get(&3).unwrap().sub_plan_size, 1); + + // Filter t1 (id: 4) - size 2 (itself + TableScan) + assert_eq!(visitor.nodes.get(&4).unwrap().sub_plan_size, 2); + + // Filter t2 (id: 2) - size 2 (itself + TableScan) + assert_eq!(visitor.nodes.get(&2).unwrap().sub_plan_size, 2); + + // Join (id: 1) - size 5 (itself + both Filter subtrees) + assert_eq!(visitor.nodes.get(&1).unwrap().sub_plan_size, 5); + + // Projection (id: 0) - size 6 (entire tree) + assert_eq!(visitor.nodes.get(&0).unwrap().sub_plan_size, 6); + + Ok(()) + } + + #[test] + fn test_complex_node_collection() -> Result<()> { + // Build a complex plan: + // Project + // | + // Join + // / \ + // Filter Join + // | / \ + // TableScan Filter TableScan + // (t1) | (t4) + // Filter + // | + // TableScan + // (t2) + + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + let t4 = test_table_scan_with_name("t4")?; + + // Left branch: Filter -> TableScan t1 + let left = LogicalPlanBuilder::from(t1) + .filter(col("t1.a").eq(lit(1)))? + .build()?; + + // Right branch: + // First build inner join + let right_left = LogicalPlanBuilder::from(t2) + .filter(col("t2.b").eq(lit(2)))? + .filter(col("t2.c").eq(lit(3)))? + .build()?; + + let right = LogicalPlanBuilder::from(right_left) + .join(t4, JoinType::Inner, (vec!["b"], vec!["b"]), None)? + .build()?; + + // Final plan: Join the branches and project + let plan = LogicalPlanBuilder::from(left) + .join(right, JoinType::Inner, (vec!["t1.a"], vec!["t2.a"]), None)? + .project(vec![col("t1.a"), col("t2.b"), col("t4.c")])? + .build()?; + + // Projection: t1.a, t2.b, t4.c + // Inner Join(ComparisonJoin): t1.a = t2.a + // Filter: t1.a = Int32(1) + // TableScan: t1 + // Inner Join(ComparisonJoin): t2.b = t4.b + // Filter: t2.c = Int32(3) + // Filter: t2.b = Int32(2) + // TableScan: t2 + // TableScan: t4 + + let mut visitor = NodeVisitor::new(); + visitor.collect_nodes(&plan)?; + + // Add assertions to verify the structure + assert_eq!(visitor.nodes.len(), 9); // Total number of nodes + + // Verify some key subplan sizes + // Leaf nodes should have size 1 + assert_eq!(visitor.nodes.get(&8).unwrap().sub_plan_size, 1); // TableScan t1 + assert_eq!(visitor.nodes.get(&7).unwrap().sub_plan_size, 1); // TableScan t2 + assert_eq!(visitor.nodes.get(&3).unwrap().sub_plan_size, 1); // TableScan t4 + + // Mid-level nodes + assert_eq!(visitor.nodes.get(&2).unwrap().sub_plan_size, 2); // Filter -> t1 + assert_eq!(visitor.nodes.get(&6).unwrap().sub_plan_size, 2); // First Filter -> t2 + assert_eq!(visitor.nodes.get(&5).unwrap().sub_plan_size, 3); // Second Filter -> Filter -> t2 + assert_eq!(visitor.nodes.get(&4).unwrap().sub_plan_size, 5); // Join -> (Filter chain, t4) + + // Top-level nodes + assert_eq!(visitor.nodes.get(&1).unwrap().sub_plan_size, 8); // Top Join + assert_eq!(visitor.nodes.get(&0).unwrap().sub_plan_size, 9); // Project + + Ok(()) + } + + // TODO: add test for candidate collector. +} diff --git a/datafusion/optimizer/src/deliminator.rs b/datafusion/optimizer/src/deliminator.rs new file mode 100644 index 000000000000..d00d7c8bf274 --- /dev/null +++ b/datafusion/optimizer/src/deliminator.rs @@ -0,0 +1,890 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::delim_candidate_rewriter::DelimCandidateRewriter; +use crate::delim_candidates_collector::{ + DelimCandidateVisitor, JoinWithDelimScan, NodeVisitor, +}; +use crate::{ApplyOrder, OptimizerConfig, OptimizerRule}; +use datafusion_common::tree_node::{ + Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter, +}; +use datafusion_common::{internal_err, Column, DataFusionError, Result}; +use datafusion_expr::utils::{conjunction, split_conjunction}; +use datafusion_expr::{ + Expr, Filter, Join, JoinKind, JoinType, LogicalPlan, Operator, Projection, +}; +use indexmap::IndexMap; + +/// The Deliminator optimizer traverses the logical operator tree and removes any +/// redundant DelimScan/DelimJoins. +#[derive(Debug)] +pub struct Deliminator {} + +impl Deliminator { + pub fn new() -> Self { + return Deliminator {}; + } +} + +impl OptimizerRule for Deliminator { + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( + &self, + plan: LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result> { + // TODO: Integrated with decrrelator + // let transformer = DecorrelateDependentJoin::new(); + // let rewrite_result = transformer.rewrite(plan, config)?; + + let rewrite_result = Transformed::no(plan); + + let mut node_visitor = NodeVisitor::new(); + let _ = node_visitor.collect_nodes(&rewrite_result.data)?; + let mut candidate_visitor = DelimCandidateVisitor::new(node_visitor); + let _ = rewrite_result.data.visit(&mut candidate_visitor)?; + for (_, candidate) in candidate_visitor.candidates.iter() { + println!("=== DelimCandidate ==="); + println!(" plan: {}", candidate.node.plan.display()); + println!(" delim_get_count: {}", candidate.delim_scan_count); + println!(" joins: ["); + for join in &candidate.joins { + println!(" JoinWithDelimGet {{"); + println!(" id: {}", join.node.id); + println!(" depth: {}", join.depth); + println!(" join: {}", join.node.plan.display()); + println!(" }},"); + } + println!(" ]"); + println!("==================\n"); + } + + if candidate_visitor.candidates.is_empty() { + return Ok(rewrite_result); + } + + let mut replacement_cols: Vec<(Column, Column)> = vec![]; + for (_, candidate) in candidate_visitor.candidates.iter_mut() { + let delim_join = &mut candidate.node.plan; + + // Sort these so the deepest are first. + candidate.joins.sort_by(|a, b| a.depth.cmp(&b.depth)); + + let mut all_removed = true; + if !candidate.joins.is_empty() { + let mut has_selection = false; + delim_join.apply(|plan| { + match plan { + LogicalPlan::TableScan(table_scan) => { + for expr in &table_scan.filters { + if !matches!(expr, Expr::IsNotNull(_)) { + has_selection = true; + return Ok(TreeNodeRecursion::Stop); + } + } + } + LogicalPlan::Filter(_) => { + has_selection = true; + return Ok(TreeNodeRecursion::Stop); + } + _ => {} + } + + Ok(TreeNodeRecursion::Continue) + })?; + + if has_selection { + // Keep the deepest join with DelimScan in these cases, + // as the selection can greatly reduce the cost of the RHS child of the + // DelimJoin. + candidate.joins.remove(0); + all_removed = false; + } + + let delim_join = if let LogicalPlan::Join(join) = delim_join { + join + } else { + return internal_err!("unreachable"); + }; + + let mut all_equality_conditions = true; + let mut is_transformed = false; + for join in &mut candidate.joins { + all_removed = remove_join_with_delim_scan( + delim_join, + candidate.delim_scan_count, + join, + &mut all_equality_conditions, + &mut is_transformed, + &mut replacement_cols, + )?; + } + + // Change type if there are no more duplicate-eliminated columns. + if candidate.joins.len() == candidate.delim_scan_count && all_removed { + is_transformed |= true; + delim_join.join_kind = JoinKind::ComparisonJoin; + // TODO: clear duplicate eliminated columns if any, or should it have? + } + + // Only DelimJoins are ever created as SINGLE joins, and we can switch from SINGLE + // to LEFT if the RHS is de-deuplicated by an aggr. + // TODO: add single join support and try switch single to left. + + candidate.is_transformed = is_transformed; + } + } + + // Replace all with candidate. + let mut joins = IndexMap::new(); + for candidate in candidate_visitor.candidates.values() { + for join in &candidate.joins { + joins.insert(join.node.id, join.clone()); + } + } + + println!("\n=== Processing All Joins ==="); + let mut joins = IndexMap::new(); + for candidate in candidate_visitor.candidates.values() { + for join in &candidate.joins { + println!(" Join {{"); + println!(" id: {}", join.node.id); + println!(" depth: {}", join.depth); + println!(" plan: {}", join.node.plan.display()); + if let Some(replacement) = &join.replacement_plan { + println!(" replacement_plan: {}", replacement.display()); + } + println!(" can_be_eliminated: {}", join.can_be_eliminated); + println!(" is_filter_generated: {}", join.is_filter_generated); + println!(" }},"); + joins.insert(join.node.id, join.clone()); + } + } + println!("========================\n"); + + let mut rewriter = + DelimCandidateRewriter::new(candidate_visitor.candidates, joins); + let rewrite_result = rewrite_result.data.rewrite(&mut rewriter)?; + + // Replace all columns. + let mut rewriter = ColumnRewriter::new(replacement_cols); + let mut rewrite_result = rewrite_result.data.rewrite(&mut rewriter)?; + + // TODO + rewrite_result.transformed = true; + + Ok(rewrite_result) + } + + fn name(&self) -> &str { + "deliminator" + } + + fn apply_order(&self) -> Option { + None + } +} + +fn remove_join_with_delim_scan( + delim_join: &mut Join, + delim_scan_count: usize, + join_with_delim_scan: &mut JoinWithDelimScan, + all_equality_conditions: &mut bool, + is_transformed: &mut bool, + replacement_cols: &mut Vec<(Column, Column)>, +) -> Result { + let join_plan = &join_with_delim_scan.node.plan; + if let LogicalPlan::Join(join) = join_plan { + if !child_join_type_can_be_deliminated(join.join_type) { + return Ok(false); + } + + // Fetch filter (if any) and delim scan. + let mut is_delim_side_left = true; + let mut plan_pair = fetch_delim_scan(join.left.as_ref()); + if plan_pair.1.is_none() { + is_delim_side_left = false; + plan_pair = fetch_delim_scan(join.right.as_ref()); + } + + // Collect filter exprs. + let mut filter_expressions = vec![]; + if let Some(plan) = plan_pair.0 { + if let LogicalPlan::Filter(filter) = plan { + for expr in split_conjunction(&filter.predicate) { + filter_expressions.push(expr.clone()); + } + } + } + + let delim_scan = plan_pair + .1 + .ok_or_else(|| DataFusionError::Plan("No delim scan found".to_string()))?; + let delim_scan = if let LogicalPlan::DelimGet(delim_scan) = delim_scan { + delim_scan + } else { + return internal_err!("unreachable"); + }; + + // Check if joining with the DelimScan is redundant, and collect relevant column + // information. + if let Some(filter) = &join.filter { + let conditions = split_conjunction(filter); + + if conditions.len() != delim_scan.columns.len() { + // Joining with delim scan adds new information. + return Ok(false); + } + + for condition in conditions { + if let Expr::BinaryExpr(binary_expr) = condition { + *all_equality_conditions &= is_equality_join_condition(&condition); + + if !matches!(*binary_expr.left, Expr::Column(_)) + || !matches!(*binary_expr.right, Expr::Column(_)) + { + return Ok(false); + } + + let (left_col, right_col) = + if let (Expr::Column(left), Expr::Column(right)) = + (&*binary_expr.left, &*binary_expr.right) + { + (left.clone(), right.clone()) + } else { + return internal_err!("unreachable"); + }; + + if is_delim_side_left { + replacement_cols.push((left_col, right_col)); + } else { + replacement_cols.push((right_col, left_col)); + } + + if !matches!(binary_expr.op, Operator::IsNotDistinctFrom) { + let is_not_null_expr = if is_delim_side_left { + binary_expr.right.clone().is_not_null() + } else { + binary_expr.left.clone().is_not_null() + }; + filter_expressions.push(is_not_null_expr); + } + } + } + } + + if !*all_equality_conditions + && !remove_inequality_join_with_delim_scan( + delim_join, + delim_scan_count, + join_plan, + is_transformed, + )? + { + return Ok(false); + } + + // All conditions passed, we can eliminate this join + DelimScan + join_with_delim_scan.can_be_eliminated = true; + let mut replacement_plan = if is_delim_side_left { + join.right.clone() + } else { + join.left.clone() + }; + if !filter_expressions.is_empty() { + replacement_plan = LogicalPlan::Filter(Filter::try_new( + conjunction(filter_expressions).ok_or_else(|| { + DataFusionError::Plan("filter expressions must exist".to_string()) + })?, + replacement_plan, + )?) + .into(); + join_with_delim_scan.is_filter_generated = true; + } + join_with_delim_scan.replacement_plan = Some(replacement_plan); + + return Ok(true); + } else { + return internal_err!("current plan must be join in remove_join_with_delim_scan"); + } +} + +fn is_equality_join_condition(expr: &Expr) -> bool { + if let Expr::BinaryExpr(binary_expr) = expr { + if matches!(binary_expr.op, Operator::IsNotDistinctFrom) + || matches!(binary_expr.op, Operator::Eq) + { + return true; + } + } + + false +} + +fn child_join_type_can_be_deliminated(join_type: JoinType) -> bool { + match join_type { + JoinType::Inner | JoinType::LeftSemi | JoinType::RightSemi => true, + _ => false, + } +} + +// fetch filter (if any) and delim scan +fn fetch_delim_scan(plan: &LogicalPlan) -> (Option<&LogicalPlan>, Option<&LogicalPlan>) { + match plan { + LogicalPlan::Filter(filter) => { + if let LogicalPlan::DelimGet(_) = filter.input.as_ref() { + return (Some(plan), Some(filter.input.as_ref())); + }; + } + LogicalPlan::DelimGet(_) => { + return (None, Some(plan)); + } + + _ => {} + } + + (None, None) +} + +fn is_delim_scan(plan: &LogicalPlan) -> bool { + match plan { + LogicalPlan::Filter(filter) => { + if let LogicalPlan::SubqueryAlias(alias) = filter.input.as_ref() { + if let LogicalPlan::DelimGet(_) = alias.input.as_ref() { + return true; + }; + }; + } + LogicalPlan::SubqueryAlias(alias) => { + if let LogicalPlan::DelimGet(_) = alias.input.as_ref() { + return true; + } + } + _ => return false, + } + + false +} + +fn remove_inequality_join_with_delim_scan( + delim_join: &mut Join, + delim_scan_count: usize, + join_plan: &LogicalPlan, + is_transformed: &mut bool, +) -> Result { + if let LogicalPlan::Join(join) = join_plan { + if delim_scan_count != 1 + || !inequality_delim_join_can_be_eliminated(&join.join_type) + { + return Ok(false); + } + + let mut delim_conditions: Vec = if let Some(filter) = &mut delim_join.filter + { + split_conjunction(filter).into_iter().cloned().collect() + } else { + return Ok(false); + }; + let join_conditions = if let Some(filter) = &join.filter { + split_conjunction(filter) + } else { + return Ok(false); + }; + if delim_conditions.len() != join_conditions.len() { + return Ok(false); + } + + // TODO add single join support + if delim_join.join_type == JoinType::LeftMark { + let mut has_one_equality = false; + for condition in &join_conditions { + has_one_equality |= is_equality_join_condition(condition); + } + + if !has_one_equality { + return Ok(false); + } + } + + // We only support colref + let mut traced_cols = vec![]; + for condition in &delim_conditions { + if let Expr::BinaryExpr(binary_expr) = condition { + if let Expr::Column(column) = &*binary_expr.right { + traced_cols.push(column.clone()); + } else { + return Ok(false); + } + } else { + return Ok(false); + } + } + + // Now we trace down the column to join (for now, we only trace it through a few + // operators). + let mut cur_op = delim_join.right.as_ref(); + while *cur_op != *join_plan { + if cur_op.inputs().len() != 1 { + return Ok(false); + } + + match cur_op { + LogicalPlan::Projection(_) => find_and_replace_cols( + &mut traced_cols, + &cur_op.expressions(), + &cur_op.schema().columns(), + )?, + LogicalPlan::Filter(_) => { + // Doesn't change bindings. + break; + } + _ => return Ok(false), + }; + + cur_op = *cur_op.inputs().get(0).ok_or_else(|| { + DataFusionError::Plan("current plan has no child".to_string()) + })?; + } + + let is_left_delim_scan = is_delim_scan(join.right.as_ref()); + + let mut found_all = true; + for (idx, delim_condition) in delim_conditions.iter_mut().enumerate() { + let traced_col = traced_cols.get(idx).ok_or_else(|| { + DataFusionError::Plan("get col under traced cols".to_string()) + })?; + + let delim_comparison = + if let Expr::BinaryExpr(ref mut binary_expr) = delim_condition { + &mut binary_expr.op + } else { + return internal_err!("expr must be binary"); + }; + + let mut found = false; + for join_condition in &join_conditions { + if let Expr::BinaryExpr(binary_expr) = join_condition { + let delim_side = if is_left_delim_scan { + &*binary_expr.left + } else { + &*binary_expr.right + }; + + if let Expr::Column(column) = delim_side { + if *column == *traced_col { + let mut join_comparison = binary_expr.op; + + if matches!(delim_comparison, Operator::IsDistinctFrom) + || matches!(delim_comparison, Operator::IsNotDistinctFrom) + { + // We need to compare Null values. + if matches!(join_comparison, Operator::Eq) { + join_comparison = Operator::IsNotDistinctFrom; + } else if matches!(join_comparison, Operator::NotEq) { + join_comparison = Operator::IsDistinctFrom; + } else if !matches!( + join_comparison, + Operator::IsDistinctFrom + ) && !matches!( + join_comparison, + Operator::IsNotDistinctFrom, + ) { + // The optimization does not work here + found = false; + break; + } + + // TODO how to change delim condition's comparison + *delim_comparison = + flip_comparison_operator(join_comparison)?; + + // Join condition was a not equal and filtered out all NULLs. + // Delim join need to do that for not delim scan side. Easiest way + // is to change the comparison expression type. + if delim_join.join_type != JoinType::LeftMark { + if *delim_comparison == Operator::IsDistinctFrom { + *delim_comparison = Operator::NotEq; + } + if *delim_comparison == Operator::IsNotDistinctFrom { + *delim_comparison = Operator::Eq; + } + } + + found = true; + break; + } + } + } else { + return internal_err!("expr must be column"); + } + } else { + return internal_err!("expr must be binary"); + } + } + found_all &= found; + } + + // Construct a new filter for delim join. + if found_all { + // If we found all conditions, combine them into a new filter. + if !delim_conditions.is_empty() { + let new_filter = conjunction(delim_conditions); + delim_join.filter = new_filter; + } else { + delim_join.filter = None; + } + + *is_transformed = true; + } + + Ok(found_all) + } else { + internal_err!( + "current plan must be join in remove_inequality_join_with_delim_scan" + ) + } +} + +fn inequality_delim_join_can_be_eliminated(join_type: &JoinType) -> bool { + // TODO add single join support + *join_type == JoinType::LeftAnti + || *join_type == JoinType::RightAnti + || *join_type == JoinType::LeftSemi + || *join_type == JoinType::RightSemi +} + +fn find_and_replace_cols( + traced_cols: &mut Vec, + exprs: &Vec, + cur_cols: &Vec, +) -> Result { + for col in traced_cols { + let mut cur_idx = 0; + for (idx, _) in exprs.iter().enumerate() { + cur_idx = idx; + if *col + == *cur_cols.get(idx).ok_or_else(|| { + DataFusionError::Plan("no column at idx".to_string()) + })? + { + break; + } + } + + if cur_idx == exprs.len() { + return Ok(false); + } + + if let Expr::Column(column) = exprs + .get(cur_idx) + .ok_or_else(|| DataFusionError::Plan("no expr at cur_idx".to_string()))? + { + *col = column.clone(); + } else { + return Ok(false); + } + } + + return Ok(true); +} + +fn flip_comparison_operator(operator: Operator) -> Result { + match operator { + Operator::Eq + | Operator::NotEq + | Operator::IsDistinctFrom + | Operator::IsNotDistinctFrom => Ok(operator), + Operator::Lt => Ok(Operator::Gt), + Operator::LtEq => Ok(Operator::GtEq), + Operator::Gt => Ok(Operator::Lt), + Operator::GtEq => Ok(Operator::LtEq), + _ => internal_err!("unsupported comparison type in flip"), + } +} + +struct ColumnRewriter { + // + replacement_cols: Vec<(Column, Column)>, +} + +impl ColumnRewriter { + fn new(replacement_cols: Vec<(Column, Column)>) -> Self { + Self { replacement_cols } + } +} + +impl TreeNodeRewriter for ColumnRewriter { + type Node = LogicalPlan; + + fn f_down(&mut self, plan: LogicalPlan) -> Result> { + Ok(Transformed::no(plan)) + } + + fn f_up(&mut self, plan: LogicalPlan) -> Result> { + // Helper closure to rewrite expressions + let rewrite_expr = |expr: Expr| -> Result> { + let mut transformed = false; + let new_expr = expr.clone().transform_down(|expr| { + Ok(match expr { + Expr::Column(col) => { + if let Some((_, new_col)) = self + .replacement_cols + .iter() + .find(|(old_col, _)| old_col == &col) + { + transformed = true; + Transformed::yes(Expr::Column(new_col.clone())) + } else { + Transformed::no(Expr::Column(col)) + } + } + _ => Transformed::no(expr), + }) + })?; + + Ok(if transformed { + Transformed::yes(new_expr.data) + } else { + Transformed::no(expr) + }) + }; + + // Rewrite expressions in the plan + // Apply the rewrite to all expressions in the plan node + match plan { + LogicalPlan::Filter(filter) => { + let new_predicate = rewrite_expr(filter.predicate.clone())?; + Ok(if new_predicate.transformed { + Transformed::yes(LogicalPlan::Filter(Filter::try_new( + new_predicate.data, + filter.input, + )?)) + } else { + Transformed::no(LogicalPlan::Filter(filter)) + }) + } + LogicalPlan::Projection(projection) => { + let mut transformed = false; + let new_exprs: Vec = projection + .expr + .clone() + .into_iter() + .map(|expr| { + let res = rewrite_expr(expr)?; + transformed |= res.transformed; + Ok(res.data) + }) + .collect::>()?; + + Ok(if transformed { + Transformed::yes(LogicalPlan::Projection(Projection::try_new( + new_exprs, + projection.input, + )?)) + } else { + Transformed::no(LogicalPlan::Projection(projection)) + }) + } + // Add other cases as needed... + _ => Ok(Transformed::no(plan)), + } + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow::datatypes::DataType as ArrowDataType; + use datafusion_common::{Column, Result}; + use datafusion_expr::{ + col, lit, CorrelatedColumnInfo, Expr, JoinType, LogicalPlanBuilder, + }; + use datafusion_functions_aggregate::count::count; + use datafusion_sql::TableReference; + use insta::assert_snapshot; + + use crate::deliminator::Deliminator; + use crate::test::{test_delim_scan_with_name, test_table_scan_with_name}; + use crate::OptimizerContext; + + macro_rules! assert_deliminate { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let rule: Arc = + Arc::new(Deliminator::new()); + let transformed = rule.rewrite( + $plan.clone(), + &OptimizerContext::new().with_skip_failing_rules(true), + )?; + let display = transformed.data.display_indent_schema(); + assert_snapshot!( + display, + @ $expected, + ) + }}; + } + + #[test] + fn test_delim_joins() -> Result<()> { + // Projection + // | + // Filter + // | + // DelimJoin1 + // / ^ \ + // Get T3 | Projection + // | | + // | InnerJoin + // | / \ + // | DelimGet1 Aggregate + // | | | + // +------+ Filter + // | | + // | InnerJoin + // | / \ + // | CrossProduct Projection + // | / \ | + // | Get t2 DelimGet2 Get t1 + // | | + // + ----------------+ + + // Bottom level plan (rightmost branch) + let get_t1 = test_table_scan_with_name("t1")?; + let get_t2 = test_table_scan_with_name("t2")?; + let get_t3 = test_table_scan_with_name("t3")?; + + // Create schema for DelimGet2 + let delim_get2 = test_delim_scan_with_name(vec![CorrelatedColumnInfo { + col: Column::new(Some(TableReference::bare("delim_get2")), "d"), + data_type: ArrowDataType::UInt32, + depth: 0, + }])?; + + // Create right branch starting with t1 + let t1_projection = LogicalPlanBuilder::from(get_t1) + .project(vec![col("t1.a")])? + .build()?; + + // Create cross product of t2 and delim_get2 + let bottom_cross = LogicalPlanBuilder::from(get_t2) + .cross_join(delim_get2)? + .build()?; + + // Join cross product with t1 projection + let bottom_join = LogicalPlanBuilder::from(bottom_cross) + .join( + t1_projection, + JoinType::Inner, + (vec![Column::from_name("a")], vec![Column::from_name("a")]), + None, + )? + .build()?; + + // Add filter and aggregate + let bottom_filter = LogicalPlanBuilder::from(bottom_join) + .filter(col("t2.a").eq(lit(1)))? + .build()?; + + let bottom_agg = LogicalPlanBuilder::from(bottom_filter) + .aggregate(Vec::::new(), vec![count(col("t2.a"))])? + .build()?; + + // Create DelimGet1 for middle join + let delim_get1 = test_delim_scan_with_name(vec![CorrelatedColumnInfo { + col: Column::new(Some(TableReference::bare("delim_get1")), "a"), + data_type: ArrowDataType::UInt32, + depth: 0, + }])?; + + // Join DelimGet1 with aggregate + let middle_join = LogicalPlanBuilder::from(delim_get1) + .join( + bottom_agg, + JoinType::Inner, + (vec![Column::from_name("a")], vec![Column::from_name("d")]), + None, + )? + .build()?; + + let middle_proj = LogicalPlanBuilder::from(middle_join) + .project(vec![col("a").alias("p_a")])? + .build()?; + + // Final DelimJoin at top level + let final_join = LogicalPlanBuilder::from(get_t3) + .delim_join( + middle_proj, + JoinType::Inner, + (vec![Column::from_name("a")], vec![Column::from_name("p_a")]), + None, + )? + .build()?; + + let final_filter = LogicalPlanBuilder::from(final_join) + .filter(col("t3.a").eq(lit(1)))? + .build()?; + + let plan = LogicalPlanBuilder::from(final_filter) + .project(vec![col("t3.a")])? + .build()?; + + // Projection: t3.a + // Filter: t3.a = Int32(1) + // Inner Join(DelimJoin): t3.a = p_a + // TableScan: t3 + // Projection: a AS p_a + // Inner Join(ComparisonJoin): a = d <- eliminate here + // DelimGet: b + // Aggregate: groupBy=[[]], aggr=[[count(t2.a)]] + // Filter: t2.a = Int32(1) + // Inner Join(ComparisonJoin): t2.a = t1.a + // Cross Join(ComparisonJoin): <- keep the deepest delimscan + // TableScan: t2 + // DelimGet: b + // Projection: t1.a + // TableScan: t1 + + // let rule: Arc = + // Arc::new(Deliminator::new()); + // rule.rewrite(plan, &OptimizerContext::new().with_skip_failing_rules(true)); + + assert_deliminate!(plan, @r" + Projection: t3.a [a:UInt32] + Filter: t3.a = Int32(1) [a:UInt32, b:UInt32, c:UInt32, p_a:UInt32;N] + Inner Join(DelimJoin): t3.a = p_a [a:UInt32, b:UInt32, c:UInt32, p_a:UInt32;N] + TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] + Projection: a AS p_a [p_a:UInt32;N] + Aggregate: groupBy=[[]], aggr=[[count(t2.a)]] [count(t2.a):Int64] + Filter: t2.a = Int32(1) [a:UInt32, b:UInt32, c:UInt32, d:UInt32;N, a:UInt32] + Inner Join(ComparisonJoin): t2.a = t1.a [a:UInt32, b:UInt32, c:UInt32, d:UInt32;N, a:UInt32] + Cross Join(ComparisonJoin): [a:UInt32, b:UInt32, c:UInt32, d:UInt32;N] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + DelimGet: b [d:UInt32;N] + Projection: t1.a [a:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + "); + + Ok(()) + } +} diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index ae1d7df46d52..22a07fd5e3cd 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -27,7 +27,7 @@ use datafusion_expr::logical_plan::{ Filter, Join, JoinConstraint, JoinType, LogicalPlan, Projection, }; use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair}; -use datafusion_expr::{and, build_join_schema, ExprSchemable, Operator}; +use datafusion_expr::{and, build_join_schema, ExprSchemable, JoinKind, Operator}; #[derive(Default, Debug)] pub struct EliminateCrossJoin; @@ -342,6 +342,7 @@ fn find_inner_join( filter: None, schema: join_schema, null_equality, + join_kind: JoinKind::ComparisonJoin, })); } } @@ -364,6 +365,7 @@ fn find_inner_join( join_type: JoinType::Inner, join_constraint: JoinConstraint::On, null_equality, + join_kind: JoinKind::ComparisonJoin, })) } @@ -496,7 +498,7 @@ mod tests { plan, @ r" Filter: t2.c < UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] - Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join(ComparisonJoin): t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] " @@ -523,7 +525,7 @@ mod tests { plan, @ r" Filter: t1.a = t2.a OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] - Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Cross Join(ComparisonJoin): [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] " @@ -549,7 +551,7 @@ mod tests { plan, @ r" Filter: t2.c < UInt32(20) AND t2.c = UInt32(10) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] - Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join(ComparisonJoin): t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] " @@ -579,7 +581,7 @@ mod tests { plan, @ r" Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] - Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join(ComparisonJoin): t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] " @@ -609,7 +611,7 @@ mod tests { plan, @ r" Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.b = t2.b AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] - Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Cross Join(ComparisonJoin): [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] " @@ -635,7 +637,7 @@ mod tests { plan, @ r" Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.a = t2.a OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] - Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Cross Join(ComparisonJoin): [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] " @@ -665,8 +667,8 @@ mod tests { @ r" Filter: t1.a > UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] Filter: t1.a > UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] - Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] - Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join(ComparisonJoin): t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join(ComparisonJoin): t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] @@ -711,8 +713,8 @@ mod tests { @ r" Filter: t3.c < UInt32(15) AND t3.b < UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c, t3.a, t3.b, t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] - Inner Join: t3.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] - Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join(ComparisonJoin): t3.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join(ComparisonJoin): t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] @@ -783,13 +785,13 @@ mod tests { plan, @ r" Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] - Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join(ComparisonJoin): t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] - Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join(ComparisonJoin): t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] - Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join(ComparisonJoin): t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] TableScan: t4 [a:UInt32, b:UInt32, c:UInt32] " @@ -857,13 +859,13 @@ mod tests { plan, @ r" Filter: t3.a = t1.a AND t4.c < UInt32(15) OR t3.a = t1.a OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] - Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Cross Join(ComparisonJoin): [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] - Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join(ComparisonJoin): t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] - Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join(ComparisonJoin): t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] TableScan: t4 [a:UInt32, b:UInt32, c:UInt32] " @@ -931,13 +933,13 @@ mod tests { plan, @ r" Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] - Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join(ComparisonJoin): t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] - Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join(ComparisonJoin): t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] Filter: t3.a = t4.a AND t4.c < UInt32(15) OR t3.a = t4.a AND t3.c = UInt32(688) OR t3.a = t4.a OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] - Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Cross Join(ComparisonJoin): [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] TableScan: t4 [a:UInt32, b:UInt32, c:UInt32] " @@ -1009,13 +1011,13 @@ mod tests { plan, @ r" Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] - Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join(ComparisonJoin): t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] Filter: t1.a = t2.a OR t2.c < UInt32(15) OR t1.a = t2.a AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] - Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Cross Join(ComparisonJoin): [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] - Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join(ComparisonJoin): t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] TableScan: t4 [a:UInt32, b:UInt32, c:UInt32] " @@ -1097,10 +1099,10 @@ mod tests { plan, @ r" Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] - Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] - Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join(ComparisonJoin): t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join(ComparisonJoin): t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] Filter: t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] - Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join(ComparisonJoin): t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] @@ -1190,9 +1192,9 @@ mod tests { plan, @ r" Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] - Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] - Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] - Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join(ComparisonJoin): t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join(ComparisonJoin): t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join(ComparisonJoin): t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] @@ -1220,7 +1222,7 @@ mod tests { plan, @ r" Filter: t2.c < UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] - Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join(ComparisonJoin): t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] " @@ -1247,7 +1249,7 @@ mod tests { plan, @ r" Filter: t1.a + UInt32(100) = t2.a * UInt32(2) OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] - Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Cross Join(ComparisonJoin): [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] " @@ -1274,7 +1276,7 @@ mod tests { plan, @ r" Filter: t2.c < UInt32(20) AND t2.c = UInt32(10) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] - Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join(ComparisonJoin): t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] " @@ -1301,7 +1303,7 @@ mod tests { plan, @ r" Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] - Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join(ComparisonJoin): t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] " @@ -1338,8 +1340,8 @@ mod tests { @ r" Filter: t3.c < UInt32(15) AND t3.b < UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c, t3.a, t3.b, t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] - Inner Join: t3.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] - Inner Join: t1.a * UInt32(2) = t3.a + UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join(ComparisonJoin): t3.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join(ComparisonJoin): t1.a * UInt32(2) = t3.a + UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] @@ -1368,6 +1370,7 @@ mod tests { filter: None, schema: join_schema, null_equality: NullEquality::NullEqualsNull, // Test preservation + join_kind: JoinKind::ComparisonJoin, }); // Apply filter that can create join conditions diff --git a/datafusion/optimizer/src/eliminate_outer_join.rs b/datafusion/optimizer/src/eliminate_outer_join.rs index 45877642f276..52c447d4613a 100644 --- a/datafusion/optimizer/src/eliminate_outer_join.rs +++ b/datafusion/optimizer/src/eliminate_outer_join.rs @@ -119,6 +119,7 @@ impl OptimizerRule for EliminateOuterJoin { filter: join.filter.clone(), schema: Arc::clone(&join.schema), null_equality: join.null_equality, + join_kind: join.join_kind, })); Filter::try_new(filter.predicate, new_join) .map(|f| Transformed::yes(LogicalPlan::Filter(f))) @@ -349,7 +350,7 @@ mod tests { assert_optimized_plan_equal!(plan, @r" Filter: t2.b IS NULL - Left Join: t1.a = t2.a + Left Join(ComparisonJoin): t1.a = t2.a TableScan: t1 TableScan: t2 ") @@ -373,7 +374,7 @@ mod tests { assert_optimized_plan_equal!(plan, @r" Filter: t2.b IS NOT NULL - Inner Join: t1.a = t2.a + Inner Join(ComparisonJoin): t1.a = t2.a TableScan: t1 TableScan: t2 ") @@ -401,7 +402,7 @@ mod tests { assert_optimized_plan_equal!(plan, @r" Filter: t1.b > UInt32(10) OR t1.c < UInt32(20) - Inner Join: t1.a = t2.a + Inner Join(ComparisonJoin): t1.a = t2.a TableScan: t1 TableScan: t2 ") @@ -429,7 +430,7 @@ mod tests { assert_optimized_plan_equal!(plan, @r" Filter: t1.b > UInt32(10) AND t2.c < UInt32(20) - Inner Join: t1.a = t2.a + Inner Join(ComparisonJoin): t1.a = t2.a TableScan: t1 TableScan: t2 ") @@ -457,7 +458,7 @@ mod tests { assert_optimized_plan_equal!(plan, @r" Filter: CAST(t1.b AS Int64) > UInt32(10) AND TRY_CAST(t2.c AS Int64) < UInt32(20) - Inner Join: t1.a = t2.a + Inner Join(ComparisonJoin): t1.a = t2.a TableScan: t1 TableScan: t2 ") diff --git a/datafusion/optimizer/src/extract_equijoin_predicate.rs b/datafusion/optimizer/src/extract_equijoin_predicate.rs index 55cf33ef4304..7703e5db988a 100644 --- a/datafusion/optimizer/src/extract_equijoin_predicate.rs +++ b/datafusion/optimizer/src/extract_equijoin_predicate.rs @@ -76,6 +76,7 @@ impl OptimizerRule for ExtractEquijoinPredicate { join_constraint, schema, null_equality, + join_kind, }) => { let left_schema = left.schema(); let right_schema = right.schema(); @@ -93,6 +94,7 @@ impl OptimizerRule for ExtractEquijoinPredicate { join_constraint, schema, null_equality, + join_kind, }))) } else { Ok(Transformed::no(LogicalPlan::Join(Join { @@ -104,6 +106,7 @@ impl OptimizerRule for ExtractEquijoinPredicate { join_constraint, schema, null_equality, + join_kind, }))) } } @@ -189,7 +192,7 @@ mod tests { assert_optimized_plan_equal!( plan, @r" - Left Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N] + Left Join(ComparisonJoin): t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N] TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] " @@ -212,7 +215,7 @@ mod tests { assert_optimized_plan_equal!( plan, @r" - Left Join: t1.a + Int64(10) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N] + Left Join(ComparisonJoin): t1.a + Int64(10) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N] TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] " @@ -239,7 +242,7 @@ mod tests { assert_optimized_plan_equal!( plan, @r" - Left Join: Filter: t1.a + Int64(10) >= t2.a * UInt32(2) AND t1.b < Int32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N] + Left Join(ComparisonJoin): Filter: t1.a + Int64(10) >= t2.a * UInt32(2) AND t1.b < Int32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N] TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] " @@ -270,7 +273,7 @@ mod tests { assert_optimized_plan_equal!( plan, @r" - Left Join: t1.a + UInt32(11) = t2.a * UInt32(2), t1.a + Int64(10) = t2.a * UInt32(2) Filter: t1.b < Int32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N] + Left Join(ComparisonJoin): t1.a + UInt32(11) = t2.a * UInt32(2), t1.a + Int64(10) = t2.a * UInt32(2) Filter: t1.b < Int32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N] TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] " @@ -300,7 +303,7 @@ mod tests { assert_optimized_plan_equal!( plan, @r" - Left Join: t1.a = t2.a, t1.b = t2.b Filter: t1.c = t2.c OR t1.a + t1.b > t2.b + t2.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N] + Left Join(ComparisonJoin): t1.a = t2.a, t1.b = t2.b Filter: t1.c = t2.c OR t1.a + t1.b > t2.b + t2.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N] TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] " @@ -339,9 +342,9 @@ mod tests { assert_optimized_plan_equal!( plan, @r" - Left Join: t1.a = t2.a Filter: t1.c + t2.c + t3.c < UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N, a:UInt32;N, b:UInt32;N, c:UInt32;N] + Left Join(ComparisonJoin): t1.a = t2.a Filter: t1.c + t2.c + t3.c < UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N, a:UInt32;N, b:UInt32;N, c:UInt32;N] TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] - Left Join: t2.a = t3.a Filter: t2.a + t3.b > UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N] + Left Join(ComparisonJoin): t2.a = t3.a Filter: t2.a + t3.b > UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N] TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] " @@ -376,9 +379,9 @@ mod tests { assert_optimized_plan_equal!( plan, @r" - Left Join: t1.a = t2.a Filter: t2.c = t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N, a:UInt32;N, b:UInt32;N, c:UInt32;N] + Left Join(ComparisonJoin): t1.a = t2.a Filter: t2.c = t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N, a:UInt32;N, b:UInt32;N, c:UInt32;N] TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] - Left Join: t2.a = t3.a Filter: t2.a + t3.b > UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N] + Left Join(ComparisonJoin): t2.a = t3.a Filter: t2.a + t3.b > UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N] TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] " @@ -406,7 +409,7 @@ mod tests { assert_optimized_plan_equal!( plan, @r" - Left Join: t1.a + CAST(Int64(1) AS UInt32) = t2.a + CAST(Int32(2) AS UInt32) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N] + Left Join(ComparisonJoin): t1.a + CAST(Int64(1) AS UInt32) = t2.a + CAST(Int32(2) AS UInt32) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N] TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] " diff --git a/datafusion/optimizer/src/filter_null_join_keys.rs b/datafusion/optimizer/src/filter_null_join_keys.rs index 8ad7fa53c0e3..58d6c9aa6220 100644 --- a/datafusion/optimizer/src/filter_null_join_keys.rs +++ b/datafusion/optimizer/src/filter_null_join_keys.rs @@ -137,7 +137,7 @@ mod tests { let plan = build_plan(t1, t2, "t1.optional_id", "t2.id", JoinType::Inner)?; assert_optimized_plan_equal!(plan, @r" - Inner Join: t1.optional_id = t2.id + Inner Join(ComparisonJoin): t1.optional_id = t2.id Filter: t1.optional_id IS NOT NULL TableScan: t1 TableScan: t2 @@ -150,7 +150,7 @@ mod tests { let plan = build_plan(t1, t2, "t1.optional_id", "t2.id", JoinType::Left)?; assert_optimized_plan_equal!(plan, @r" - Left Join: t1.optional_id = t2.id + Left Join(ComparisonJoin): t1.optional_id = t2.id TableScan: t1 TableScan: t2 ") @@ -164,7 +164,7 @@ mod tests { build_plan(t_right, t_left, "t2.id", "t1.optional_id", JoinType::Left)?; assert_optimized_plan_equal!(plan, @r" - Left Join: t2.id = t1.optional_id + Left Join(ComparisonJoin): t2.id = t1.optional_id TableScan: t2 Filter: t1.optional_id IS NOT NULL TableScan: t1 @@ -177,7 +177,7 @@ mod tests { let plan = build_plan(t1, t2, "t2.id", "t1.optional_id", JoinType::Inner)?; assert_optimized_plan_equal!(plan, @r" - Inner Join: t1.optional_id = t2.id + Inner Join(ComparisonJoin): t1.optional_id = t2.id Filter: t1.optional_id IS NOT NULL TableScan: t1 TableScan: t2 @@ -213,10 +213,10 @@ mod tests { .build()?; assert_optimized_plan_equal!(plan, @r" - Inner Join: t3.t1_id = t1.id, t3.t2_id = t2.id + Inner Join(ComparisonJoin): t3.t1_id = t1.id, t3.t2_id = t2.id Filter: t3.t1_id IS NOT NULL AND t3.t2_id IS NOT NULL TableScan: t3 - Inner Join: t1.optional_id = t2.id + Inner Join(ComparisonJoin): t1.optional_id = t2.id Filter: t1.optional_id IS NOT NULL TableScan: t1 TableScan: t2 @@ -239,7 +239,7 @@ mod tests { .build()?; assert_optimized_plan_equal!(plan, @r" - Inner Join: t1.optional_id + UInt32(1) = t2.id + UInt32(1) + Inner Join(ComparisonJoin): t1.optional_id + UInt32(1) = t2.id + UInt32(1) Filter: t1.optional_id + UInt32(1) IS NOT NULL TableScan: t1 TableScan: t2 @@ -262,7 +262,7 @@ mod tests { .build()?; assert_optimized_plan_equal!(plan, @r" - Inner Join: t1.id + UInt32(1) = t2.optional_id + UInt32(1) + Inner Join(ComparisonJoin): t1.id + UInt32(1) = t2.optional_id + UInt32(1) TableScan: t1 Filter: t2.optional_id + UInt32(1) IS NOT NULL TableScan: t2 @@ -285,7 +285,7 @@ mod tests { .build()?; assert_optimized_plan_equal!(plan, @r" - Inner Join: t1.optional_id + UInt32(1) = t2.optional_id + UInt32(1) + Inner Join(ComparisonJoin): t1.optional_id + UInt32(1) = t2.optional_id + UInt32(1) Filter: t1.optional_id + UInt32(1) IS NOT NULL TableScan: t1 Filter: t2.optional_id + UInt32(1) IS NOT NULL @@ -314,7 +314,7 @@ mod tests { .build()?; assert_optimized_plan_equal!(plan_from_cols, @r" - Inner Join: t1.optional_id = t2.optional_id + Inner Join(ComparisonJoin): t1.optional_id = t2.optional_id Filter: t1.optional_id IS NOT NULL TableScan: t1 Filter: t2.optional_id IS NOT NULL diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index 280010e3d92c..bbe0f9e2d6e4 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -40,8 +40,12 @@ pub mod analyzer; pub mod common_subexpr_eliminate; pub mod decorrelate; +pub mod decorrelate_dependent_join; pub mod decorrelate_lateral_join; pub mod decorrelate_predicate_subquery; +pub mod delim_candidate_rewriter; +pub mod delim_candidates_collector; +pub mod deliminator; pub mod eliminate_cross_join; pub mod eliminate_duplicated_expr; pub mod eliminate_filter; @@ -59,6 +63,7 @@ pub mod propagate_empty_relation; pub mod push_down_filter; pub mod push_down_limit; pub mod replace_distinct_aggregate; +pub mod rewrite_dependent_join; pub mod scalar_subquery_to_join; pub mod simplify_expressions; pub mod single_distinct_to_groupby; diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index 7b7be82b70ca..8d09452814d1 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -393,6 +393,10 @@ fn optimize_projections( }); vec![required_indices.append(&additional_necessary_child_indices)] } + LogicalPlan::DependentJoin(..) => { + return Ok(Transformed::no(plan)); + } + LogicalPlan::DelimGet(_) => return Ok(Transformed::no(plan)), }; // Required indices are currently ordered (child0, child1, ...) @@ -716,7 +720,8 @@ fn split_join_requirements( | JoinType::Right | JoinType::Full | JoinType::LeftMark - | JoinType::RightMark => { + | JoinType::RightMark + | JoinType::LeftSingle => { // Decrease right side indices by `left_len` so that they point to valid // positions within the right child: indices.split_off(left_len) @@ -1741,7 +1746,7 @@ mod tests { assert_snapshot!( optimized_plan.clone(), @r" - Left Join: test.a = test2.c1 + Left Join(ComparisonJoin): test.a = test2.c1 TableScan: test projection=[a, b] TableScan: test2 projection=[c1] " @@ -1796,7 +1801,7 @@ mod tests { optimized_plan.clone(), @r" Projection: test.a, test.b - Left Join: test.a = test2.c1 + Left Join(ComparisonJoin): test.a = test2.c1 TableScan: test projection=[a, b] TableScan: test2 projection=[c1] " diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index 4d2c2c7c79cd..523d91060932 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -33,8 +33,9 @@ use datafusion_common::{internal_err, DFSchema, DataFusionError, HashSet, Result use datafusion_expr::logical_plan::LogicalPlan; use crate::common_subexpr_eliminate::CommonSubexprEliminate; +use crate::decorrelate_dependent_join::DecorrelateDependentJoin; use crate::decorrelate_lateral_join::DecorrelateLateralJoin; -use crate::decorrelate_predicate_subquery::DecorrelatePredicateSubquery; +// use crate::decorrelate_predicate_subquery::DecorrelatePredicateSubquery; use crate::eliminate_cross_join::EliminateCrossJoin; use crate::eliminate_duplicated_expr::EliminateDuplicatedExpr; use crate::eliminate_filter::EliminateFilter; @@ -52,7 +53,7 @@ use crate::propagate_empty_relation::PropagateEmptyRelation; use crate::push_down_filter::PushDownFilter; use crate::push_down_limit::PushDownLimit; use crate::replace_distinct_aggregate::ReplaceDistinctWithAggregate; -use crate::scalar_subquery_to_join::ScalarSubqueryToJoin; +//use crate::scalar_subquery_to_join::ScalarSubqueryToJoin; use crate::simplify_expressions::SimplifyExpressions; use crate::single_distinct_to_groupby::SingleDistinctToGroupBy; use crate::utils::log_plan; @@ -225,8 +226,9 @@ impl Optimizer { Arc::new(SimplifyExpressions::new()), Arc::new(ReplaceDistinctWithAggregate::new()), Arc::new(EliminateJoin::new()), - Arc::new(DecorrelatePredicateSubquery::new()), - Arc::new(ScalarSubqueryToJoin::new()), + Arc::new(DecorrelateDependentJoin::new()), // TODO + // Arc::new(DecorrelatePredicateSubquery::new()), + // Arc::new(ScalarSubqueryToJoin::new()), Arc::new(DecorrelateLateralJoin::new()), Arc::new(ExtractEquijoinPredicate::new()), Arc::new(EliminateDuplicatedExpr::new()), diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index bcb867f6e7fa..ced28450350a 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -161,7 +161,7 @@ pub struct PushDownFilter {} pub(crate) fn lr_is_preserved(join_type: JoinType) -> (bool, bool) { match join_type { JoinType::Inner => (true, true), - JoinType::Left => (true, false), + JoinType::Left | JoinType::LeftSingle => (true, false), JoinType::Right => (false, true), JoinType::Full => (false, false), // No columns from the right side of the join can be referenced in output @@ -185,7 +185,7 @@ pub(crate) fn lr_is_preserved(join_type: JoinType) -> (bool, bool) { pub(crate) fn on_lr_is_preserved(join_type: JoinType) -> (bool, bool) { match join_type { JoinType::Inner => (true, true), - JoinType::Left => (false, true), + JoinType::Left | JoinType::LeftSingle => (false, true), JoinType::Right => (true, false), JoinType::Full => (false, false), JoinType::LeftSemi | JoinType::RightSemi => (true, true), @@ -686,7 +686,7 @@ fn infer_join_predicates_from_on_filters( on_filters, inferred_predicates, ), - JoinType::Left | JoinType::LeftSemi | JoinType::LeftMark => { + JoinType::Left | JoinType::LeftSemi | JoinType::LeftMark | JoinType::LeftSingle => { infer_join_predicates_impl::( join_col_keys, on_filters, @@ -1103,7 +1103,13 @@ impl OptimizerRule for PushDownFilter { let (volatile_filters, non_volatile_filters): (Vec<&Expr>, Vec<&Expr>) = filter_predicates .into_iter() - .partition(|pred| pred.is_volatile()); + // TODO: subquery decorrelation sometimes cannot decorrelated all the expr + // (i.e in the case of recursive subquery) + // this function may accidentally pushdown the subquery expr as well + // until then, we have to exclude these exprs here + .partition(|pred| { + pred.is_volatile() || has_scalar_subquery(pred) + }); // Check which non-volatile filters are supported by source let supported_filters = scan @@ -1396,6 +1402,14 @@ fn contain(e: &Expr, check_map: &HashMap) -> bool { is_contain } +fn has_scalar_subquery(expr: &Expr) -> bool { + expr.exists(|e| match e { + Expr::ScalarSubquery(_) => Ok(true), + _ => Ok(false), + }) + .unwrap() +} + #[cfg(test)] mod tests { use std::any::Any; @@ -2243,7 +2257,7 @@ mod tests { plan, @r" Projection: test.a, test1.d - Cross Join: + Cross Join(ComparisonJoin): Projection: test.a, test.b, test.c TableScan: test, full_filters=[test.a = Int32(1)] Projection: test1.d, test1.e, test1.f @@ -2273,7 +2287,7 @@ mod tests { plan, @r" Projection: test.a, test1.a - Cross Join: + Cross Join(ComparisonJoin): Projection: test.a, test.b, test.c TableScan: test, full_filters=[test.a = Int32(1)] Projection: test1.a, test1.b, test1.c @@ -2401,7 +2415,7 @@ mod tests { assert_snapshot!(plan, @r" Filter: test.a <= Int64(1) - Inner Join: test.a = test2.a + Inner Join(ComparisonJoin): test.a = test2.a TableScan: test Projection: test2.a TableScan: test2 @@ -2484,7 +2498,7 @@ mod tests { assert_snapshot!(plan, @r" Filter: test.c <= test2.b - Inner Join: test.a = test2.a + Inner Join(ComparisonJoin): test.a = test2.a Projection: test.a, test.c TableScan: test Projection: test2.a, test2.b @@ -2530,7 +2544,7 @@ mod tests { assert_snapshot!(plan, @r" Filter: test.b <= Int64(1) - Inner Join: test.a = test2.a + Inner Join(ComparisonJoin): test.a = test2.a Projection: test.a, test.b TableScan: test Projection: test2.a, test2.c @@ -2741,7 +2755,7 @@ mod tests { // not part of the test, just good to know: assert_snapshot!(plan, @r" - Inner Join: test.a = test2.a Filter: test.c > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4) + Inner Join(ComparisonJoin): test.a = test2.a Filter: test.c > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4) Projection: test.a, test.b, test.c TableScan: test Projection: test2.a, test2.b, test2.c @@ -2786,7 +2800,7 @@ mod tests { // not part of the test, just good to know: assert_snapshot!(plan, @r" - Inner Join: test.a = test2.a Filter: test.b > UInt32(1) AND test2.c > UInt32(4) + Inner Join(ComparisonJoin): test.a = test2.a Filter: test.b > UInt32(1) AND test2.c > UInt32(4) Projection: test.a, test.b, test.c TableScan: test Projection: test2.a, test2.b, test2.c @@ -2829,7 +2843,7 @@ mod tests { // not part of the test, just good to know: assert_snapshot!(plan, @r" - Inner Join: test.a = test2.b Filter: test.a > UInt32(1) + Inner Join(ComparisonJoin): test.a = test2.b Filter: test.a > UInt32(1) Projection: test.a TableScan: test Projection: test2.b @@ -2875,7 +2889,7 @@ mod tests { // not part of the test, just good to know: assert_snapshot!(plan, @r" - Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4) + Left Join(ComparisonJoin): test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4) Projection: test.a, test.b, test.c TableScan: test Projection: test2.a, test2.b, test2.c @@ -2921,7 +2935,7 @@ mod tests { // not part of the test, just good to know: assert_snapshot!(plan, @r" - Right Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4) + Right Join(ComparisonJoin): test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4) Projection: test.a, test.b, test.c TableScan: test Projection: test2.a, test2.b, test2.c @@ -2967,7 +2981,7 @@ mod tests { // not part of the test, just good to know: assert_snapshot!(plan, @r" - Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4) + Full Join(ComparisonJoin): test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4) Projection: test.a, test.b, test.c TableScan: test Projection: test2.a, test2.b, test2.c @@ -3256,7 +3270,7 @@ mod tests { assert_snapshot!(plan, @r" - Inner Join: c = d Filter: c > UInt32(1) + Inner Join(ComparisonJoin): c = d Filter: c > UInt32(1) Projection: test.a AS c TableScan: test Projection: test2.b AS d @@ -3439,7 +3453,7 @@ mod tests { .build()?; assert_optimized_plan_eq_with_rewrite_predicate!(plan.clone(), @r" - Inner Join: Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10) + Inner Join(ComparisonJoin): Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10) Projection: test.a, test.b, test.c TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)] Projection: test1.a AS d, test1.a AS e @@ -3488,7 +3502,7 @@ mod tests { assert_snapshot!(plan, @r" Filter: test2.a <= Int64(1) - LeftSemi Join: test1.a = test2.a + LeftSemi Join(ComparisonJoin): test1.a = test2.a TableScan: test1 Projection: test2.a, test2.b TableScan: test2 @@ -3533,7 +3547,7 @@ mod tests { // not part of the test, just good to know: assert_snapshot!(plan, @r" - LeftSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2) + LeftSemi Join(ComparisonJoin): test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2) TableScan: test1 Projection: test2.a, test2.b TableScan: test2 @@ -3575,7 +3589,7 @@ mod tests { assert_snapshot!(plan, @r" Filter: test1.a <= Int64(1) - RightSemi Join: test1.a = test2.a + RightSemi Join(ComparisonJoin): test1.a = test2.a TableScan: test1 Projection: test2.a, test2.b TableScan: test2 @@ -3620,7 +3634,7 @@ mod tests { // not part of the test, just good to know: assert_snapshot!(plan, @r" - RightSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2) + RightSemi Join(ComparisonJoin): test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2) TableScan: test1 Projection: test2.a, test2.b TableScan: test2 @@ -3665,7 +3679,7 @@ mod tests { assert_snapshot!(plan, @r" Filter: test2.a > UInt32(2) - LeftAnti Join: test1.a = test2.a + LeftAnti Join(ComparisonJoin): test1.a = test2.a Projection: test1.a, test1.b TableScan: test1 Projection: test2.a, test2.b @@ -3715,7 +3729,7 @@ mod tests { // not part of the test, just good to know: assert_snapshot!(plan, @r" - LeftAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2) + LeftAnti Join(ComparisonJoin): test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2) Projection: test1.a, test1.b TableScan: test1 Projection: test2.a, test2.b @@ -3762,7 +3776,7 @@ mod tests { assert_snapshot!(plan, @r" Filter: test1.a > UInt32(2) - RightAnti Join: test1.a = test2.a + RightAnti Join(ComparisonJoin): test1.a = test2.a Projection: test1.a, test1.b TableScan: test1 Projection: test2.a, test2.b @@ -3812,7 +3826,7 @@ mod tests { // not part of the test, just good to know: assert_snapshot!(plan, @r" - RightAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2) + RightAnti Join(ComparisonJoin): test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2) Projection: test1.a, test1.b TableScan: test1 Projection: test2.a, test2.b @@ -3931,7 +3945,7 @@ mod tests { Filter: t.r > Float64(0.8) SubqueryAlias: t Projection: test1.a AS a, TestScalarUDF() AS r - Inner Join: test1.a = test2.a + Inner Join(ComparisonJoin): test1.a = test2.a TableScan: test1 TableScan: test2 ", diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs index ec042dd350ca..1319aa059615 100644 --- a/datafusion/optimizer/src/push_down_limit.rs +++ b/datafusion/optimizer/src/push_down_limit.rs @@ -843,7 +843,7 @@ mod test { plan, @r" Limit: skip=10, fetch=1000 - Inner Join: test.a = test2.a + Inner Join(ComparisonJoin): test.a = test2.a TableScan: test TableScan: test2 " @@ -870,7 +870,7 @@ mod test { plan, @r" Limit: skip=10, fetch=1000 - Inner Join: test.a = test2.a + Inner Join(ComparisonJoin): test.a = test2.a TableScan: test TableScan: test2 " @@ -961,7 +961,7 @@ mod test { plan, @r" Limit: skip=10, fetch=1000 - Left Join: test.a = test2.a + Left Join(ComparisonJoin): test.a = test2.a Limit: skip=0, fetch=1010 TableScan: test, fetch=1010 TableScan: test2 @@ -989,7 +989,7 @@ mod test { plan, @r" Limit: skip=0, fetch=1000 - Right Join: test.a = test2.a + Right Join(ComparisonJoin): test.a = test2.a TableScan: test Limit: skip=0, fetch=1000 TableScan: test2, fetch=1000 @@ -1017,7 +1017,7 @@ mod test { plan, @r" Limit: skip=10, fetch=1000 - Right Join: test.a = test2.a + Right Join(ComparisonJoin): test.a = test2.a TableScan: test Limit: skip=0, fetch=1010 TableScan: test2, fetch=1010 @@ -1039,7 +1039,7 @@ mod test { plan, @r" Limit: skip=0, fetch=1000 - Cross Join: + Cross Join(ComparisonJoin): Limit: skip=0, fetch=1000 TableScan: test, fetch=1000 Limit: skip=0, fetch=1000 @@ -1062,7 +1062,7 @@ mod test { plan, @r" Limit: skip=1000, fetch=1000 - Cross Join: + Cross Join(ComparisonJoin): Limit: skip=0, fetch=2000 TableScan: test, fetch=2000 Limit: skip=0, fetch=2000 diff --git a/datafusion/optimizer/src/rewrite_dependent_join.rs b/datafusion/optimizer/src/rewrite_dependent_join.rs new file mode 100644 index 000000000000..22fec13684c0 --- /dev/null +++ b/datafusion/optimizer/src/rewrite_dependent_join.rs @@ -0,0 +1,2451 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`DependentJoinRewriter`] converts correlated subqueries to `DependentJoin` + +use std::collections::VecDeque; +use std::ops::Deref; +use std::sync::Arc; + +use crate::{ApplyOrder, OptimizerConfig, OptimizerRule}; + +use arrow::datatypes::DataType; +use datafusion_common::alias::AliasGenerator; +use datafusion_common::tree_node::{ + Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter, +}; +use datafusion_common::{ + internal_datafusion_err, internal_err, not_impl_err, Column, HashMap, Result, +}; +use datafusion_expr::{ + col, lit, not, Aggregate, CorrelatedColumnInfo, Expr, Filter, Join, LogicalPlan, + LogicalPlanBuilder, Projection, +}; + +use indexmap::map::Entry; +use indexmap::IndexMap; +use itertools::Itertools; + +pub struct DependentJoinRewriter { + // each logical plan traversal will assign it a integer id + current_id: usize, + subquery_depth: usize, + // each newly visted `LogicalPlan` is inserted inside this map for tracking + nodes: IndexMap, + // all the node ids from root to the current node + // this is mutated duri traversal + stack: Vec, + // track for each column, the nodes/logical plan that reference to its within the tree + all_outer_ref_columns: IndexMap>, + alias_generator: Arc, +} + +#[derive(Debug, Hash, PartialEq, PartialOrd, Eq, Clone)] +struct ColumnAccess { + // node ids from root to the node that is referencing the column + stack: Vec, + // the node referencing the column + node_id: usize, + col: Column, + data_type: DataType, + subquery_depth: usize, +} + +impl DependentJoinRewriter { + // this function is to rewrite logical plan having arbitrary exprs that contain + // subquery expr into dependent join logical plan + fn rewrite_exprs_into_dependent_join_plan( + exprs: Vec>, + dependent_join_node: &Node, + current_subquery_depth: usize, + mut current_plan: LogicalPlanBuilder, + subquery_alias_by_offset: HashMap, + ) -> Result<(LogicalPlanBuilder, Vec>)> { + // everytime we meet a subquery during traversal, we increment this by 1 + // we can use this offset to lookup the original subquery info + // in subquery_alias_by_offset + // the reason why we cannot create a hashmap keyed by Subquery object HashMap + // is that the subquery inside this filter expr may have been rewritten in + // the lower level + let mut offset = 0; + let offset_ref = &mut offset; + let mut subquery_expr_by_offset = HashMap::new(); + let mut rewritten_exprs_groups = vec![]; + for expr_group in exprs { + let rewritten_exprs = expr_group + .iter() + .cloned() + .map(|e| { + Ok(e.clone() + .transform(|e| { + // replace any subquery expr with subquery_alias.output column + let alias = match e { + Expr::InSubquery(_) + | Expr::Exists(_) + | Expr::ScalarSubquery(_) => subquery_alias_by_offset + .get(offset_ref) + .ok_or(internal_datafusion_err!( + "subquery alias not found at offset {}", + *offset_ref + )), + _ => return Ok(Transformed::no(e)), + }?; + + // We are aware that the original subquery can be rewritten update the + // latest expr to this map. + subquery_expr_by_offset.insert(*offset_ref, e); + *offset_ref += 1; + + Ok(Transformed::yes(col(format!("{alias}")))) + })? + .data) + }) + .collect::>>()?; + rewritten_exprs_groups.push(rewritten_exprs); + } + + for (subquery_offset, (_, column_accesses)) in dependent_join_node + .columns_accesses_by_subquery_id + .iter() + .enumerate() + { + let alias = subquery_alias_by_offset.get(&subquery_offset).ok_or( + internal_datafusion_err!( + "subquery alias not found at offset {subquery_offset}" + ), + )?; + let subquery_expr = subquery_expr_by_offset.get(&subquery_offset).ok_or( + internal_datafusion_err!( + "subquery expr not found at offset {subquery_offset}" + ), + )?; + + let (splan, sexpr) = unwrap_subquery_input_from_expr(subquery_expr); + + let correlated_columns = column_accesses + .iter() + .map(|ac| CorrelatedColumnInfo { + col: ac.col.clone(), + data_type: ac.data_type.clone(), + depth: ac.subquery_depth, + }) + .unique() + .collect(); + + current_plan = current_plan.dependent_join( + splan, + correlated_columns, + Some(sexpr), + current_subquery_depth, + alias.clone(), + None, + )?; + } + Ok((current_plan, rewritten_exprs_groups)) + } + + fn rewrite_filter( + &mut self, + filter: &Filter, + dependent_join_node: &Node, + current_subquery_depth: usize, + current_plan: LogicalPlanBuilder, + subquery_alias_by_offset: HashMap, + ) -> Result { + // because dependent join may introduce extra columns + // to evaluate the subquery, the final plan should + // have another projection to remove these redundant columns + let post_join_projections: Vec = filter + .input + .schema() + .columns() + .iter() + .map(|c| col(c.clone())) + .collect(); + + // Extract NOT from negated subqueries before processing. + let normalized_predicate = normalize_negated_subqueries(&filter.predicate)?; + + let (transformed_plan, transformed_exprs) = + Self::rewrite_exprs_into_dependent_join_plan( + vec![vec![&normalized_predicate]], + dependent_join_node, + current_subquery_depth, + current_plan, + subquery_alias_by_offset, + )?; + + let transformed_predicate = transformed_exprs + .first() + .ok_or(internal_datafusion_err!( + "transform predicate does not return 1 element" + ))? + .first() + .ok_or(internal_datafusion_err!( + "transform predicate does not return 1 element" + ))?; + + transformed_plan + .filter(transformed_predicate.clone())? + .project(post_join_projections) + } + + fn rewrite_projection( + &mut self, + original_proj: &Projection, + dependent_join_node: &Node, + current_subquery_depth: usize, + current_plan: LogicalPlanBuilder, + subquery_alias_by_offset: HashMap, + ) -> Result { + // Normalize negated subquries in projection expressions. + let normalized_exprs: Vec = original_proj + .expr + .iter() + .map(normalize_negated_subqueries) + .collect::>>()?; + + let (transformed_plan, transformed_exprs) = + Self::rewrite_exprs_into_dependent_join_plan( + vec![normalized_exprs.iter().collect()], + dependent_join_node, + current_subquery_depth, + current_plan, + subquery_alias_by_offset, + )?; + let transformed_proj_exprs = + transformed_exprs.first().ok_or(internal_datafusion_err!( + "transform projection expr does not return 1 element" + ))?; + transformed_plan.project(transformed_proj_exprs.clone()) + } + + fn rewrite_aggregate( + &mut self, + aggregate: &Aggregate, + dependent_join_node: &Node, + current_subquery_depth: usize, + current_plan: LogicalPlanBuilder, + subquery_alias_by_offset: HashMap, + ) -> Result { + // because dependent join may introduce extra columns + // to evaluate the subquery, the final plan should + // have another projection to remove these redundant columns + let post_join_projections: Vec = aggregate + .schema + .columns() + .iter() + .map(|c| col(c.clone())) + .collect(); + + // Normalize negated subqueries in group and aggregate expressions + let normalized_group_exprs: Vec = aggregate + .group_expr + .iter() + .map(normalize_negated_subqueries) + .collect::>>()?; + let normalized_aggr_exprs: Vec = aggregate + .aggr_expr + .iter() + .map(normalize_negated_subqueries) + .collect::>>()?; + + let (transformed_plan, transformed_exprs) = + Self::rewrite_exprs_into_dependent_join_plan( + vec![ + normalized_group_exprs.iter().collect(), + normalized_aggr_exprs.iter().collect(), + ], + dependent_join_node, + current_subquery_depth, + current_plan, + subquery_alias_by_offset, + )?; + let (new_group_exprs, new_aggr_exprs) = match transformed_exprs.as_slice() { + [first, second] => (first, second), + _ => { + return internal_err!( + "transform group and aggr exprs does not return vector of 2 Vec") + } + }; + + transformed_plan + .aggregate(new_group_exprs.clone(), new_aggr_exprs.clone())? + .project(post_join_projections) + } + + fn rewrite_lateral_join( + &mut self, + join: &Join, + dependent_join_node: &Node, + current_subquery_depth: usize, + current_plan: LogicalPlanBuilder, + subquery_alias_by_offset: HashMap, + ) -> Result { + // this is lateral join + assert!(dependent_join_node.columns_accesses_by_subquery_id.len() == 1); + let (_, column_accesses) = dependent_join_node + .columns_accesses_by_subquery_id + .first() + .ok_or(internal_datafusion_err!( + "a lateral join should always have one child subquery" + ))?; + let alias = subquery_alias_by_offset + .get(&0) + .ok_or(internal_datafusion_err!( + "cannot find subquery alias for only-child of lateral join" + ))?; + let correlated_columns = column_accesses + .iter() + .map(|ac| CorrelatedColumnInfo { + col: ac.col.clone(), + data_type: ac.data_type.clone(), + depth: ac.subquery_depth, + }) + .unique() + .collect(); + + let sq = if let LogicalPlan::Subquery(sq) = join.right.as_ref() { + sq + } else { + return internal_err!("right side of a lateral join is not a subquery"); + }; + let right = sq.subquery.deref().clone(); + // At the time of implementation lateral join condition is not fully clear yet + // So a TODO for future tracking + let lateral_join_condition = if let Some(ref filter) = join.filter { + filter.clone() + } else { + lit(true) + }; + current_plan.dependent_join( + right, + correlated_columns, + None, + current_subquery_depth, + alias.to_string(), + Some((join.join_type, lateral_join_condition)), + ) + } + + // TODO: it is sub-optimal that we completely remove all + // the filters (including the ones that have no subquery attached) + // from the original join + // We have to check if after decorrelation, the other optimizers + // that follows are capable of merging these filters back to the + // join node or not + fn rewrite_join( + &mut self, + join: &Join, + dependent_join_node: &Node, + current_subquery_depth: usize, + subquery_alias_by_offset: HashMap, + ) -> Result { + let mut new_join = join.clone(); + let filter = if let Some(ref filter) = join.filter { + filter + } else { + return internal_err!( + "rewriting a correlated join node without any filter condition" + ); + }; + + new_join.filter = None; + + // Normalize negated subqueries in join filter + let normalized_filter = normalize_negated_subqueries(filter)?; + + let (transformed_plan, transformed_exprs) = + Self::rewrite_exprs_into_dependent_join_plan( + vec![vec![&normalized_filter]], + dependent_join_node, + current_subquery_depth, + LogicalPlanBuilder::new(LogicalPlan::Join(new_join)), + subquery_alias_by_offset, + )?; + + let transformed_predicate = transformed_exprs + .first() + .ok_or(internal_datafusion_err!( + "transform predicate does not return 1 element" + ))? + .first() + .ok_or(internal_datafusion_err!( + "transform predicate does not return 1 element" + ))?; + + transformed_plan.filter(transformed_predicate.clone()) + } + + // lowest common ancestor from stack + // given a tree of + // n1 + // | + // n2 filter where outer.column = exists(subquery) + // ---------------------- + // | \ + // | n5: subquery + // | | + // n3 scan table outer n6 filter outer.column=inner.column + // | + // n7 scan table inner + // this function is called with 2 args a:[1,2,3] and [1,2,5,6,7] + // it then returns the id of the dependent join node (2) + // and the id of the subquery node (5) + fn dependent_join_and_subquery_node_ids( + stack_with_table_provider: &[usize], + stack_with_subquery: &[usize], + ) -> (usize, usize) { + let mut lowest_common_ancestor = 0; + let mut subquery_node_id = 0; + + let min_len = stack_with_table_provider + .len() + .min(stack_with_subquery.len()); + + for i in 0..min_len { + let right_id = stack_with_subquery[i]; + let left_id = stack_with_table_provider[i]; + + if right_id == left_id { + // common parent + lowest_common_ancestor = right_id; + subquery_node_id = stack_with_subquery[i + 1]; + } else { + break; + } + } + + (lowest_common_ancestor, subquery_node_id) + } + + // because the column providers are visited after column-accessor + // (function visit_with_subqueries always visit the subquery before visiting the other children) + // we can always infer the LCA inside this function, by getting the deepest common parent + fn conclude_lowest_dependent_join_node_if_any( + &mut self, + child_id: usize, + col: &Column, + ) -> Result<()> { + if let Some(accesses) = self.all_outer_ref_columns.get(col) { + for access in accesses.iter() { + let mut cur_stack = self.stack.clone(); + + cur_stack.push(child_id); + let (dependent_join_node_id, subquery_node_id) = + Self::dependent_join_and_subquery_node_ids(&cur_stack, &access.stack); + let node = self.nodes.get_mut(&dependent_join_node_id).ok_or( + internal_datafusion_err!( + "dependent join node with id {dependent_join_node_id} not found" + ), + )?; + let accesses = node + .columns_accesses_by_subquery_id + .entry(subquery_node_id) + .or_default(); + accesses.push(ColumnAccess { + col: col.clone(), + node_id: access.node_id, + stack: access.stack.clone(), + data_type: access.data_type.clone(), + subquery_depth: access.subquery_depth, + }); + } + } + Ok(()) + } + + fn mark_outer_column_access( + &mut self, + child_id: usize, + data_type: &DataType, + col: &Column, + ) { + // iter from bottom to top, the goal is to mark the dependent node + // the current child's access + self.all_outer_ref_columns + .entry(col.clone()) + .or_default() + .push(ColumnAccess { + stack: self.stack.clone(), + node_id: child_id, + col: col.clone(), + data_type: data_type.clone(), + subquery_depth: self.subquery_depth, + }); + } + + pub fn rewrite_subqueries_into_dependent_joins( + &mut self, + plan: LogicalPlan, + ) -> Result> { + plan.rewrite_with_subqueries(self) + } +} + +impl DependentJoinRewriter { + pub fn new(alias_generator: Arc) -> Self { + DependentJoinRewriter { + alias_generator, + current_id: 0, + nodes: IndexMap::new(), + stack: vec![], + all_outer_ref_columns: IndexMap::new(), + subquery_depth: 0, + } + } +} + +#[derive(Debug, Clone)] +struct Node { + plan: LogicalPlan, + + // This field is only meaningful if the node is dependent join node. + // It tracks which descendent nodes still accessing the outer columns provided by its + // left child + // The key of this map is node_id of the children subqueries. + // The insertion order matters here, and thus we use IndexMap + columns_accesses_by_subquery_id: IndexMap>, + + is_dependent_join_node: bool, + subquery_types: VecDeque, + // a dependent join node with LogicalPlan::Join variation can have subquery children + // in two scenarios: + // - it is a lateral join + // - it is a normal join, but the join conditions contain subquery + // These two scenarios are mutually exclusive and we need to maintain a flag for this + is_lateral_join: bool, + + // note that for dependent join nodes, there can be more than 1 + // subquery children at a time, but always 1 outer-column-providing-child + // which is at the last element + subquery_type: SubqueryType, +} +#[derive(Debug, Clone)] +enum SubqueryType { + None, + In, + Exists, + Scalar, + LateralJoin, +} + +impl SubqueryType { + fn prefix(&self) -> String { + match self { + SubqueryType::None => "", + SubqueryType::In => "__in_sq", + SubqueryType::Exists => "__exists_sq", + SubqueryType::Scalar => "__scalar_sq", + SubqueryType::LateralJoin => "__lateral_sq", + } + .to_string() + } +} + +fn unwrap_subquery_input_from_expr(expr: &Expr) -> (LogicalPlan, Expr) { + match expr { + Expr::ScalarSubquery(sq) => (sq.subquery.as_ref().clone(), expr.clone()), + Expr::Exists(exists) => (exists.subquery.subquery.as_ref().clone(), expr.clone()), + Expr::InSubquery(in_sq) => { + (in_sq.subquery.subquery.as_ref().clone(), expr.clone()) + } + _ => unreachable!(), + } +} + +// if current expr contains any subquery expr +// this function must not be recursive +fn contains_subquery(expr: &Expr) -> bool { + expr.exists(|expr| { + Ok(matches!( + expr, + Expr::ScalarSubquery(_) | Expr::InSubquery(_) | Expr::Exists(_) + )) + }) + .expect("Inner is always Ok") +} + +/// The rewriting happens up-down, where the parent nodes are downward-visited +/// before its children (subqueries children are visited first). +/// This behavior allow the fact that, at any moment, if we observe a `LogicalPlan` +/// that provides the data for columns, we can assume that all subqueries that reference +/// its data were already visited, and we can conclude the information of +/// the `DependentJoin` +/// needed for the decorrelation: +/// - The subquery expr +/// - The correlated columns on the LHS referenced from the RHS +/// (and its recursing subqueries if any) +/// +/// If in the original node there exists multiple subqueries at the same time +/// two nested `DependentJoin` plans are generated (with equal depth). +/// +/// For illustration, given this query +/// ```sql +/// SELECT ID FROM T1 WHERE EXISTS(SELECT * FROM T2 WHERE T2.ID=T1.ID) OR EXISTS(SELECT * FROM T2 WHERE T2.VALUE=T1.ID); +/// ``` +/// +/// The traversal happens in the following sequence +/// +/// ```text +/// ↓1 +/// ↑12 +/// ┌──────────────┐ +/// │ FILTER │<--- DependentJoin rewrite +/// │ (1) │ happens here (step 12) +/// └─┬─────┬────┬─┘ Here we already have enough information +/// │ │ │ of which node is accessing which column +/// │ │ │ provided by "Table Scan t1" node +/// │ │ │ (for example node (6) below ) +/// │ │ │ +/// │ │ │ +/// │ │ │ +/// ↓2────┘ ↓6 └────↓10 +/// ↑5 ↑11 ↑11 +/// ┌───▼───┐ ┌──▼───┐ ┌───▼───────┐ +/// │SUBQ1 │ │SUBQ2 │ │TABLE SCAN │ +/// └──┬────┘ └──┬───┘ │ t1 │ +/// │ │ └───────────┘ +/// │ │ +/// │ │ +/// │ ↓7 +/// │ ↑10 +/// │ ┌───▼──────┐ +/// │ │Filter │----> mark_outer_column_access(outer_ref) +/// │ │outer_ref │ +/// │ │ (6) │ +/// │ └──┬───────┘ +/// │ │ +/// ↓3 ↓8 +/// ↑4 ↑9 +/// ┌──▼────┐ ┌──▼────┐ +/// │SCAN t2│ │SCAN t2│ +/// └───────┘ └───────┘ +/// ``` +impl TreeNodeRewriter for DependentJoinRewriter { + type Node = LogicalPlan; + + fn f_down(&mut self, node: LogicalPlan) -> Result> { + let new_id = self.current_id; + self.current_id += 1; + let mut is_dependent_join_node = false; + let mut subquery_type = SubqueryType::None; + // for each node, find which column it is accessing, which column it is providing + // Set of columns current node access + let mut subquery_types = VecDeque::new(); + match &node { + LogicalPlan::Filter(f) => { + collect_subquery_types( + &f.predicate, + &mut is_dependent_join_node, + &mut subquery_types, + ); + + f.predicate + .apply(|expr| { + if let Expr::OuterReferenceColumn(data_type, col) = expr { + self.mark_outer_column_access(new_id, data_type, col); + } + Ok(TreeNodeRecursion::Continue) + }) + .expect("traversal is infallible"); + } + LogicalPlan::Projection(proj) => { + for expr in &proj.expr { + collect_subquery_types( + expr, + &mut is_dependent_join_node, + &mut subquery_types, + ); + expr.apply(|expr| { + if let Expr::OuterReferenceColumn(data_type, col) = expr { + self.mark_outer_column_access(new_id, data_type, col); + } + Ok(TreeNodeRecursion::Continue) + })?; + } + } + LogicalPlan::Subquery(_) => { + let parent = self.stack.last().ok_or(internal_datafusion_err!( + "subquery node cannot be at the beginning of the query plan" + ))?; + + let parent_node = self + .nodes + .get_mut(parent) + .ok_or(internal_datafusion_err!("node {parent} not found"))?; + // the inserting sequence matter here + // when a parent has multiple children subquery at the same time + // we rely on the order in which subquery children are visited + // to later on find back the corresponding subquery (if some part of them + // were rewritten in the lower node) + parent_node + .columns_accesses_by_subquery_id + .insert(new_id, vec![]); + + if parent_node.is_lateral_join { + subquery_type = SubqueryType::LateralJoin; + } else { + subquery_type = parent_node.subquery_types.pop_front().ok_or( + internal_datafusion_err!("subquery_types queue is empty"), + )?; + } + } + LogicalPlan::Aggregate(aggregate) => { + for expr in &aggregate.group_expr { + collect_subquery_types( + expr, + &mut is_dependent_join_node, + &mut subquery_types, + ); + + expr.apply(|expr| { + if let Expr::OuterReferenceColumn(data_type, col) = expr { + self.mark_outer_column_access(new_id, data_type, col); + } + Ok(TreeNodeRecursion::Continue) + })?; + } + + for expr in &aggregate.aggr_expr { + collect_subquery_types( + expr, + &mut is_dependent_join_node, + &mut subquery_types, + ); + + expr.apply(|expr| { + if let Expr::OuterReferenceColumn(data_type, col) = expr { + self.mark_outer_column_access(new_id, data_type, col); + } + Ok(TreeNodeRecursion::Continue) + })?; + } + } + LogicalPlan::Join(join) => { + if let LogicalPlan::Subquery(_) = &join.left.as_ref() { + return internal_err!("left side of a join cannot be a subquery"); + } + + // Handle the case lateral join + if let LogicalPlan::Subquery(_) = join.right.as_ref() { + if let Some(ref filter) = join.filter { + if contains_subquery(filter) { + return not_impl_err!( + "subquery inside lateral join condition is not supported" + ); + } + } + self.subquery_depth += 1; + self.stack.push(new_id); + self.nodes.insert( + new_id, + Node { + plan: node.clone(), + is_dependent_join_node: true, + subquery_types: VecDeque::new(), + columns_accesses_by_subquery_id: IndexMap::new(), + subquery_type, + is_lateral_join: true, + }, + ); + + // we assume that RHS is always a subquery for the lateral join + // and because this function assume that subquery side is always + // visited first during f_down, we have to explicitly swap the rewrite + // order at this step, else the function visit_with_subqueries will + // call f_down for the LHS instead + let transformed_subquery = self + .rewrite_subqueries_into_dependent_joins( + join.right.deref().clone(), + )? + .data; + let transformed_left = self + .rewrite_subqueries_into_dependent_joins( + join.left.deref().clone(), + )? + .data; + let mut new_join_node = join.clone(); + new_join_node.right = Arc::new(transformed_subquery); + new_join_node.left = Arc::new(transformed_left); + return Ok(Transformed::new( + LogicalPlan::Join(new_join_node), + true, + // since we rewrite the children directly in this function, + TreeNodeRecursion::Jump, + )); + } + + if let Some(filter) = &join.filter { + collect_subquery_types( + filter, + &mut is_dependent_join_node, + &mut subquery_types, + ); + + filter.apply(|expr| { + if let Expr::OuterReferenceColumn(data_type, col) = expr { + self.mark_outer_column_access(new_id, data_type, col); + } + Ok(TreeNodeRecursion::Continue) + })?; + } + } + LogicalPlan::Sort(sort) => { + for expr in &sort.expr { + collect_subquery_types( + &expr.expr, + &mut is_dependent_join_node, + &mut subquery_types, + ); + + expr.expr.apply(|expr| { + if let Expr::OuterReferenceColumn(data_type, col) = expr { + self.mark_outer_column_access(new_id, data_type, col); + } + Ok(TreeNodeRecursion::Continue) + })?; + } + } + // TODO: maybe there are more logical plan that provides columns + // aside from TableScan + LogicalPlan::TableScan(tbl_scan) => { + tbl_scan + .projected_schema + .columns() + .iter() + .try_for_each(|col| { + self.conclude_lowest_dependent_join_node_if_any(new_id, col) + })?; + } + // Similar to TableScan, this node may provide column names which + // is referenced inside some subqueries + LogicalPlan::SubqueryAlias(alias) => { + alias.schema.columns().iter().try_for_each(|col| { + self.conclude_lowest_dependent_join_node_if_any(new_id, col) + })?; + } + _ => {} + }; + + if is_dependent_join_node { + self.subquery_depth += 1 + } + self.stack.push(new_id); + self.nodes.insert( + new_id, + Node { + plan: node.clone(), + is_dependent_join_node, + columns_accesses_by_subquery_id: IndexMap::new(), + subquery_types, + subquery_type, + is_lateral_join: false, + }, + ); + + Ok(Transformed::no(node)) + } + + /// All rewrite happens inside upward traversal + /// and only happens if the node is a "dependent join node" + /// (i.e the node with at least one subquery expr) + /// When all dependency information are already collected + fn f_up(&mut self, node: LogicalPlan) -> Result> { + // if the node in the f_up meet any node in the stack, it means that node itself + // is a dependent join node,transformation by + // build a join based on + let current_node_id = self.stack.pop().ok_or(internal_datafusion_err!( + "stack cannot be empty during upward traversal" + ))?; + let node_info = if let Entry::Occupied(e) = self.nodes.entry(current_node_id) { + let node_info = e.get(); + if !node_info.is_dependent_join_node { + return Ok(Transformed::no(node)); + } + e.swap_remove() + } else { + unreachable!() + }; + + let current_subquery_depth = self.subquery_depth; + self.subquery_depth -= 1; + + let cloned_input = (**node.inputs().first().ok_or(internal_datafusion_err!( + "logical plan {} does not have any input", + node + ))?) + .clone(); + let mut current_plan = LogicalPlanBuilder::new(cloned_input); + let mut subquery_alias_by_offset = HashMap::new(); + for (subquery_offset, (subquery_id, _)) in + node_info.columns_accesses_by_subquery_id.iter().enumerate() + { + let subquery_node = self + .nodes + .get(subquery_id) + .ok_or(internal_datafusion_err!("node {subquery_id} not found"))?; + let alias = self + .alias_generator + .next(&subquery_node.subquery_type.prefix()); + subquery_alias_by_offset.insert(subquery_offset, alias); + } + + match &node { + LogicalPlan::Projection(projection) => { + current_plan = self.rewrite_projection( + projection, + &node_info, + current_subquery_depth, + current_plan, + subquery_alias_by_offset, + )?; + } + LogicalPlan::Filter(filter) => { + current_plan = self.rewrite_filter( + filter, + &node_info, + current_subquery_depth, + current_plan, + subquery_alias_by_offset, + )?; + } + + LogicalPlan::Join(join) => { + if node_info.is_lateral_join { + current_plan = self.rewrite_lateral_join( + join, + &node_info, + current_subquery_depth, + current_plan, + subquery_alias_by_offset, + )? + } else { + // Correlated subquery in join filter. + current_plan = self.rewrite_join( + join, + &node_info, + current_subquery_depth, + subquery_alias_by_offset, + )?; + }; + } + LogicalPlan::Aggregate(aggregate) => { + current_plan = self.rewrite_aggregate( + aggregate, + &node_info, + current_subquery_depth, + current_plan, + subquery_alias_by_offset, + )?; + } + _ => { + unimplemented!( + "implement more dependent join node creation for node {}", + node + ) + } + } + Ok(Transformed::yes(current_plan.build()?)) + } +} + +fn collect_subquery_types( + sub_expr: &Expr, + is_dependent_join_node: &mut bool, + subquery_types: &mut VecDeque, +) { + sub_expr + .apply(|expr| { + Ok(match expr { + Expr::ScalarSubquery(_) => { + *is_dependent_join_node = true; + subquery_types.push_back(SubqueryType::Scalar); + TreeNodeRecursion::Continue + } + Expr::InSubquery(_) => { + *is_dependent_join_node = true; + subquery_types.push_back(SubqueryType::In); + TreeNodeRecursion::Continue + } + Expr::Exists(_) => { + *is_dependent_join_node = true; + subquery_types.push_back(SubqueryType::Exists); + TreeNodeRecursion::Continue + } + _ => TreeNodeRecursion::Continue, + }) + }) + .expect("Inner is always Ok"); + + //match sub_expr { + // Expr::ScalarSubquery(_) => { + // *is_dependent_join_node = true; + // subquery_types.push_back(SubqueryType::Scalar); + // } + // Expr::InSubquery(_) => { + // *is_dependent_join_node = true; + // subquery_types.push_back(SubqueryType::In); + // } + // Expr::Exists(_) => { + // *is_dependent_join_node = true; + // subquery_types.push_back(SubqueryType::Exists); + // } + // _ => {} + //} +} + +/// Normalize negated subqueries by extracting the NOT to the top level +/// For example: `InSubquery{negated: true, ...}` becomes `NOT(InSubquery{negated: false, ...})` +fn normalize_negated_subqueries(expr: &Expr) -> Result { + expr.clone() + .transform(|e| { + match e { + Expr::InSubquery(mut in_subquery) if in_subquery.negated => { + // Convert negated InSubquery to NOT(InSubquery{negated: false}) + in_subquery.negated = false; + Ok(Transformed::yes(not(Expr::InSubquery(in_subquery)))) + } + Expr::Exists(mut exists_subuqery) if exists_subuqery.negated => { + // Convert negated ExistsSubquery to NOT(ExistsSubquery{negated: false}) + exists_subuqery.negated = false; + Ok(Transformed::yes(not(Expr::Exists(exists_subuqery)))) + } + _ => Ok(Transformed::no(e)), + } + }) + .map(|t| t.data) +} + +/// Optimizer rule for rewriting subqueries to dependent join. +#[allow(dead_code)] +#[derive(Debug)] +pub struct RewriteDependentJoin {} + +impl Default for RewriteDependentJoin { + fn default() -> Self { + Self::new() + } +} + +impl RewriteDependentJoin { + pub fn new() -> Self { + RewriteDependentJoin {} + } +} + +impl OptimizerRule for RewriteDependentJoin { + fn supports_rewrite(&self) -> bool { + true + } + + // Convert all subqueries (maybe including lateral join in the future) to temporary + // LogicalPlan node called DependentJoin. + fn rewrite( + &self, + plan: LogicalPlan, + config: &dyn OptimizerConfig, + ) -> Result> { + let mut transformer = + DependentJoinRewriter::new(Arc::clone(config.alias_generator())); + let rewrite_result = transformer.rewrite_subqueries_into_dependent_joins(plan)?; + Ok(rewrite_result) + } + + fn name(&self) -> &str { + "rewrite_dependent_join" + } + + fn apply_order(&self) -> Option { + None + } +} + +#[cfg(test)] +mod tests { + use super::DependentJoinRewriter; + + use crate::test::{test_table_scan_with_name, test_table_with_columns}; + use arrow::datatypes::{DataType, TimeUnit}; + use datafusion_common::{alias::AliasGenerator, Result, Spans}; + use datafusion_expr::{ + and, binary_expr, exists, expr::InSubquery, expr_fn::col, in_subquery, lit, + not_exists, out_ref_col, scalar_subquery, Expr, JoinType, LogicalPlan, + LogicalPlanBuilder, Operator, SortExpr, Subquery, + }; + use datafusion_functions_aggregate::{count::count, sum::sum}; + use insta::assert_snapshot; + use std::sync::Arc; + + macro_rules! assert_dependent_join_rewrite_err { + ( + $plan:expr + // @ $expected:literal $(,)? + ) => {{ + let mut index = DependentJoinRewriter::new(Arc::new(AliasGenerator::new())); + let transformed = + index.rewrite_subqueries_into_dependent_joins($plan.clone()); + if let Err(err) = transformed { + // assert_snapshot!( + // err, + // @ $expected, + // ) + } else { + panic!("rewriting {} was not returning error", $plan) + } + }}; + } + + macro_rules! assert_dependent_join_rewrite { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let mut index = DependentJoinRewriter::new(Arc::new(AliasGenerator::new())); + let transformed = index.rewrite_subqueries_into_dependent_joins($plan)?; + assert!(transformed.transformed); + let display = transformed.data.display_indent_schema(); + assert_snapshot!( + display, + @ $expected, + ) + }}; + } + + #[test] + fn uncorrelated_lateral_join() -> Result<()> { + let outer_table = test_table_scan_with_name("outer_table")?; + let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; + + let lateral_join_rhs = Arc::new( + LogicalPlanBuilder::from(inner_table_lv1.clone()) + .filter(col("inner_table_lv1.c").eq(lit(1)))? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(outer_table.clone()) + .join_on( + LogicalPlan::Subquery(Subquery { + subquery: lateral_join_rhs, + outer_ref_columns: vec![], + spans: Spans::new(), + }), + JoinType::Inner, + vec![lit(true)], + )? + .build()?; + + // Inner Join: Filter: Boolean(true) + // TableScan: outer_table + // Subquery: + // Filter: inner_table_lv1.c = outer_ref(outer_table.c) + // TableScan: inner_table_lv1 + + assert_dependent_join_rewrite!(plan, @r" + DependentJoin on [] lateral Inner join with Boolean(true) depth 1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] + Filter: inner_table_lv1.c = Int32(1) [a:UInt32, b:UInt32, c:UInt32] + TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] + "); + Ok(()) + } + + #[test] + fn correlated_lateral_join() -> Result<()> { + let outer_table = test_table_scan_with_name("outer_table")?; + let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; + + let lateral_join_rhs = Arc::new( + LogicalPlanBuilder::from(inner_table_lv1.clone()) + .filter( + col("inner_table_lv1.c") + .eq(out_ref_col(DataType::UInt32, "outer_table.c")), + )? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(outer_table.clone()) + .join_on( + LogicalPlan::Subquery(Subquery { + subquery: lateral_join_rhs, + outer_ref_columns: vec![out_ref_col( + DataType::UInt32, + "outer_table.c", + )], + spans: Spans::new(), + }), + JoinType::Inner, + vec![lit(true)], + )? + .build()?; + + // Inner Join: Filter: Boolean(true) + // TableScan: outer_table + // Subquery: + // Filter: inner_table_lv1.c = outer_ref(outer_table.c) + // TableScan: inner_table_lv1 + + assert_dependent_join_rewrite!(plan, @r" + DependentJoin on [outer_table.c lvl 1] lateral Inner join with Boolean(true) depth 1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] + Filter: inner_table_lv1.c = outer_ref(outer_table.c) [a:UInt32, b:UInt32, c:UInt32] + TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] + "); + Ok(()) + } + + #[test] + fn scalar_subquery_nested_inside_a_lateral_join() -> Result<()> { + let outer_table = test_table_scan_with_name("outer_table")?; + + let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; + + let inner_table_lv2 = test_table_scan_with_name("inner_table_lv2")?; + let scalar_sq_level2 = Arc::new( + LogicalPlanBuilder::from(inner_table_lv2) + .filter( + col("inner_table_lv2.a") + .eq(out_ref_col(DataType::UInt32, "outer_table.a")) + .and( + col("inner_table_lv2.b") + .eq(out_ref_col(DataType::UInt32, "inner_table_lv1.b")), + ), + )? + .aggregate(Vec::::new(), vec![count(col("inner_table_lv2.a"))])? + .build()?, + ); + let sq_level1 = Arc::new( + LogicalPlanBuilder::from(inner_table_lv1.clone()) + .filter( + col("inner_table_lv1.c") + .eq(out_ref_col(DataType::UInt32, "outer_table.c")) + .and(scalar_subquery(scalar_sq_level2).eq(lit(1))), + )? + .aggregate(Vec::::new(), vec![count(col("inner_table_lv1.a"))])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(outer_table.clone()) + .join_on( + LogicalPlan::Subquery(Subquery { + subquery: sq_level1, + outer_ref_columns: vec![out_ref_col( + DataType::UInt32, + "outer_table.c", + // note that subquery lvl2 is referencing outer_table.a, and it is not being listed here + // this simulate the limitation of current subquery planning and assert + // that the rewriter can fill in this gap + )], + spans: Spans::new(), + }), + JoinType::Inner, + vec![lit(true)], + )? + .build()?; + + // Inner Join: Filter: Boolean(true) + // TableScan: outer_table + // Subquery: + // Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv1.a)]] + // Filter: inner_table_lv1.c = outer_ref(outer_table.c) AND () = Int32(1) + // Subquery: + // Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv2.a)]] + // Filter: inner_table_lv2.a = outer_ref(outer_table.a) AND inner_table_lv2.b = outer_ref(inner_table_lv1.b) + // TableScan: inner_table_lv2 + // TableScan: inner_table_lv1 + + assert_dependent_join_rewrite!(plan, @r" + DependentJoin on [outer_table.a lvl 2, outer_table.c lvl 1] lateral Inner join with Boolean(true) depth 1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] + Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv1.a)]] [count(inner_table_lv1.a):Int64] + Projection: inner_table_lv1.a, inner_table_lv1.b, inner_table_lv1.c [a:UInt32, b:UInt32, c:UInt32] + Filter: inner_table_lv1.c = outer_ref(outer_table.c) AND __scalar_sq_1 = Int32(1) [a:UInt32, b:UInt32, c:UInt32, __scalar_sq_1:Int64] + DependentJoin on [inner_table_lv1.b lvl 2] with expr () depth 2 [a:UInt32, b:UInt32, c:UInt32, __scalar_sq_1:Int64] + TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] + Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv2.a)]] [count(inner_table_lv2.a):Int64] + Filter: inner_table_lv2.a = outer_ref(outer_table.a) AND inner_table_lv2.b = outer_ref(inner_table_lv1.b) [a:UInt32, b:UInt32, c:UInt32] + TableScan: inner_table_lv2 [a:UInt32, b:UInt32, c:UInt32] + "); + Ok(()) + } + + #[test] + fn join_logical_plan_with_subquery_in_filter_expr() -> Result<()> { + let outer_left_table = test_table_scan_with_name("outer_right_table")?; + let outer_right_table = test_table_scan_with_name("outer_left_table")?; + let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; + let sq_level1 = Arc::new( + LogicalPlanBuilder::from(inner_table_lv1) + .filter(col("inner_table_lv1.a").eq(binary_expr( + out_ref_col(DataType::UInt32, "outer_left_table.a"), + Operator::Plus, + out_ref_col(DataType::UInt32, "outer_right_table.a"), + )))? + .aggregate(Vec::::new(), vec![count(col("inner_table_lv1.a"))])? + .project(vec![count(col("inner_table_lv1.a")).alias("count_a")])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(outer_left_table.clone()) + .join_on( + outer_right_table, + JoinType::Left, + vec![col("outer_left_table.a").eq(col("outer_right_table.a"))], + )? + .filter( + col("outer_table.a") + .gt(lit(1)) + .and(in_subquery(col("outer_table.c"), sq_level1)), + )? + .build()?; + + // Filter: outer_table.a > Int32(1) AND outer_table.c IN () + // Subquery: + // Projection: count(inner_table_lv1.a) AS count_a + // Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv1.a)]] + // Filter: inner_table_lv1.a = outer_ref(outer_left_table.a) + outer_ref(outer_right_table.a) + // TableScan: inner_table_lv1 + // Left Join: Filter: outer_left_table.a = outer_right_table.a + // TableScan: outer_right_table + // TableScan: outer_left_table + + assert_dependent_join_rewrite!(plan, @r" + Projection: outer_right_table.a, outer_right_table.b, outer_right_table.c, outer_left_table.a, outer_left_table.b, outer_left_table.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N] + Filter: outer_table.a > Int32(1) AND __in_sq_1 [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N, __in_sq_1:Boolean] + DependentJoin on [outer_right_table.a lvl 1, outer_left_table.a lvl 1] with expr outer_table.c IN () depth 1 [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N, __in_sq_1:Boolean] + Left Join(ComparisonJoin): Filter: outer_left_table.a = outer_right_table.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N] + TableScan: outer_right_table [a:UInt32, b:UInt32, c:UInt32] + TableScan: outer_left_table [a:UInt32, b:UInt32, c:UInt32] + Projection: count(inner_table_lv1.a) AS count_a [count_a:Int64] + Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv1.a)]] [count(inner_table_lv1.a):Int64] + Filter: inner_table_lv1.a = outer_ref(outer_left_table.a) + outer_ref(outer_right_table.a) [a:UInt32, b:UInt32, c:UInt32] + TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] + "); + Ok(()) + } + #[test] + fn subquery_in_from_expr() -> Result<()> { + Ok(()) + } + #[test] + fn nested_subquery_in_projection_expr() -> Result<()> { + let outer_table = test_table_scan_with_name("outer_table")?; + let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; + + let inner_table_lv2 = test_table_scan_with_name("inner_table_lv2")?; + let scalar_sq_level2 = Arc::new( + LogicalPlanBuilder::from(inner_table_lv2) + .filter( + col("inner_table_lv2.a") + .eq(out_ref_col(DataType::UInt32, "outer_table.a")) + .and( + col("inner_table_lv2.b") + .eq(out_ref_col(DataType::UInt32, "inner_table_lv1.b")), + ), + )? + .aggregate(Vec::::new(), vec![count(col("inner_table_lv2.a"))])? + .build()?, + ); + let scalar_sq_level1_a = Arc::new( + LogicalPlanBuilder::from(inner_table_lv1.clone()) + .filter( + col("inner_table_lv1.c") + .eq(out_ref_col(DataType::UInt32, "outer_table.c")) + // scalar_sq_level2 is intentionally shared between both + // scalar_sq_level1_a and scalar_sq_level1_b + // to check if the framework can uniquely identify the correlated columns + .and(scalar_subquery(Arc::clone(&scalar_sq_level2)).eq(lit(1))), + )? + .aggregate(Vec::::new(), vec![count(col("inner_table_lv1.a"))])? + .build()?, + ); + let scalar_sq_level1_b = Arc::new( + LogicalPlanBuilder::from(inner_table_lv1.clone()) + .filter( + col("inner_table_lv1.c") + .eq(out_ref_col(DataType::UInt32, "outer_table.c")) + .and(scalar_subquery(scalar_sq_level2).eq(lit(1))), + )? + .aggregate(Vec::::new(), vec![count(col("inner_table_lv1.b"))])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(outer_table.clone()) + .project(vec![ + col("outer_table.a"), + binary_expr( + scalar_subquery(scalar_sq_level1_a), + Operator::Plus, + scalar_subquery(scalar_sq_level1_b), + ), + ])? + .build()?; + + // Projection: outer_table.a, () + () + // Subquery: + // Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv1.a)]] + // Filter: inner_table_lv1.c = outer_ref(outer_table.c) AND () = Int32(1) + // Subquery: + // Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv2.a)]] + // Filter: inner_table_lv2.a = outer_ref(outer_table.a) AND inner_table_lv2.b = outer_ref(inner_table_lv1.b) + // TableScan: inner_table_lv2 + // TableScan: inner_table_lv1 + // Subquery: + // Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv1.b)]] + // Filter: inner_table_lv1.c = outer_ref(outer_table.c) AND () = Int32(1) + // Subquery: + // Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv2.a)]] + // Filter: inner_table_lv2.a = outer_ref(outer_table.a) AND inner_table_lv2.b = outer_ref(inner_table_lv1.b) + // TableScan: inner_table_lv2 + // TableScan: inner_table_lv1 + // TableScan: outer_table + + assert_dependent_join_rewrite!(plan, @r" + Projection: outer_table.a, __scalar_sq_3 + __scalar_sq_4 [a:UInt32, __scalar_sq_3 + __scalar_sq_4:Int64] + DependentJoin on [outer_table.a lvl 2, outer_table.c lvl 1] with expr () depth 1 [a:UInt32, b:UInt32, c:UInt32, __scalar_sq_3:Int64, __scalar_sq_4:Int64] + DependentJoin on [inner_table_lv1.b lvl 2, outer_table.a lvl 2, outer_table.c lvl 1] with expr () depth 1 [a:UInt32, b:UInt32, c:UInt32, __scalar_sq_3:Int64] + TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] + Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv1.a)]] [count(inner_table_lv1.a):Int64] + Projection: inner_table_lv1.a, inner_table_lv1.b, inner_table_lv1.c [a:UInt32, b:UInt32, c:UInt32] + Filter: inner_table_lv1.c = outer_ref(outer_table.c) AND __scalar_sq_1 = Int32(1) [a:UInt32, b:UInt32, c:UInt32, __scalar_sq_1:Int64] + DependentJoin on [inner_table_lv1.b lvl 2] with expr () depth 2 [a:UInt32, b:UInt32, c:UInt32, __scalar_sq_1:Int64] + TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] + Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv2.a)]] [count(inner_table_lv2.a):Int64] + Filter: inner_table_lv2.a = outer_ref(outer_table.a) AND inner_table_lv2.b = outer_ref(inner_table_lv1.b) [a:UInt32, b:UInt32, c:UInt32] + TableScan: inner_table_lv2 [a:UInt32, b:UInt32, c:UInt32] + Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv1.b)]] [count(inner_table_lv1.b):Int64] + Projection: inner_table_lv1.a, inner_table_lv1.b, inner_table_lv1.c [a:UInt32, b:UInt32, c:UInt32] + Filter: inner_table_lv1.c = outer_ref(outer_table.c) AND __scalar_sq_2 = Int32(1) [a:UInt32, b:UInt32, c:UInt32, __scalar_sq_2:Int64] + DependentJoin on [inner_table_lv1.b lvl 2] with expr () depth 2 [a:UInt32, b:UInt32, c:UInt32, __scalar_sq_2:Int64] + TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] + Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv2.a)]] [count(inner_table_lv2.a):Int64] + Filter: inner_table_lv2.a = outer_ref(outer_table.a) AND inner_table_lv2.b = outer_ref(inner_table_lv1.b) [a:UInt32, b:UInt32, c:UInt32] + TableScan: inner_table_lv2 [a:UInt32, b:UInt32, c:UInt32] + "); + Ok(()) + } + + #[test] + fn nested_subquery_in_filter() -> Result<()> { + let outer_table = test_table_scan_with_name("outer_table")?; + let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; + + let inner_table_lv2 = test_table_scan_with_name("inner_table_lv2")?; + let scalar_sq_level2 = Arc::new( + LogicalPlanBuilder::from(inner_table_lv2) + .filter( + col("inner_table_lv2.a") + .eq(out_ref_col(DataType::UInt32, "outer_table.a")) + .and( + col("inner_table_lv2.b") + .eq(out_ref_col(DataType::UInt32, "inner_table_lv1.b")), + ), + )? + .aggregate(Vec::::new(), vec![count(col("inner_table_lv2.a"))])? + .build()?, + ); + let scalar_sq_level1 = Arc::new( + LogicalPlanBuilder::from(inner_table_lv1.clone()) + .filter( + col("inner_table_lv1.c") + .eq(out_ref_col(DataType::UInt32, "outer_table.c")) + .and(scalar_subquery(scalar_sq_level2).eq(lit(1))), + )? + .aggregate(Vec::::new(), vec![count(col("inner_table_lv1.a"))])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(outer_table.clone()) + .filter( + col("outer_table.a") + .gt(lit(1)) + .and(scalar_subquery(scalar_sq_level1).eq(col("outer_table.a"))), + )? + .build()?; + + // Filter: outer_table.a > Int32(1) AND () = outer_table.a + // Subquery: + // Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv1.a)]] + // Filter: inner_table_lv1.c = outer_ref(outer_table.c) AND () = Int32(1) + // Subquery: + // Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv2.a)]] + // Filter: inner_table_lv2.a = outer_ref(outer_table.a) AND inner_table_lv2.b = outer_ref(inner_table_lv1 + // .b) + // TableScan: inner_table_lv2 + // TableScan: inner_table_lv1 + // TableScan: outer_table + + assert_dependent_join_rewrite!(plan, @r" + Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] + Filter: outer_table.a > Int32(1) AND __scalar_sq_2 = outer_table.a [a:UInt32, b:UInt32, c:UInt32, __scalar_sq_2:Int64] + DependentJoin on [outer_table.a lvl 2, outer_table.c lvl 1] with expr () depth 1 [a:UInt32, b:UInt32, c:UInt32, __scalar_sq_2:Int64] + TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] + Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv1.a)]] [count(inner_table_lv1.a):Int64] + Projection: inner_table_lv1.a, inner_table_lv1.b, inner_table_lv1.c [a:UInt32, b:UInt32, c:UInt32] + Filter: inner_table_lv1.c = outer_ref(outer_table.c) AND __scalar_sq_1 = Int32(1) [a:UInt32, b:UInt32, c:UInt32, __scalar_sq_1:Int64] + DependentJoin on [inner_table_lv1.b lvl 2] with expr () depth 2 [a:UInt32, b:UInt32, c:UInt32, __scalar_sq_1:Int64] + TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] + Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv2.a)]] [count(inner_table_lv2.a):Int64] + Filter: inner_table_lv2.a = outer_ref(outer_table.a) AND inner_table_lv2.b = outer_ref(inner_table_lv1.b) [a:UInt32, b:UInt32, c:UInt32] + TableScan: inner_table_lv2 [a:UInt32, b:UInt32, c:UInt32] + "); + Ok(()) + } + #[test] + fn two_subqueries_in_the_same_filter_expr() -> Result<()> { + let outer_table = test_table_scan_with_name("outer_table")?; + let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; + let in_sq_level1 = Arc::new( + LogicalPlanBuilder::from(inner_table_lv1.clone()) + .filter(col("inner_table_lv1.c").eq(lit(2)))? + .project(vec![col("inner_table_lv1.a")])? + .build()?, + ); + let exist_sq_level1 = Arc::new( + LogicalPlanBuilder::from(inner_table_lv1) + .filter( + col("inner_table_lv1.a").and(col("inner_table_lv1.b").eq(lit(1))), + )? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(outer_table.clone()) + .filter( + col("outer_table.a") + .gt(lit(1)) + .and(exists(exist_sq_level1)) + .and(in_subquery(col("outer_table.b"), in_sq_level1)), + )? + .build()?; + + // Filter: outer_table.a > Int32(1) AND EXISTS () AND outer_table.b IN () + // Subquery: + // Filter: inner_table_lv1.a AND inner_table_lv1.b = Int32(1) + // TableScan: inner_table_lv1 + // Subquery: + // Projection: inner_table_lv1.a + // Filter: inner_table_lv1.c = Int32(2) + // TableScan: inner_table_lv1 + // TableScan: outer_table + + assert_dependent_join_rewrite!(plan, @r" + Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] + Filter: outer_table.a > Int32(1) AND __exists_sq_1 AND __in_sq_2 [a:UInt32, b:UInt32, c:UInt32, __exists_sq_1:Boolean, __in_sq_2:Boolean] + DependentJoin on [] with expr outer_table.b IN () depth 1 [a:UInt32, b:UInt32, c:UInt32, __exists_sq_1:Boolean, __in_sq_2:Boolean] + DependentJoin on [] with expr EXISTS () depth 1 [a:UInt32, b:UInt32, c:UInt32, __exists_sq_1:Boolean] + TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] + Filter: inner_table_lv1.a AND inner_table_lv1.b = Int32(1) [a:UInt32, b:UInt32, c:UInt32] + TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] + Projection: inner_table_lv1.a [a:UInt32] + Filter: inner_table_lv1.c = Int32(2) [a:UInt32, b:UInt32, c:UInt32] + TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] + "); + Ok(()) + } + + #[test] + fn in_subquery_with_count_of_1_depth() -> Result<()> { + let outer_table = test_table_scan_with_name("outer_table")?; + let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; + let sq_level1 = Arc::new( + LogicalPlanBuilder::from(inner_table_lv1) + .filter( + col("inner_table_lv1.a") + .eq(out_ref_col(DataType::UInt32, "outer_table.a")) + .and( + out_ref_col(DataType::UInt32, "outer_table.a") + .gt(col("inner_table_lv1.c")), + ) + .and(col("inner_table_lv1.b").eq(lit(1))) + .and( + out_ref_col(DataType::UInt32, "outer_table.b") + .eq(col("inner_table_lv1.b")), + ), + )? + .aggregate(Vec::::new(), vec![count(col("inner_table_lv1.a"))])? + .project(vec![count(col("inner_table_lv1.a")).alias("count_a")])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(outer_table.clone()) + .filter( + col("outer_table.a") + .gt(lit(1)) + .and(in_subquery(col("outer_table.c"), sq_level1)), + )? + .build()?; + + // Filter: outer_table.a > Int32(1) AND outer_table.c IN () + // Subquery: + // Projection: count(inner_table_lv1.a) AS count_a + // Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv1.a)]] + // Filter: inner_table_lv1.a = outer_ref(outer_table.a) AND outer_ref(outer_table.a) > inner_table_lv1.c AND inner_table_lv1.b = Int32(1) AND outer_ref(outer_table.b) = inner_table_lv1.b + // TableScan: inner_table_lv1 + // TableScan: outer_table + + assert_dependent_join_rewrite!(plan, @r" + Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] + Filter: outer_table.a > Int32(1) AND __in_sq_1 [a:UInt32, b:UInt32, c:UInt32, __in_sq_1:Boolean] + DependentJoin on [outer_table.a lvl 1, outer_table.b lvl 1] with expr outer_table.c IN () depth 1 [a:UInt32, b:UInt32, c:UInt32, __in_sq_1:Boolean] + TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] + Projection: count(inner_table_lv1.a) AS count_a [count_a:Int64] + Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv1.a)]] [count(inner_table_lv1.a):Int64] + Filter: inner_table_lv1.a = outer_ref(outer_table.a) AND outer_ref(outer_table.a) > inner_table_lv1.c AND inner_table_lv1.b = Int32(1) AND outer_ref(outer_table.b) = inner_table_lv1.b [a:UInt32, b:UInt32, c:UInt32] + TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] + "); + Ok(()) + } + #[test] + fn correlated_exist_subquery() -> Result<()> { + let outer_table = test_table_scan_with_name("outer_table")?; + let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; + let sq_level1 = Arc::new( + LogicalPlanBuilder::from(inner_table_lv1) + .filter( + col("inner_table_lv1.a") + .eq(out_ref_col(DataType::UInt32, "outer_table.a")) + .and( + out_ref_col(DataType::UInt32, "outer_table.a") + .gt(col("inner_table_lv1.c")), + ) + .and(col("inner_table_lv1.b").eq(lit(1))) + .and( + out_ref_col(DataType::UInt32, "outer_table.b") + .eq(col("inner_table_lv1.b")), + ), + )? + .project(vec![ + out_ref_col(DataType::UInt32, "outer_table.b").alias("outer_b_alias") + ])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(outer_table.clone()) + .filter(col("outer_table.a").gt(lit(1)).and(exists(sq_level1)))? + .build()?; + + // Filter: outer_table.a > Int32(1) AND EXISTS () + // Subquery: + // Projection: outer_ref(outer_table.b) AS outer_b_alias + // Filter: inner_table_lv1.a = outer_ref(outer_table.a) AND outer_ref(outer_table.a) > inner_table_lv1.c AND in + // ner_table_lv1.b = Int32(1) AND outer_ref(outer_table.b) = inner_table_lv1.b + // TableScan: inner_table_lv1 + // TableScan: outer_table + + assert_dependent_join_rewrite!(plan, @r" + Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] + Filter: outer_table.a > Int32(1) AND __exists_sq_1 [a:UInt32, b:UInt32, c:UInt32, __exists_sq_1:Boolean] + DependentJoin on [outer_table.a lvl 1, outer_table.b lvl 1] with expr EXISTS () depth 1 [a:UInt32, b:UInt32, c:UInt32, __exists_sq_1:Boolean] + TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] + Projection: outer_ref(outer_table.b) AS outer_b_alias [outer_b_alias:UInt32;N] + Filter: inner_table_lv1.a = outer_ref(outer_table.a) AND outer_ref(outer_table.a) > inner_table_lv1.c AND inner_table_lv1.b = Int32(1) AND outer_ref(outer_table.b) = inner_table_lv1.b [a:UInt32, b:UInt32, c:UInt32] + TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] + "); + Ok(()) + } + + #[test] + fn uncorrelated_exist_subquery() -> Result<()> { + let outer_table = test_table_scan_with_name("outer_table")?; + let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; + let sq_level1 = Arc::new( + LogicalPlanBuilder::from(inner_table_lv1) + .filter(col("inner_table_lv1.b").eq(lit(1)))? + .project(vec![col("inner_table_lv1.b"), col("inner_table_lv1.a")])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(outer_table.clone()) + .filter(col("outer_table.a").gt(lit(1)).and(exists(sq_level1)))? + .build()?; + + // Filter: outer_table.a > Int32(1) AND EXISTS () + // Subquery: + // Projection: inner_table_lv1.b, inner_table_lv1.a + // Filter: inner_table_lv1.b = Int32(1) + // TableScan: inner_table_lv1 + // TableScan: outer_table + + assert_dependent_join_rewrite!(plan, @r" + Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] + Filter: outer_table.a > Int32(1) AND __exists_sq_1 [a:UInt32, b:UInt32, c:UInt32, __exists_sq_1:Boolean] + DependentJoin on [] with expr EXISTS () depth 1 [a:UInt32, b:UInt32, c:UInt32, __exists_sq_1:Boolean] + TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] + Projection: inner_table_lv1.b, inner_table_lv1.a [b:UInt32, a:UInt32] + Filter: inner_table_lv1.b = Int32(1) [a:UInt32, b:UInt32, c:UInt32] + TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] + "); + + Ok(()) + } + #[test] + fn uncorrelated_in_subquery() -> Result<()> { + let outer_table = test_table_scan_with_name("outer_table")?; + let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; + let sq_level1 = Arc::new( + LogicalPlanBuilder::from(inner_table_lv1) + .filter(col("inner_table_lv1.b").eq(lit(1)))? + .project(vec![col("inner_table_lv1.b")])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(outer_table.clone()) + .filter( + col("outer_table.a") + .gt(lit(1)) + .and(in_subquery(col("outer_table.c"), sq_level1)), + )? + .build()?; + + // Filter: outer_table.a > Int32(1) AND outer_table.c IN () + // Subquery: + // Projection: inner_table_lv1.b + // Filter: inner_table_lv1.b = Int32(1) + // TableScan: inner_table_lv1 + // TableScan: outer_table + + assert_dependent_join_rewrite!(plan, @r" + Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] + Filter: outer_table.a > Int32(1) AND __in_sq_1 [a:UInt32, b:UInt32, c:UInt32, __in_sq_1:Boolean] + DependentJoin on [] with expr outer_table.c IN () depth 1 [a:UInt32, b:UInt32, c:UInt32, __in_sq_1:Boolean] + TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] + Projection: inner_table_lv1.b [b:UInt32] + Filter: inner_table_lv1.b = Int32(1) [a:UInt32, b:UInt32, c:UInt32] + TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] + "); + + Ok(()) + } + #[test] + fn correlated_in_subquery() -> Result<()> { + let outer_table = test_table_scan_with_name("outer_table")?; + let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; + let sq_level1 = Arc::new( + LogicalPlanBuilder::from(inner_table_lv1) + .filter( + col("inner_table_lv1.a") + .eq(out_ref_col(DataType::UInt32, "outer_table.a")) + .and( + out_ref_col(DataType::UInt32, "outer_table.a") + .gt(col("inner_table_lv1.c")), + ) + .and(col("inner_table_lv1.b").eq(lit(1))) + .and( + out_ref_col(DataType::UInt32, "outer_table.b") + .eq(col("inner_table_lv1.b")), + ), + )? + .project(vec![ + out_ref_col(DataType::UInt32, "outer_table.b").alias("outer_b_alias") + ])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(outer_table.clone()) + .filter( + col("outer_table.a") + .gt(lit(1)) + .and(in_subquery(col("outer_table.c"), sq_level1)), + )? + .build()?; + + // Filter: outer_table.a > Int32(1) AND outer_table.c IN () + // Subquery: + // Projection: outer_ref(outer_table.b) AS outer_b_alias + // Filter: inner_table_lv1.a = outer_ref(outer_table.a) AND outer_ref(outer_table.a) > inner_table_lv1.c AND inner_table_lv1.b = Int32(1) AND outer_ref(outer_table.b) = inner_table_lv1.b + // TableScan: inner_table_lv1 + // TableScan: outer_table + + assert_dependent_join_rewrite!(plan, @r" + Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] + Filter: outer_table.a > Int32(1) AND __in_sq_1 [a:UInt32, b:UInt32, c:UInt32, __in_sq_1:Boolean] + DependentJoin on [outer_table.a lvl 1, outer_table.b lvl 1] with expr outer_table.c IN () depth 1 [a:UInt32, b:UInt32, c:UInt32, __in_sq_1:Boolean] + TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] + Projection: outer_ref(outer_table.b) AS outer_b_alias [outer_b_alias:UInt32;N] + Filter: inner_table_lv1.a = outer_ref(outer_table.a) AND outer_ref(outer_table.a) > inner_table_lv1.c AND inner_table_lv1.b = Int32(1) AND outer_ref(outer_table.b) = inner_table_lv1.b [a:UInt32, b:UInt32, c:UInt32] + TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] + "); + Ok(()) + } + + #[test] + fn correlated_subquery_with_alias() -> Result<()> { + let outer_table = test_table_scan_with_name("outer_table")?; + let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; + let sq_level1 = Arc::new( + LogicalPlanBuilder::from(inner_table_lv1) + .filter( + col("inner_table_lv1.a") + .eq(out_ref_col(DataType::UInt32, "outer_table_alias.a")), + )? + .aggregate(Vec::::new(), vec![count(col("inner_table_lv1.a"))])? + .project(vec![count(col("inner_table_lv1.a")).alias("count_a")])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(outer_table.clone()) + .alias("outer_table_alias")? + .filter( + col("outer_table.a") + .gt(lit(1)) + .and(in_subquery(col("outer_table.c"), sq_level1)), + )? + .build()?; + + // Filter: outer_table.a > Int32(1) AND outer_table.c IN () + // Subquery: + // Projection: count(inner_table_lv1.a) AS count_a + // Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv1.a)]] + // Filter: inner_table_lv1.a = outer_ref(outer_table_alias.a) + // TableScan: inner_table_lv1 + // SubqueryAlias: outer_table_alias + // TableScan: outer_table + + assert_dependent_join_rewrite!(plan, @r" + Projection: outer_table_alias.a, outer_table_alias.b, outer_table_alias.c [a:UInt32, b:UInt32, c:UInt32] + Filter: outer_table.a > Int32(1) AND __in_sq_1 [a:UInt32, b:UInt32, c:UInt32, __in_sq_1:Boolean] + DependentJoin on [outer_table_alias.a lvl 1] with expr outer_table.c IN () depth 1 [a:UInt32, b:UInt32, c:UInt32, __in_sq_1:Boolean] + SubqueryAlias: outer_table_alias [a:UInt32, b:UInt32, c:UInt32] + TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] + Projection: count(inner_table_lv1.a) AS count_a [count_a:Int64] + Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv1.a)]] [count(inner_table_lv1.a):Int64] + Filter: inner_table_lv1.a = outer_ref(outer_table_alias.a) [a:UInt32, b:UInt32, c:UInt32] + TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] + "); + Ok(()) + } + + // from duckdb test: https://github.com/duckdb/duckdb/blob/main/test/sql/subquery/any_all/test_correlated_any_all.test + #[test] + fn test_correlated_any_all_1() -> Result<()> { + // CREATE TABLE integers(i INTEGER); + // SELECT i = ANY( + // SELECT i + // FROM integers + // WHERE i = i1.i + // ) + // FROM integers i1 + // ORDER BY i; + + // Create base table + let integers = test_table_with_columns("integers", &[("i", DataType::Int32)])?; + + // Build correlated subquery: + // SELECT i FROM integers WHERE i = i1.i + let subquery = Arc::new( + LogicalPlanBuilder::from(integers.clone()) + .filter(col("integers.i").eq(out_ref_col(DataType::Int32, "i1.i")))? + .project(vec![col("integers.i")])? + .build()?, + ); + + // Build main query with table alias i1 + let plan = LogicalPlanBuilder::from(integers) + .alias("i1")? // Alias the table as i1 + .filter( + // i = ANY(subquery) + Expr::InSubquery(InSubquery { + expr: Box::new(col("i1.i")), + subquery: Subquery { + subquery, + outer_ref_columns: vec![out_ref_col(DataType::Int32, "i1.i")], + spans: Spans::new(), + }, + negated: false, + }), + )? + .sort(vec![SortExpr::new(col("i1.i"), false, false)])? // ORDER BY i + .build()?; + + // original plan: + // Sort: i1.i DESC NULLS LAST + // Filter: i1.i IN () + // Subquery: + // Projection: integers.i + // Filter: integers.i = outer_ref(i1.i) + // TableScan: integers + // SubqueryAlias: i1 + // TableScan: integers + + // Verify the rewrite result + assert_dependent_join_rewrite!( + plan, + @r" + Sort: i1.i DESC NULLS LAST [i:Int32] + Projection: i1.i [i:Int32] + Filter: __in_sq_1 [i:Int32, __in_sq_1:Boolean] + DependentJoin on [i1.i lvl 1] with expr i1.i IN () depth 1 [i:Int32, __in_sq_1:Boolean] + SubqueryAlias: i1 [i:Int32] + TableScan: integers [i:Int32] + Projection: integers.i [i:Int32] + Filter: integers.i = outer_ref(i1.i) [i:Int32] + TableScan: integers [i:Int32] + " + ); + + Ok(()) + } + + // from duckdb: https://github.com/duckdb/duckdb/blob/main/test/sql/subquery/any_all/issue_2999.test + #[test] + fn test_any_subquery_with_derived_join() -> Result<()> { + // SQL equivalent: + // CREATE TABLE t0 (c0 INT); + // CREATE TABLE t1 (c0 INT); + // SELECT 1 = ANY( + // SELECT 1 + // FROM t1 + // JOIN ( + // SELECT count(*) + // GROUP BY t0.c0 + // ) AS x(x) ON TRUE + // ) + // FROM t0; + + // Create base tables + let t0 = test_table_with_columns("t0", &[("c0", DataType::Int32)])?; + let t1 = test_table_with_columns("t1", &[("c0", DataType::Int32)])?; + + // Build derived table subquery: + // SELECT count(*) GROUP BY t0.c0 + let derived_table = Arc::new( + LogicalPlanBuilder::from(t1.clone()) + .aggregate( + vec![out_ref_col(DataType::Int32, "t0.c0")], // GROUP BY t0.c0 + vec![count(lit(1))], // count(*) + )? + .build()?, + ); + + // Build the join subquery: + // SELECT 1 FROM t1 JOIN (derived_table) x(x) ON TRUE + let join_subquery = Arc::new( + LogicalPlanBuilder::from(t1) + .join_on( + LogicalPlan::Subquery(Subquery { + subquery: derived_table, + outer_ref_columns: vec![out_ref_col(DataType::Int32, "t0.c0")], + spans: Spans::new(), + }), + JoinType::Inner, + vec![lit(true)], // ON TRUE + )? + .project(vec![lit(1)])? // SELECT 1 + .build()?, + ); + + // Build main query + let plan = LogicalPlanBuilder::from(t0) + .filter( + // 1 = ANY(subquery) + Expr::InSubquery(InSubquery { + expr: Box::new(lit(1)), + subquery: Subquery { + subquery: join_subquery, + outer_ref_columns: vec![out_ref_col(DataType::Int32, "t0.c0")], + spans: Spans::new(), + }, + negated: false, + }), + )? + .build()?; + + // Filter: Int32(1) IN () + // Subquery: + // Projection: Int32(1) + // Inner Join: Filter: Boolean(true) + // TableScan: t1 + // Subquery: + // Aggregate: groupBy=[[outer_ref(t0.c0)]], aggr=[[count(Int32(1))]] + // TableScan: t1 + // TableScan: t0 + + // Verify the rewrite result + assert_dependent_join_rewrite!( + plan, + @r" + Projection: t0.c0 [c0:Int32] + Filter: __in_sq_2 [c0:Int32, __in_sq_2:Boolean] + DependentJoin on [t0.c0 lvl 2] with expr Int32(1) IN () depth 1 [c0:Int32, __in_sq_2:Boolean] + TableScan: t0 [c0:Int32] + Projection: Int32(1) [Int32(1):Int32] + DependentJoin on [] lateral Inner join with Boolean(true) depth 2 [c0:Int32] + TableScan: t1 [c0:Int32] + Aggregate: groupBy=[[outer_ref(t0.c0)]], aggr=[[count(Int32(1))]] [outer_ref(t0.c0):Int32;N, count(Int32(1)):Int64] + TableScan: t1 [c0:Int32] + " + ); + + Ok(()) + } + + #[test] + fn test_simple_correlated_agg_subquery() -> Result<()> { + // CREATE TABLE t(a INT, b INT); + // SELECT a, + // (SELECT SUM(b) + // FROM t t2 + // WHERE t2.a = t1.a) as sum_b + // FROM t t1; + + // Create base table + let t = test_table_with_columns( + "t", + &[("a", DataType::Int32), ("b", DataType::Int32)], + )?; + + // Build scalar subquery: + // SELECT SUM(b) FROM t t2 WHERE t2.a = t1.a + let scalar_sub = Arc::new( + LogicalPlanBuilder::from(t.clone()) + .alias("t2")? + .filter(col("t2.a").eq(out_ref_col(DataType::Int32, "t1.a")))? + .aggregate( + vec![col("t2.b")], // No GROUP BY + vec![sum(col("t2.b"))], // SUM(b) + )? + .build()?, + ); + + // Build main query + let plan = LogicalPlanBuilder::from(t) + .alias("t1")? + .project(vec![ + col("t1.a"), // a + scalar_subquery(scalar_sub), // (SELECT SUM(b) ...) + ])? + .build()?; + + // Projection: t1.a, () + // Subquery: + // Aggregate: groupBy=[[t2.b]], aggr=[[sum(t2.b)]] + // Filter: t2.a = outer_ref(t1.a) + // SubqueryAlias: t2 + // TableScan: t + // SubqueryAlias: t1 + // TableScan: t + + // Verify the rewrite result + assert_dependent_join_rewrite!( + plan, + @r" + Projection: t1.a, __scalar_sq_1 [a:Int32, __scalar_sq_1:Int32] + DependentJoin on [t1.a lvl 1] with expr () depth 1 [a:Int32, b:Int32, __scalar_sq_1:Int32] + SubqueryAlias: t1 [a:Int32, b:Int32] + TableScan: t [a:Int32, b:Int32] + Aggregate: groupBy=[[t2.b]], aggr=[[sum(t2.b)]] [b:Int32, sum(t2.b):Int64;N] + Filter: t2.a = outer_ref(t1.a) [a:Int32, b:Int32] + SubqueryAlias: t2 [a:Int32, b:Int32] + TableScan: t [a:Int32, b:Int32] + " + ); + + Ok(()) + } + + #[test] + fn test_simple_subquery_in_agg() -> Result<()> { + // CREATE TABLE t(a INT, b INT); + // SELECT a, + // SUM( + // (SELECT b FROM t t2 WHERE t2.a = t1.a) + // ) as sum_scalar + // FROM t t1 + // GROUP BY a; + + // Create base table + let t = test_table_with_columns( + "t", + &[("a", DataType::Int32), ("b", DataType::Int32)], + )?; + + // Build inner scalar subquery: + // SELECT b FROM t t2 WHERE t2.a = t1.a + let scalar_sub = Arc::new( + LogicalPlanBuilder::from(t.clone()) + .alias("t2")? + .filter(col("t2.a").eq(out_ref_col(DataType::Int32, "t1.a")))? + .project(vec![col("t2.b")])? // SELECT b + .build()?, + ); + + // Build main query + let plan = LogicalPlanBuilder::from(t) + .alias("t1")? + .aggregate( + vec![col("t1.a")], // GROUP BY a + vec![sum(scalar_subquery(scalar_sub)) // SUM((SELECT b ...)) + .alias("sum_scalar")], + )? + .build()?; + + // Aggregate: groupBy=[[t1.a]], aggr=[[sum(()) AS sum_scalar]] + // Subquery: + // Projection: t2.b + // Filter: t2.a = outer_ref(t1.a) + // SubqueryAlias: t2 + // TableScan: t + // SubqueryAlias: t1 + // TableScan: t + + // Verify the rewrite result + assert_dependent_join_rewrite!( + plan, + @r" + Projection: t1.a, sum_scalar [a:Int32, sum_scalar:Int64;N] + Aggregate: groupBy=[[t1.a]], aggr=[[sum(__scalar_sq_1) AS sum_scalar]] [a:Int32, sum_scalar:Int64;N] + DependentJoin on [t1.a lvl 1] with expr () depth 1 [a:Int32, b:Int32, __scalar_sq_1:Int32] + SubqueryAlias: t1 [a:Int32, b:Int32] + TableScan: t [a:Int32, b:Int32] + Projection: t2.b [b:Int32] + Filter: t2.a = outer_ref(t1.a) [a:Int32, b:Int32] + SubqueryAlias: t2 [a:Int32, b:Int32] + TableScan: t [a:Int32, b:Int32] + " + ); + + Ok(()) + } + + #[test] + // https://github.com/duckdb/duckdb/blob/4d7cb701cabd646d8232a9933dd058a089ea7348/test/sql/subquery/any_all/subquery_in.test + fn correlated_scalar_subquery_returning_more_than_1_row() -> Result<()> { + // SELECT (FALSE) IN (TRUE, (SELECT TIME '13:35:07' FROM t1) BETWEEN t0.c0 AND t0.c0) FROM t0; + let t0 = test_table_with_columns( + "t0", + &[ + ("c0", DataType::Time64(TimeUnit::Second)), + ("c1", DataType::Float64), + ], + )?; + let t1 = test_table_with_columns("t1", &[("c0", DataType::Int32)])?; + let t1_subquery = Arc::new( + LogicalPlanBuilder::from(t1) + .project(vec![lit("13:35:07")])? + .build()?, + ); + let plan = LogicalPlanBuilder::from(t0) + .project(vec![lit(false).in_list( + vec![ + lit(true), + scalar_subquery(t1_subquery).between(col("t0.c0"), col("t0.c0")), + ], + false, + )])? + .build()?; + // Projection: Boolean(false) IN ([Boolean(true), () BETWEEN t0.c0 AND t0.c0]) + // Subquery: + // Projection: Utf8("13:35:07") + // TableScan: t1 + // TableScan: t0 + assert_dependent_join_rewrite!( + plan, + @r#" + Projection: Boolean(false) IN ([Boolean(true), __scalar_sq_1 BETWEEN t0.c0 AND t0.c0]) [Boolean(false) IN Boolean(true), __scalar_sq_1 BETWEEN t0.c0 AND t0.c0:Boolean] + DependentJoin on [] with expr () depth 1 [c0:Time64(Second), c1:Float64, __scalar_sq_1:Utf8] + TableScan: t0 [c0:Time64(Second), c1:Float64] + Projection: Utf8("13:35:07") [Utf8("13:35:07"):Utf8] + TableScan: t1 [c0:Int32] + "# + ); + + Ok(()) + } + + #[test] + fn test_correlated_subquery_in_join_filter() -> Result<()> { + // Test demonstrates traversal order issue with subquery in JOIN condition + // Query pattern: + // SELECT * FROM t1 + // JOIN t2 ON t2.key = t1.key + // AND t2.val > (SELECT COUNT(*) FROM t3 WHERE t3.id = t1.id); + + let t1 = test_table_with_columns( + "t1", + &[ + ("key", DataType::Int32), + ("id", DataType::Int32), + ("val", DataType::Int32), + ], + )?; + + let t2 = test_table_with_columns( + "t2", + &[("key", DataType::Int32), ("val", DataType::Int32)], + )?; + + let t3 = test_table_with_columns( + "t3", + &[("id", DataType::Int32), ("val", DataType::Int32)], + )?; + + // Subquery in join condition: SELECT COUNT(*) FROM t3 WHERE t3.id = t1.id + let scalar_sq = Arc::new( + LogicalPlanBuilder::from(t3) + .filter(col("t3.id").eq(out_ref_col(DataType::Int32, "t1.id")))? + .aggregate(Vec::::new(), vec![count(lit(1))])? + .build()?, + ); + + // Build join condition: t2.key = t1.key AND t2.val > scalar_sq AND EXISTS(exists_sq) + let join_condition = and( + col("t2.key").eq(col("t1.key")), + col("t2.val").gt(scalar_subquery(scalar_sq)), + ); + let plan = LogicalPlanBuilder::from(t1) + .join_on(t2, JoinType::Inner, vec![join_condition])? + .build()?; + + // Inner Join: Filter: t2.key = t1.key AND t2.val > () + // Subquery: + // Aggregate: groupBy=[[]], aggr=[[count(Int32(1))]] + // Filter: t3.id = outer_ref(t1.id) + // TableScan: t3 + // TableScan: t1 + // TableScan: t2 + + assert_dependent_join_rewrite!( + plan, + @r" + Filter: t2.key = t1.key AND t2.val > __scalar_sq_1 [key:Int32, id:Int32, val:Int32, key:Int32, val:Int32, __scalar_sq_1:Int64] + DependentJoin on [t1.id lvl 1] with expr () depth 1 [key:Int32, id:Int32, val:Int32, key:Int32, val:Int32, __scalar_sq_1:Int64] + Cross Join(ComparisonJoin): [key:Int32, id:Int32, val:Int32, key:Int32, val:Int32] + TableScan: t1 [key:Int32, id:Int32, val:Int32] + TableScan: t2 [key:Int32, val:Int32] + Aggregate: groupBy=[[]], aggr=[[count(Int32(1))]] [count(Int32(1)):Int64] + Filter: t3.id = outer_ref(t1.id) [id:Int32, val:Int32] + TableScan: t3 [id:Int32, val:Int32] + " + ); + + Ok(()) + } + + #[test] + fn test_correlated_subquery_in_lateral_join_filter() -> Result<()> { + // Test demonstrates traversal order issue with subquery in JOIN condition + // Query pattern: + // SELECT * FROM t1 + // JOIN t2 ON t2.key = t1.key + // AND t2.val > (SELECT COUNT(*) FROM t3 WHERE t3.id = t1.id); + + let t1 = test_table_with_columns( + "t1", + &[ + ("key", DataType::Int32), + ("id", DataType::Int32), + ("val", DataType::Int32), + ], + )?; + + let t2 = test_table_with_columns( + "t2", + &[("key", DataType::Int32), ("val", DataType::Int32)], + )?; + + let t3 = test_table_with_columns( + "t3", + &[("id", DataType::Int32), ("val", DataType::Int32)], + )?; + + // Subquery in join condition: SELECT COUNT(*) FROM t3 WHERE t3.id = t1.id + let scalar_sq = Arc::new( + LogicalPlanBuilder::from(t3) + .filter(col("t3.id").eq(out_ref_col(DataType::Int32, "t1.id")))? + .aggregate(Vec::::new(), vec![count(lit(1))])? + .build()?, + ); + + // Build join condition: t2.key = t1.key AND t2.val > scalar_sq AND EXISTS(exists_sq) + let join_condition = and( + col("t2.key").eq(col("t1.key")), + col("t2.val").gt(scalar_subquery(scalar_sq)), + ); + + let plan = LogicalPlanBuilder::from(t1) + .join_on( + LogicalPlan::Subquery(Subquery { + subquery: t2.into(), + outer_ref_columns: vec![], + spans: Spans::new(), + }), + JoinType::Inner, + vec![join_condition], + )? + .build()?; + + // Inner Join: Filter: t2.key = t1.key AND t2.val > () + // Subquery: + // Aggregate: groupBy=[[]], aggr=[[count(Int32(1))]] + // Filter: t3.id = outer_ref(t1.id) + // TableScan: t3 + // TableScan: t1 + // TableScan: t2 + assert_dependent_join_rewrite_err!( + plan //@"This feature is not implemented: subquery inside lateral join condition is not supported" + ); + + Ok(()) + } + + #[test] + fn test_multiple_correlated_subqueries_in_join_filter() -> Result<()> { + // Test demonstrates traversal order issue with subquery in JOIN condition + // Query pattern: + // SELECT * FROM t1 + // JOIN t2 ON (t2.key = t1.key + // AND t2.val > (SELECT COUNT(*) FROM t3 WHERE t3.id = t1.id)) + // OR exits ( + // SELECT * FROM T3 WHERE T3.ID = T2.KEY + // ); + + let t1 = test_table_with_columns( + "t1", + &[ + ("key", DataType::Int32), + ("id", DataType::Int32), + ("val", DataType::Int32), + ], + )?; + + let t2 = test_table_with_columns( + "t2", + &[("key", DataType::Int32), ("val", DataType::Int32)], + )?; + + let t3 = test_table_with_columns( + "t3", + &[("id", DataType::Int32), ("val", DataType::Int32)], + )?; + + // Subquery in join condition: SELECT COUNT(*) FROM t3 WHERE t3.id = t1.id + let scalar_sq = Arc::new( + LogicalPlanBuilder::from(t3.clone()) + .filter(col("t3.id").eq(out_ref_col(DataType::Int32, "t1.id")))? + .aggregate(Vec::::new(), vec![count(lit(1))])? + .build()?, + ); + let exists_sq = Arc::new( + LogicalPlanBuilder::from(t3) + .filter(col("t3.id").eq(out_ref_col(DataType::Int32, "t2.key")))? + .build()?, + ); + + // Build join condition: (t2.key = t1.key AND t2.val > scalar_sq) OR (exists(exists_sq)) + let join_condition = and( + col("t2.key").eq(col("t1.key")), + col("t2.val").gt(scalar_subquery(scalar_sq)), + ) + .or(exists(exists_sq)); + + let plan = LogicalPlanBuilder::from(t1) + .join_on(t2, JoinType::Inner, vec![join_condition])? + .build()?; + // Inner Join: Filter: t2.key = t1.key AND t2.val > () OR EXISTS () + // Subquery: + // Aggregate: groupBy=[[]], aggr=[[count(Int32(1))]] + // Filter: t3.id = outer_ref(t1.id) + // TableScan: t3 + // Subquery: + // Filter: t3.id = outer_ref(t2.key) + // TableScan: t3 + // TableScan: t1 + // TableScan: t2 + + assert_dependent_join_rewrite!( + plan, + @r" + Filter: t2.key = t1.key AND t2.val > __scalar_sq_1 OR __exists_sq_2 [key:Int32, id:Int32, val:Int32, key:Int32, val:Int32, __scalar_sq_1:Int64, __exists_sq_2:Boolean] + DependentJoin on [t2.key lvl 1] with expr EXISTS () depth 1 [key:Int32, id:Int32, val:Int32, key:Int32, val:Int32, __scalar_sq_1:Int64, __exists_sq_2:Boolean] + DependentJoin on [t1.id lvl 1] with expr () depth 1 [key:Int32, id:Int32, val:Int32, key:Int32, val:Int32, __scalar_sq_1:Int64] + Cross Join(ComparisonJoin): [key:Int32, id:Int32, val:Int32, key:Int32, val:Int32] + TableScan: t1 [key:Int32, id:Int32, val:Int32] + TableScan: t2 [key:Int32, val:Int32] + Aggregate: groupBy=[[]], aggr=[[count(Int32(1))]] [count(Int32(1)):Int64] + Filter: t3.id = outer_ref(t1.id) [id:Int32, val:Int32] + TableScan: t3 [id:Int32, val:Int32] + Filter: t3.id = outer_ref(t2.key) [id:Int32, val:Int32] + TableScan: t3 [id:Int32, val:Int32] + " + ); + + Ok(()) + } + + #[test] + fn test_two_exists_subqueries_with_or_filter() -> Result<()> { + // Test case for the SQL pattern mentioned in the documentation comment: + // SELECT ID FROM T1 WHERE EXISTS(SELECT * FROM T2 WHERE T2.ID=T1.ID) OR EXISTS(SELECT * FROM T2 WHERE T2.VALUE=T1.ID); + + let t1 = test_table_with_columns( + "t1", + &[("id", DataType::Int32), ("value", DataType::Int32)], + )?; + + let t2 = test_table_with_columns( + "t2", + &[("id", DataType::Int32), ("value", DataType::Int32)], + )?; + + // First EXISTS subquery: SELECT * FROM T2 WHERE T2.ID=T1.ID + let exists_sq1 = Arc::new( + LogicalPlanBuilder::from(t2.clone()) + .filter(col("t2.id").eq(out_ref_col(DataType::Int32, "t1.id")))? + .build()?, + ); + + // Second EXISTS subquery: SELECT * FROM T2 WHERE T2.VALUE=T1.ID + let exists_sq2 = Arc::new( + LogicalPlanBuilder::from(t2) + .filter(col("t2.value").eq(out_ref_col(DataType::Int32, "t1.id")))? + .build()?, + ); + + // Build the main query: SELECT ID FROM T1 WHERE EXISTS(...) OR EXISTS(...) + let plan = LogicalPlanBuilder::from(t1) + .filter(exists(exists_sq1).or(not_exists(exists_sq2)))? + .project(vec![col("t1.id")])? + .build()?; + + // Filter: EXISTS () OR EXISTS () + // Subquery: + // Filter: t2.id = outer_ref(t1.id) + // TableScan: t2 + // Subquery: + // Filter: t2.value = outer_ref(t1.id) + // TableScan: t2 + // TableScan: t1 + + assert_dependent_join_rewrite!( + plan, + @r" + Projection: t1.id [id:Int32] + Projection: t1.id, t1.value [id:Int32, value:Int32] + Filter: __exists_sq_1 OR NOT __exists_sq_2 [id:Int32, value:Int32, __exists_sq_1:Boolean, __exists_sq_2:Boolean] + DependentJoin on [t1.id lvl 1] with expr EXISTS () depth 1 [id:Int32, value:Int32, __exists_sq_1:Boolean, __exists_sq_2:Boolean] + DependentJoin on [t1.id lvl 1] with expr EXISTS () depth 1 [id:Int32, value:Int32, __exists_sq_1:Boolean] + TableScan: t1 [id:Int32, value:Int32] + Filter: t2.id = outer_ref(t1.id) [id:Int32, value:Int32] + TableScan: t2 [id:Int32, value:Int32] + Filter: t2.value = outer_ref(t1.id) [id:Int32, value:Int32] + TableScan: t2 [id:Int32, value:Int32] + " + ); + + Ok(()) + } +} diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index 2f9a2f6bb9ed..d2a856e911d5 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -73,7 +73,6 @@ impl OptimizerRule for ScalarSubqueryToJoin { fn supports_rewrite(&self) -> bool { true } - fn rewrite( &self, plan: LogicalPlan, @@ -453,8 +452,8 @@ mod tests { @r" Projection: customer.c_custkey [c_custkey:Int64] Filter: Int32(1) < __scalar_sq_1.max(orders.o_custkey) AND Int32(1) < __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] - Left Join: Filter: __scalar_sq_2.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] - Left Join: Filter: __scalar_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] + Left Join(ComparisonJoin): Filter: __scalar_sq_2.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] + Left Join(ComparisonJoin): Filter: __scalar_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] TableScan: customer [c_custkey:Int64, c_name:Utf8] SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] @@ -507,13 +506,13 @@ mod tests { @r" Projection: customer.c_custkey [c_custkey:Int64] Filter: customer.c_acctbal < __scalar_sq_1.sum(orders.o_totalprice) [c_custkey:Int64, c_name:Utf8, sum(orders.o_totalprice):Float64;N, o_custkey:Int64;N, __always_true:Boolean;N] - Left Join: Filter: __scalar_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, sum(orders.o_totalprice):Float64;N, o_custkey:Int64;N, __always_true:Boolean;N] + Left Join(ComparisonJoin): Filter: __scalar_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, sum(orders.o_totalprice):Float64;N, o_custkey:Int64;N, __always_true:Boolean;N] TableScan: customer [c_custkey:Int64, c_name:Utf8] SubqueryAlias: __scalar_sq_1 [sum(orders.o_totalprice):Float64;N, o_custkey:Int64, __always_true:Boolean] Projection: sum(orders.o_totalprice), orders.o_custkey, __always_true [sum(orders.o_totalprice):Float64;N, o_custkey:Int64, __always_true:Boolean] Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[sum(orders.o_totalprice)]] [o_custkey:Int64, __always_true:Boolean, sum(orders.o_totalprice):Float64;N] Filter: orders.o_totalprice < __scalar_sq_2.sum(lineitem.l_extendedprice) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N, sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64;N, __always_true:Boolean;N] - Left Join: Filter: __scalar_sq_2.l_orderkey = orders.o_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N, sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64;N, __always_true:Boolean;N] + Left Join(ComparisonJoin): Filter: __scalar_sq_2.l_orderkey = orders.o_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N, sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64;N, __always_true:Boolean;N] TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] SubqueryAlias: __scalar_sq_2 [sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64, __always_true:Boolean] Projection: sum(lineitem.l_extendedprice), lineitem.l_orderkey, __always_true [sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64, __always_true:Boolean] @@ -548,7 +547,7 @@ mod tests { @r" Projection: customer.c_custkey [c_custkey:Int64] Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] - Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] + Left Join(ComparisonJoin): Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] TableScan: customer [c_custkey:Int64, c_name:Utf8] SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] @@ -584,7 +583,7 @@ mod tests { @r" Projection: customer.c_custkey [c_custkey:Int64] Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] - Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] + Left Join(ComparisonJoin): Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] TableScan: customer [c_custkey:Int64, c_name:Utf8] SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N] Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] @@ -615,7 +614,7 @@ mod tests { @r" Projection: customer.c_custkey [c_custkey:Int64] Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] - Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] + Left Join(ComparisonJoin): Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] TableScan: customer [c_custkey:Int64, c_name:Utf8] SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N] Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] @@ -775,7 +774,7 @@ mod tests { @r" Projection: customer.c_custkey [c_custkey:Int64] Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) + Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] - Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] + Left Join(ComparisonJoin): Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] TableScan: customer [c_custkey:Int64, c_name:Utf8] SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64, __always_true:Boolean] Projection: max(orders.o_custkey) + Int32(1), orders.o_custkey, __always_true [max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64, __always_true:Boolean] @@ -816,7 +815,7 @@ mod tests { plan, @r#" Projection: customer.c_custkey, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN CASE WHEN CAST(NULL AS Boolean) THEN Utf8("a") ELSE Utf8("b") END ELSE __scalar_sq_1.CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END END AS CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END [c_custkey:Int64, CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END:Utf8;N] - Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END:Utf8;N, o_custkey:Int64;N, __always_true:Boolean;N] + Left Join(ComparisonJoin): Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END:Utf8;N, o_custkey:Int64;N, __always_true:Boolean;N] TableScan: customer [c_custkey:Int64, c_name:Utf8] SubqueryAlias: __scalar_sq_1 [CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END:Utf8, o_custkey:Int64, __always_true:Boolean] Projection: CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END, orders.o_custkey, __always_true [CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END:Utf8, o_custkey:Int64, __always_true:Boolean] @@ -878,7 +877,7 @@ mod tests { @r" Projection: customer.c_custkey [c_custkey:Int64] Filter: customer.c_custkey >= __scalar_sq_1.max(orders.o_custkey) AND customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] - Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] + Left Join(ComparisonJoin): Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] TableScan: customer [c_custkey:Int64, c_name:Utf8] SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] @@ -915,7 +914,7 @@ mod tests { @r" Projection: customer.c_custkey [c_custkey:Int64] Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) AND customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] - Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] + Left Join(ComparisonJoin): Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] TableScan: customer [c_custkey:Int64, c_name:Utf8] SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] @@ -953,7 +952,7 @@ mod tests { @r" Projection: customer.c_custkey [c_custkey:Int64] Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] - Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] + Left Join(ComparisonJoin): Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] TableScan: customer [c_custkey:Int64, c_name:Utf8] SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] @@ -984,7 +983,7 @@ mod tests { @r" Projection: test.c [c:UInt32] Filter: test.c < __scalar_sq_1.min(sq.c) [a:UInt32, b:UInt32, c:UInt32, min(sq.c):UInt32;N, a:UInt32;N, __always_true:Boolean;N] - Left Join: Filter: test.a = __scalar_sq_1.a [a:UInt32, b:UInt32, c:UInt32, min(sq.c):UInt32;N, a:UInt32;N, __always_true:Boolean;N] + Left Join(ComparisonJoin): Filter: test.a = __scalar_sq_1.a [a:UInt32, b:UInt32, c:UInt32, min(sq.c):UInt32;N, a:UInt32;N, __always_true:Boolean;N] TableScan: test [a:UInt32, b:UInt32, c:UInt32] SubqueryAlias: __scalar_sq_1 [min(sq.c):UInt32;N, a:UInt32, __always_true:Boolean] Projection: min(sq.c), sq.a, __always_true [min(sq.c):UInt32;N, a:UInt32, __always_true:Boolean] @@ -1014,7 +1013,7 @@ mod tests { @r" Projection: customer.c_custkey [c_custkey:Int64] Filter: customer.c_custkey < __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] - Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] + Left Join(ComparisonJoin): Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] TableScan: customer [c_custkey:Int64, c_name:Utf8] SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N] Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] @@ -1043,7 +1042,7 @@ mod tests { @r" Projection: customer.c_custkey [c_custkey:Int64] Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] - Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] + Left Join(ComparisonJoin): Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] TableScan: customer [c_custkey:Int64, c_name:Utf8] SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N] Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] @@ -1093,8 +1092,8 @@ mod tests { @r" Projection: customer.c_custkey [c_custkey:Int64] Filter: customer.c_custkey BETWEEN __scalar_sq_1.min(orders.o_custkey) AND __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] - Left Join: Filter: customer.c_custkey = __scalar_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] - Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] + Left Join(ComparisonJoin): Filter: customer.c_custkey = __scalar_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] + Left Join(ComparisonJoin): Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] TableScan: customer [c_custkey:Int64, c_name:Utf8] SubqueryAlias: __scalar_sq_1 [min(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] Projection: min(orders.o_custkey), orders.o_custkey, __always_true [min(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] @@ -1140,8 +1139,8 @@ mod tests { @r" Projection: customer.c_custkey [c_custkey:Int64] Filter: customer.c_custkey BETWEEN __scalar_sq_1.min(orders.o_custkey) AND __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, max(orders.o_custkey):Int64;N] - Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, max(orders.o_custkey):Int64;N] - Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N] + Left Join(ComparisonJoin): Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, max(orders.o_custkey):Int64;N] + Left Join(ComparisonJoin): Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N] TableScan: customer [c_custkey:Int64, c_name:Utf8] SubqueryAlias: __scalar_sq_1 [min(orders.o_custkey):Int64;N] Projection: min(orders.o_custkey) [min(orders.o_custkey):Int64;N] diff --git a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs index ccf90893e17e..0e0c8dd8b9fb 100644 --- a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs +++ b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs @@ -799,7 +799,7 @@ mod tests { assert_optimized_plan_equal!( plan, @ r" - Inner Join: t1.a + UInt32(1) = t2.a + UInt32(2) + Inner Join(ComparisonJoin): t1.a + UInt32(1) = t2.a + UInt32(2) TableScan: t1 TableScan: t2 " diff --git a/datafusion/optimizer/src/test/mod.rs b/datafusion/optimizer/src/test/mod.rs index 6e0b734bb928..0220b01ccdcd 100644 --- a/datafusion/optimizer/src/test/mod.rs +++ b/datafusion/optimizer/src/test/mod.rs @@ -21,6 +21,7 @@ use crate::{OptimizerContext, OptimizerRule}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::config::ConfigOptions; use datafusion_common::{assert_contains, Result}; +use datafusion_expr::CorrelatedColumnInfo; use datafusion_expr::{logical_plan::table_scan, LogicalPlan, LogicalPlanBuilder}; use std::sync::Arc; @@ -45,6 +46,39 @@ pub fn test_table_scan() -> Result { test_table_scan_with_name("test") } +pub fn test_delim_scan_with_name( + correlated_columns: Vec, +) -> Result { + LogicalPlanBuilder::delim_get(&correlated_columns)?.build() +} + +/// Create a table with the given name and column definitions. +/// +/// # Arguments +/// * `name` - The name of the table to create +/// * `columns` - Column definitions as slice of tuples (name, data_type) +/// +/// # Example +/// ``` +/// let plan = test_table_with_columns("integers", &[("i", DataType::Int32)])?; +/// ``` +pub fn test_table_with_columns( + name: &str, + columns: &[(&str, DataType)], +) -> Result { + // Create fields with specified types for each column + let fields: Vec = columns + .iter() + .map(|&(col_name, ref data_type)| Field::new(col_name, data_type.clone(), false)) + .collect(); + + // Create schema from fields + let schema = Schema::new(fields); + + // Create table scan + table_scan(Some(name), &schema, None)?.build() +} + /// Scan an empty data source, mainly used in tests pub fn scan_empty( name: Option<&str>, diff --git a/datafusion/physical-expr/src/equivalence/class.rs b/datafusion/physical-expr/src/equivalence/class.rs index 8af6f3be0389..588e11855527 100644 --- a/datafusion/physical-expr/src/equivalence/class.rs +++ b/datafusion/physical-expr/src/equivalence/class.rs @@ -768,7 +768,11 @@ impl EquivalenceGroup { on: &[(PhysicalExprRef, PhysicalExprRef)], ) -> Result { let group = match join_type { - JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => { + JoinType::Inner + | JoinType::Left + | JoinType::Full + | JoinType::Right + | JoinType::LeftSingle => { let mut result = Self::new( self.iter().cloned().chain( right_equivalences diff --git a/datafusion/physical-optimizer/src/enforce_distribution.rs b/datafusion/physical-optimizer/src/enforce_distribution.rs index 39eb557ea601..dad9f2104ef2 100644 --- a/datafusion/physical-optimizer/src/enforce_distribution.rs +++ b/datafusion/physical-optimizer/src/enforce_distribution.rs @@ -341,7 +341,8 @@ pub fn adjust_input_keys_ordering( | JoinType::LeftSemi | JoinType::LeftAnti | JoinType::Full - | JoinType::LeftMark => vec![], + | JoinType::LeftMark + | JoinType::LeftSingle => vec![], }; } PartitionMode::Auto => { diff --git a/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs b/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs index 6e4e78486612..ba386b72fe58 100644 --- a/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs +++ b/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs @@ -525,7 +525,8 @@ fn expr_source_side( | JoinType::Right | JoinType::Full | JoinType::LeftMark - | JoinType::RightMark => { + | JoinType::RightMark + | JoinType::LeftSingle => { let eq_group = eqp.eq_group(); let mut right_ordering = ordering.clone(); let (mut valid_left, mut valid_right) = (true, true); diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index c8c4c0806f03..6e25470fb230 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -17,6 +17,7 @@ //! [`HashJoinExec`] Partitioned Hash Join Operator +use std::collections::HashMap as StdHashMap; use std::fmt; use std::mem::size_of; use std::sync::atomic::{AtomicUsize, Ordering}; @@ -562,7 +563,8 @@ impl HashJoinExec { JoinType::Left | JoinType::LeftAnti | JoinType::LeftMark - | JoinType::Full => EmissionType::Both, + | JoinType::Full + | JoinType::LeftSingle => EmissionType::Both, } } else { right.pipeline_behavior() @@ -875,6 +877,7 @@ impl ExecutionPlan for HashJoinExec { batch_size, hashes_buffer: vec![], right_side_ordered: self.right.output_ordering().is_some(), + left_match_counts: StdHashMap::new(), })) } @@ -1249,6 +1252,8 @@ struct HashJoinStream { hashes_buffer: Vec, /// Specifies whether the right side has an ordering to potentially preserve right_side_ordered: bool, + /// Used by Letft Single Join to check it multiple rows matched at runtime. + left_match_counts: StdHashMap, } impl RecordBatchStream for HashJoinStream { @@ -1592,6 +1597,21 @@ impl HashJoinStream { )? }; + // Validates cardinality constraints for single join types + // TODO: RightSingle support. + if matches!(self.join_type, JoinType::LeftSingle) { + for &left_idx in left_indices.values() { + let count = self.left_match_counts.entry(left_idx).or_insert(0); + *count += 1; + if *count > 1 { + return internal_err!( + "LeftSingle join constraint violated: build side row at index {} has multiple matches", + left_idx + ); + } + } + } + self.join_metrics.output_batches.add(1); timer.done(); @@ -4686,6 +4706,89 @@ mod tests { Ok(()) } + #[apply(batch_sizes)] + #[tokio::test] + async fn join_left_single_success(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), // each value appears once + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), // each value appears at most once in matching positions + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + )]; + + let (columns, batches) = join_collect( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + &JoinType::LeftSingle, + NullEquality::NullEqualsNothing, + task_ctx, + ) + .await?; + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); + + allow_duplicates! { + assert_snapshot!(batches_to_sort_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | 1 | 4 | 7 | 10 | 4 | 70 | + | 2 | 5 | 8 | 20 | 5 | 80 | + | 3 | 7 | 9 | | | | + +----+----+----+----+----+----+ + "#); + } + + Ok(()) + } + + #[apply(batch_sizes)] + #[tokio::test] + async fn join_left_single_cardinality_violation(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 5]), // 5 appears twice - this should be fine + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![5, 5, 6]), // 5 appears twice - this creates multiple matches for left side + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + )]; + + let join = join( + Arc::clone(&left), + Arc::clone(&right), + on, + &JoinType::LeftSingle, + NullEquality::NullEqualsNothing, + )?; + + let stream = join.execute(0, task_ctx)?; + let result = common::collect(stream).await; + + assert!(result.is_err()); + let error_msg = result.unwrap_err().to_string(); + assert!(error_msg.contains("LeftSingle join constraint violated")); + assert!(error_msg.contains("has multiple matches")); + + Ok(()) + } + /// Returns the column names on the schema fn columns(schema: &Schema) -> Vec { schema.fields().iter().map(|f| f.name().clone()).collect() diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index c84b3a9d402c..f19c1e787667 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -282,6 +282,7 @@ impl NestedLoopJoinExec { | JoinType::LeftAnti | JoinType::LeftMark | JoinType::Full => EmissionType::Both, + JoinType::LeftSingle => unimplemented!() } } else { right.pipeline_behavior() diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 9a6832283486..c71d7c93dfb9 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -242,6 +242,7 @@ impl SortMergeJoinExec { | JoinType::LeftAnti | JoinType::LeftSemi | JoinType::LeftMark => JoinSide::Left, + JoinType::LeftSingle => unimplemented!(), } } @@ -259,6 +260,7 @@ impl SortMergeJoinExec { | JoinType::RightMark => { vec![false, true] } + JoinType::LeftSingle => unimplemented!(), _ => vec![false, false], } } diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index 6420348f880e..d79a1123c5ac 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -209,7 +209,7 @@ pub struct ColumnIndex { fn output_join_field(old_field: &Field, join_type: &JoinType, is_left: bool) -> Field { let force_nullable = match join_type { JoinType::Inner => false, - JoinType::Left => !is_left, // right input is padded with nulls + JoinType::Left | JoinType::LeftSingle => !is_left, // right input is padded with nulls JoinType::Right => is_left, // left input is padded with nulls JoinType::Full => true, // both inputs can be padded with nulls JoinType::LeftSemi => false, // doesn't introduce nulls @@ -268,7 +268,11 @@ pub fn build_join_schema( }; let (fields, column_indices): (SchemaBuilder, Vec) = match join_type { - JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => { + JoinType::Inner + | JoinType::Left + | JoinType::Full + | JoinType::Right + | JoinType::LeftSingle => { // left then right left_fields().chain(right_fields()).unzip() } @@ -436,7 +440,11 @@ fn estimate_join_cardinality( .unzip::<_, _, Vec<_>, Vec<_>>(); match join_type { - JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => { + JoinType::Inner + | JoinType::Left + | JoinType::Right + | JoinType::Full + | JoinType::LeftSingle => { let ij_cardinality = estimate_inner_join_cardinality( Statistics { num_rows: left_stats.num_rows, @@ -785,6 +793,7 @@ pub(crate) fn need_produce_result_in_final(join_type: JoinType) -> bool { | JoinType::LeftSemi | JoinType::LeftMark | JoinType::Full + | JoinType::LeftSingle ) } @@ -942,7 +951,7 @@ pub(crate) fn adjust_indices_by_join_type( // matched Ok((left_indices, right_indices)) } - JoinType::Left => { + JoinType::Left | JoinType::LeftSingle => { // matched Ok((left_indices, right_indices)) // unmatched left row will be produced in the end of loop, and it has been set in the left visited bitmap @@ -1329,9 +1338,11 @@ pub(crate) fn symmetric_join_output_partitioning( let left_partitioning = left.output_partitioning(); let right_partitioning = right.output_partitioning(); let result = match join_type { - JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => { - left_partitioning.clone() - } + JoinType::Left + | JoinType::LeftSemi + | JoinType::LeftAnti + | JoinType::LeftMark + | JoinType::LeftSingle => left_partitioning.clone(), JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => { right_partitioning.clone() } @@ -1363,7 +1374,8 @@ pub(crate) fn asymmetric_join_output_partitioning( | JoinType::LeftSemi | JoinType::LeftAnti | JoinType::Full - | JoinType::LeftMark => Partitioning::UnknownPartitioning( + | JoinType::LeftMark + | JoinType::LeftSingle => Partitioning::UnknownPartitioning( right.output_partitioning().partition_count(), ), }; diff --git a/datafusion/proto/src/generated/datafusion_proto_common.rs b/datafusion/proto/src/generated/datafusion_proto_common.rs index 411d72af4c62..f0bd48b05131 100644 --- a/datafusion/proto/src/generated/datafusion_proto_common.rs +++ b/datafusion/proto/src/generated/datafusion_proto_common.rs @@ -905,6 +905,7 @@ pub enum JoinType { Rightanti = 7, Leftmark = 8, Rightmark = 9, + LeftSingle = 10, } impl JoinType { /// String value of the enum field names used in the ProtoBuf definition. @@ -923,6 +924,7 @@ impl JoinType { Self::Rightanti => "RIGHTANTI", Self::Leftmark => "LEFTMARK", Self::Rightmark => "RIGHTMARK", + Self::LeftSingle => "LEFTSINGLE", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -938,6 +940,7 @@ impl JoinType { "RIGHTANTI" => Some(Self::Rightanti), "LEFTMARK" => Some(Self::Leftmark), "RIGHTMARK" => Some(Self::Rightmark), + "LEFTSINGLE" => Some(Self::LeftSingle), _ => None, } } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 6c5b348698c7..327b92e2d8e2 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -206,6 +206,7 @@ impl From for JoinType { protobuf::JoinType::Rightanti => JoinType::RightAnti, protobuf::JoinType::Leftmark => JoinType::LeftMark, protobuf::JoinType::Rightmark => JoinType::RightMark, + protobuf::JoinType::LeftSingle => JoinType::LeftSingle, } } } diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 9915d3617ff9..d4c4df61a8c6 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -1778,6 +1778,12 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } + LogicalPlan::DependentJoin(_) => Err(proto_error( + "LogicalPlan serde is not implemented for DependentJoin", + )), + LogicalPlan::DelimGet(_) => Err(proto_error( + "LogicalPlan serde is not implemented for DelimGet", + )), } } } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 43afaa0fbe65..3111ec01411a 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -683,6 +683,7 @@ impl From for protobuf::JoinType { JoinType::RightAnti => protobuf::JoinType::Rightanti, JoinType::LeftMark => protobuf::JoinType::Leftmark, JoinType::RightMark => protobuf::JoinType::Rightmark, + JoinType::LeftSingle => protobuf::JoinType::LeftSingle, } } } diff --git a/datafusion/sql/src/expr/identifier.rs b/datafusion/sql/src/expr/identifier.rs index 7c276ce53e35..9ee7b22e6dde 100644 --- a/datafusion/sql/src/expr/identifier.rs +++ b/datafusion/sql/src/expr/identifier.rs @@ -69,7 +69,7 @@ impl SqlToRel<'_, S> { } // Check the outer query schema - if let Some(outer) = planner_context.outer_query_schema() { + for outer in planner_context.outer_queries_schemas() { if let Ok((qualifier, field)) = outer.qualified_field_with_unqualified_name(normalize_ident.as_str()) { @@ -165,35 +165,43 @@ impl SqlToRel<'_, S> { not_impl_err!("compound identifier: {ids:?}") } else { // Check the outer_query_schema and try to find a match - if let Some(outer) = planner_context.outer_query_schema() { - let search_result = search_dfschema(&ids, outer); - match search_result { - // Found matching field with spare identifier(s) for nested field(s) in structure - Some((field, qualifier, nested_names)) - if !nested_names.is_empty() => - { - // TODO: remove when can support nested identifiers for OuterReferenceColumn - not_impl_err!( + let outer_schemas = planner_context.outer_queries_schemas(); + let mut maybe_result = None; + if !outer_schemas.is_empty() { + for outer in planner_context.outer_queries_schemas() { + let search_result = search_dfschema(&ids, &outer); + let result = match search_result { + // Found matching field with spare identifier(s) for nested field(s) in structure + Some((field, qualifier, nested_names)) + if !nested_names.is_empty() => + { + // TODO: remove when can support nested identifiers for OuterReferenceColumn + not_impl_err!( "Nested identifiers are not yet supported for OuterReferenceColumn {}", Column::from((qualifier, field)).quoted_flat_name() ) - } - // Found matching field with no spare identifier(s) - Some((field, qualifier, _nested_names)) => { - // Found an exact match on a qualified name in the outer plan schema, so this is an outer reference column - Ok(Expr::OuterReferenceColumn( - field.data_type().clone(), - Column::from((qualifier, field)), - )) - } - // Found no matching field, will return a default - None => { - let s = &ids[0..ids.len()]; - // safe unwrap as s can never be empty or exceed the bounds - let (relation, column_name) = - form_identifier(s).unwrap(); - Ok(Expr::Column(Column::new(relation, column_name))) - } + } + // Found matching field with no spare identifier(s) + Some((field, qualifier, _nested_names)) => { + // Found an exact match on a qualified name in the outer plan schema, so this is an outer reference column + Ok(Expr::OuterReferenceColumn( + field.data_type().clone(), + Column::from((qualifier, field)), + )) + } + // Found no matching field, will return a default + None => continue, + }; + maybe_result = Some(result); + break; + } + if let Some(result) = maybe_result { + result + } else { + let s = &ids[0..ids.len()]; + // safe unwrap as s can never be empty or exceed the bounds + let (relation, column_name) = form_identifier(s).unwrap(); + Ok(Expr::Column(Column::new(relation, column_name))) } } else { let s = &ids[0..ids.len()]; diff --git a/datafusion/sql/src/expr/subquery.rs b/datafusion/sql/src/expr/subquery.rs index 602d39233d58..6e10607d8533 100644 --- a/datafusion/sql/src/expr/subquery.rs +++ b/datafusion/sql/src/expr/subquery.rs @@ -31,11 +31,10 @@ impl SqlToRel<'_, S> { input_schema: &DFSchema, planner_context: &mut PlannerContext, ) -> Result { - let old_outer_query_schema = - planner_context.set_outer_query_schema(Some(input_schema.clone().into())); + planner_context.append_outer_query_schema(input_schema.clone().into()); let sub_plan = self.query_to_plan(subquery, planner_context)?; let outer_ref_columns = sub_plan.all_out_ref_exprs(); - planner_context.set_outer_query_schema(old_outer_query_schema); + planner_context.pop_outer_query_schema(); Ok(Expr::Exists(Exists { subquery: Subquery { subquery: Arc::new(sub_plan), @@ -54,8 +53,7 @@ impl SqlToRel<'_, S> { input_schema: &DFSchema, planner_context: &mut PlannerContext, ) -> Result { - let old_outer_query_schema = - planner_context.set_outer_query_schema(Some(input_schema.clone().into())); + planner_context.append_outer_query_schema(input_schema.clone().into()); let mut spans = Spans::new(); if let SetExpr::Select(select) = subquery.body.as_ref() { @@ -70,7 +68,7 @@ impl SqlToRel<'_, S> { let sub_plan = self.query_to_plan(subquery, planner_context)?; let outer_ref_columns = sub_plan.all_out_ref_exprs(); - planner_context.set_outer_query_schema(old_outer_query_schema); + planner_context.pop_outer_query_schema(); self.validate_single_column( &sub_plan, @@ -98,8 +96,8 @@ impl SqlToRel<'_, S> { input_schema: &DFSchema, planner_context: &mut PlannerContext, ) -> Result { - let old_outer_query_schema = - planner_context.set_outer_query_schema(Some(input_schema.clone().into())); + planner_context.append_outer_query_schema(input_schema.clone().into()); + let mut spans = Spans::new(); if let SetExpr::Select(select) = subquery.body.as_ref() { for item in &select.projection { @@ -112,7 +110,7 @@ impl SqlToRel<'_, S> { } let sub_plan = self.query_to_plan(subquery, planner_context)?; let outer_ref_columns = sub_plan.all_out_ref_exprs(); - planner_context.set_outer_query_schema(old_outer_query_schema); + planner_context.pop_outer_query_schema(); self.validate_single_column( &sub_plan, diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 26c982690115..bcb7e6fcde18 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -198,7 +198,15 @@ pub struct PlannerContext { /// Map of CTE name to logical plan of the WITH clause. /// Use `Arc` to allow cheap cloning ctes: HashMap>, + + /// The queries schemas of outer query relations, used to resolve the outer referenced + /// columns in subquery (recursive aware) + outer_queries_schemas_stack: Vec, + /// The query schema of the outer query plan, used to resolve the columns in subquery + /// This field is maintained to support deprecated functions + /// `outer_query_schema` and `set_outer_query_schema` + /// which is only aware of the adjacent outer relation outer_query_schema: Option, /// The joined schemas of all FROM clauses planned so far. When planning LATERAL /// FROM clauses, this should become a suffix of the `outer_query_schema`. @@ -220,6 +228,7 @@ impl PlannerContext { prepare_param_data_types: Arc::new(vec![]), ctes: HashMap::new(), outer_query_schema: None, + outer_queries_schemas_stack: vec![], outer_from_schema: None, create_table_schema: None, } @@ -234,13 +243,22 @@ impl PlannerContext { self } - // Return a reference to the outer query's schema + /// Return a reference to the outer query's schema + /// This function should not be used together with + /// `outer_queries_schemas`, `append_outer_query_schema` + /// `latest_outer_query_schema` and `pop_outer_query_schema` + #[deprecated(note = "Use outer_queries_schemas instead")] pub fn outer_query_schema(&self) -> Option<&DFSchema> { self.outer_query_schema.as_ref().map(|s| s.as_ref()) } /// Sets the outer query schema, returning the existing one, if - /// any + /// any, this function should not be used together with + /// `outer_queries_schemas`, `append_outer_query_schema` + /// `latest_outer_query_schema` and `pop_outer_query_schema` + #[deprecated( + note = "This struct is now aware of a stack of schemas, check pop_outer_query_schema" + )] pub fn set_outer_query_schema( &mut self, mut schema: Option, @@ -249,6 +267,28 @@ impl PlannerContext { schema } + /// Return the stack of outer relations' schemas, the outer most + /// relation are at the first entry + pub fn outer_queries_schemas(&self) -> Vec { + self.outer_queries_schemas_stack.to_vec() + } + + /// Sets the outer query schema, returning the existing one, if + /// any + pub fn append_outer_query_schema(&mut self, schema: DFSchemaRef) { + self.outer_queries_schemas_stack.push(schema); + } + + /// The schema of the adjacent outer relation + pub fn latest_outer_query_schema(&mut self) -> Option { + self.outer_queries_schemas_stack.last().cloned() + } + + /// Remove the schema of the adjacent outer relation + pub fn pop_outer_query_schema(&mut self) -> Option { + self.outer_queries_schemas_stack.pop() + } + pub fn set_table_schema( &mut self, mut schema: Option, diff --git a/datafusion/sql/src/relation/mod.rs b/datafusion/sql/src/relation/mod.rs index aa37d74fd4d8..e1be8d30ec92 100644 --- a/datafusion/sql/src/relation/mod.rs +++ b/datafusion/sql/src/relation/mod.rs @@ -213,20 +213,24 @@ impl SqlToRel<'_, S> { let old_from_schema = planner_context .set_outer_from_schema(None) .unwrap_or_else(|| Arc::new(DFSchema::empty())); - let new_query_schema = match planner_context.outer_query_schema() { - Some(old_query_schema) => { + let outer_query_schema = planner_context.pop_outer_query_schema(); + let new_query_schema = match outer_query_schema { + Some(ref old_query_schema) => { let mut new_query_schema = old_from_schema.as_ref().clone(); - new_query_schema.merge(old_query_schema); - Some(Arc::new(new_query_schema)) + new_query_schema.merge(old_query_schema.as_ref()); + Arc::new(new_query_schema) } - None => Some(Arc::clone(&old_from_schema)), + None => Arc::clone(&old_from_schema), }; - let old_query_schema = planner_context.set_outer_query_schema(new_query_schema); + planner_context.append_outer_query_schema(new_query_schema); let plan = self.create_relation(subquery, planner_context)?; let outer_ref_columns = plan.all_out_ref_exprs(); - planner_context.set_outer_query_schema(old_query_schema); + planner_context.pop_outer_query_schema(); + if let Some(schema) = outer_query_schema { + planner_context.append_outer_query_schema(schema); + } planner_context.set_outer_from_schema(Some(old_from_schema)); // We can omit the subquery wrapper if there are no columns diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index b50fbf68129c..b47c12774bfd 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -29,7 +29,7 @@ use crate::utils::{ use datafusion_common::error::DataFusionErrorBuilder; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; -use datafusion_common::{not_impl_err, plan_err, Result}; +use datafusion_common::{not_impl_err, plan_err, DFSchema, Result}; use datafusion_common::{RecursionUnnestOption, UnnestOptions}; use datafusion_expr::expr::{Alias, PlannedReplaceSelectItem, WildcardOptions}; use datafusion_expr::expr_rewriter::{ @@ -521,12 +521,8 @@ impl SqlToRel<'_, S> { match selection { Some(predicate_expr) => { let fallback_schemas = plan.fallback_normalize_schemas(); - let outer_query_schema = planner_context.outer_query_schema().cloned(); - let outer_query_schema_vec = outer_query_schema - .as_ref() - .map(|schema| vec![schema]) - .unwrap_or_else(Vec::new); + let outer_query_schema_vec = planner_context.outer_queries_schemas(); let filter_expr = self.sql_to_expr(predicate_expr, plan.schema(), planner_context)?; @@ -541,9 +537,19 @@ impl SqlToRel<'_, S> { let mut using_columns = HashSet::new(); expr_to_columns(&filter_expr, &mut using_columns)?; + let mut schema_stack: Vec> = + vec![vec![plan.schema()], fallback_schemas]; + for sc in outer_query_schema_vec.iter().rev() { + schema_stack.push(vec![sc.as_ref()]); + } + let filter_expr = normalize_col_with_schemas_and_ambiguity_check( filter_expr, - &[&[plan.schema()], &fallback_schemas, &outer_query_schema_vec], + schema_stack + .iter() + .map(|sc| sc.as_slice()) + .collect::>() + .as_slice(), &[using_columns], )?; diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index 4fb1e42d6028..f6be173051de 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -124,7 +124,11 @@ impl Unparser<'_> { | LogicalPlan::Copy(_) | LogicalPlan::DescribeTable(_) | LogicalPlan::RecursiveQuery(_) - | LogicalPlan::Unnest(_) => not_impl_err!("Unsupported plan: {plan:?}"), + | LogicalPlan::Unnest(_) + | LogicalPlan::DependentJoin(_) + | LogicalPlan::DelimGet(_) => { + not_impl_err!("Unsupported plan: {plan:?}") + } } } @@ -775,7 +779,8 @@ impl Unparser<'_> { JoinType::Inner | JoinType::Left | JoinType::Right - | JoinType::Full => { + | JoinType::Full + | JoinType::LeftSingle => { let Ok(Some(relation)) = right_relation.build() else { return internal_err!("Failed to build right relation"); }; @@ -1267,8 +1272,8 @@ impl Unparser<'_> { JoinType::LeftSemi => ast::JoinOperator::LeftSemi(constraint), JoinType::RightAnti => ast::JoinOperator::RightAnti(constraint), JoinType::RightSemi => ast::JoinOperator::RightSemi(constraint), - JoinType::LeftMark | JoinType::RightMark => { - unimplemented!("Unparsing of Mark join type") + JoinType::LeftMark | JoinType::RightMark | JoinType::LeftSingle => { + unimplemented!("Unparsing of {} join type", join_type) } }) } diff --git a/datafusion/sqllogictest/test_files/dependent_join_temp.slt b/datafusion/sqllogictest/test_files/dependent_join_temp.slt new file mode 100644 index 000000000000..ef95cd38dfd6 --- /dev/null +++ b/datafusion/sqllogictest/test_files/dependent_join_temp.slt @@ -0,0 +1,193 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# make sure to a batch size smaller than row number of the table. +statement ok +set datafusion.execution.batch_size = 2; + +############# +## Subquery Tests +############# + + +############# +## Setup test data table +############# +# there tables for subquery +statement ok +CREATE TABLE t0(t0_id INT, t0_name TEXT, t0_int INT) AS VALUES +(11, 'o', 6), +(22, 'p', 7), +(33, 'q', 8), +(44, 'r', 9); + +statement ok +CREATE TABLE t1(t1_id INT, t1_name TEXT, t1_int INT) AS VALUES +(11, 'a', 1), +(22, 'b', 2), +(33, 'c', 3), +(44, 'd', 4); + +statement ok +CREATE TABLE t2(t2_id INT, t2_name TEXT, t2_int INT) AS VALUES +(11, 'z', 3), +(22, 'y', 1), +(44, 'x', 3), +(55, 'w', 3); + +statement ok +CREATE TABLE t3(t3_id INT PRIMARY KEY, t3_name TEXT, t3_int INT) AS VALUES +(11, 'e', 3), +(22, 'f', 1), +(44, 'g', 3), +(55, 'h', 3); + +statement ok +CREATE EXTERNAL TABLE IF NOT EXISTS customer ( + c_custkey BIGINT, + c_name VARCHAR, + c_address VARCHAR, + c_nationkey BIGINT, + c_phone VARCHAR, + c_acctbal DECIMAL(15, 2), + c_mktsegment VARCHAR, + c_comment VARCHAR, +) STORED AS CSV LOCATION '../core/tests/tpch-csv/customer.csv' OPTIONS ('format.delimiter' ',', 'format.has_header' 'true'); + +statement ok +CREATE EXTERNAL TABLE IF NOT EXISTS orders ( + o_orderkey BIGINT, + o_custkey BIGINT, + o_orderstatus VARCHAR, + o_totalprice DECIMAL(15, 2), + o_orderdate DATE, + o_orderpriority VARCHAR, + o_clerk VARCHAR, + o_shippriority INTEGER, + o_comment VARCHAR, +) STORED AS CSV LOCATION '../core/tests/tpch-csv/orders.csv' OPTIONS ('format.delimiter' ',', 'format.has_header' 'true'); + +statement ok +CREATE EXTERNAL TABLE IF NOT EXISTS lineitem ( + l_orderkey BIGINT, + l_partkey BIGINT, + l_suppkey BIGINT, + l_linenumber INTEGER, + l_quantity DECIMAL(15, 2), + l_extendedprice DECIMAL(15, 2), + l_discount DECIMAL(15, 2), + l_tax DECIMAL(15, 2), + l_returnflag VARCHAR, + l_linestatus VARCHAR, + l_shipdate DATE, + l_commitdate DATE, + l_receiptdate DATE, + l_shipinstruct VARCHAR, + l_shipmode VARCHAR, + l_comment VARCHAR, +) STORED AS CSV LOCATION '../core/tests/tpch-csv/lineitem.csv' OPTIONS ('format.delimiter' ',', 'format.has_header' 'true'); + +statement ok +set datafusion.explain.logical_plan_only = true; + +# correlated_recursive_scalar_subquery_with_level_3_scalar_subquery_referencing_level1_relation +query TT +explain select c_custkey from customer +where c_acctbal < ( + select sum(o_totalprice) from orders + where o_custkey = c_custkey + and o_totalprice < ( + select sum(l_extendedprice) as price from lineitem where l_orderkey = o_orderkey + and l_extendedprice < c_acctbal + ) +) order by c_custkey; +---- +logical_plan +01)Sort: customer.c_custkey ASC NULLS LAST +02)--Projection: customer.c_custkey +03)----Filter: CAST(customer.c_acctbal AS Decimal128(25, 2)) < __scalar_sq_2.output +04)------Projection: customer.c_custkey, customer.c_acctbal, __scalar_sq_2.output +05)--------DependentJoin on [customer.c_custkey lvl 1, customer.c_acctbal lvl 2] with expr () depth 1 +06)----------TableScan: customer +07)----------Projection: sum(orders.o_totalprice) +08)------------Aggregate: groupBy=[[]], aggr=[[sum(orders.o_totalprice)]] +09)--------------Projection: orders.o_orderkey, orders.o_custkey, orders.o_orderstatus, orders.o_totalprice, orders.o_orderdate, orders.o_orderpriority, orders.o_clerk, orders.o_shippriority, orders.o_comment +10)----------------Filter: orders.o_custkey = outer_ref(customer.c_custkey) AND CAST(orders.o_totalprice AS Decimal128(25, 2)) < __scalar_sq_1.output +11)------------------DependentJoin on [orders.o_orderkey lvl 2] with expr () depth 2 +12)--------------------TableScan: orders +13)--------------------Projection: sum(lineitem.l_extendedprice) AS price +14)----------------------Aggregate: groupBy=[[]], aggr=[[sum(lineitem.l_extendedprice)]] +15)------------------------Filter: lineitem.l_orderkey = outer_ref(orders.o_orderkey) AND lineitem.l_extendedprice < outer_ref(customer.c_acctbal) +16)--------------------------TableScan: lineitem, partial_filters=[lineitem.l_orderkey = outer_ref(orders.o_orderkey), lineitem.l_extendedprice < outer_ref(customer.c_acctbal)] + +# correlated_recursive_scalar_subquery_with_level_3_exists_subquery_referencing_level1_relation +query TT +explain select c_custkey from customer +where c_acctbal < ( + select sum(o_totalprice) from orders + where o_custkey = c_custkey + and exists ( + select * from lineitem where l_orderkey = o_orderkey + and l_extendedprice < c_acctbal + ) +) order by c_custkey; +---- +logical_plan +01)Sort: customer.c_custkey ASC NULLS LAST +02)--Projection: customer.c_custkey +03)----Filter: CAST(customer.c_acctbal AS Decimal128(25, 2)) < __scalar_sq_2.output +04)------Projection: customer.c_custkey, customer.c_acctbal, __scalar_sq_2.output +05)--------DependentJoin on [customer.c_custkey lvl 1, customer.c_acctbal lvl 2] with expr () depth 1 +06)----------TableScan: customer +07)----------Projection: sum(orders.o_totalprice) +08)------------Aggregate: groupBy=[[]], aggr=[[sum(orders.o_totalprice)]] +09)--------------Projection: orders.o_orderkey, orders.o_custkey, orders.o_orderstatus, orders.o_totalprice, orders.o_orderdate, orders.o_orderpriority, orders.o_clerk, orders.o_shippriority, orders.o_comment +10)----------------Filter: orders.o_custkey = outer_ref(customer.c_custkey) AND __exists_sq_1.output +11)------------------DependentJoin on [orders.o_orderkey lvl 2] with expr EXISTS () depth 2 +12)--------------------TableScan: orders +13)--------------------Projection: lineitem.l_orderkey, lineitem.l_partkey, lineitem.l_suppkey, lineitem.l_linenumber, lineitem.l_quantity, lineitem.l_extendedprice, lineitem.l_discount, lineitem.l_tax, lineitem.l_returnflag, lineitem.l_linestatus, lineitem.l_shipdate, lineitem.l_commitdate, lineitem.l_receiptdate, lineitem.l_shipinstruct, lineitem.l_shipmode, lineitem.l_comment +14)----------------------Filter: lineitem.l_orderkey = outer_ref(orders.o_orderkey) AND lineitem.l_extendedprice < outer_ref(customer.c_acctbal) +15)------------------------TableScan: lineitem, partial_filters=[lineitem.l_orderkey = outer_ref(orders.o_orderkey), lineitem.l_extendedprice < outer_ref(customer.c_acctbal)] + +# correlated_recursive_scalar_subquery_with_level_3_in_subquery_referencing_level1_relation +query TT +explain select c_custkey from customer +where c_acctbal < ( + select sum(o_totalprice) from orders + where o_custkey = c_custkey + and o_totalprice in ( + select l_extendedprice as price from lineitem where l_orderkey = o_orderkey + and l_extendedprice < c_acctbal + ) +) order by c_custkey; +---- +logical_plan +01)Sort: customer.c_custkey ASC NULLS LAST +02)--Projection: customer.c_custkey +03)----Filter: CAST(customer.c_acctbal AS Decimal128(25, 2)) < __scalar_sq_2.output +04)------Projection: customer.c_custkey, customer.c_acctbal, __scalar_sq_2.output +05)--------DependentJoin on [customer.c_custkey lvl 1, customer.c_acctbal lvl 2] with expr () depth 1 +06)----------TableScan: customer +07)----------Projection: sum(orders.o_totalprice) +08)------------Aggregate: groupBy=[[]], aggr=[[sum(orders.o_totalprice)]] +09)--------------Projection: orders.o_orderkey, orders.o_custkey, orders.o_orderstatus, orders.o_totalprice, orders.o_orderdate, orders.o_orderpriority, orders.o_clerk, orders.o_shippriority, orders.o_comment +10)----------------Filter: orders.o_custkey = outer_ref(customer.c_custkey) AND __in_sq_1.output +11)------------------DependentJoin on [orders.o_orderkey lvl 2] with expr orders.o_totalprice IN () depth 2 +12)--------------------TableScan: orders +13)--------------------Projection: lineitem.l_extendedprice AS price +14)----------------------Filter: lineitem.l_orderkey = outer_ref(orders.o_orderkey) AND lineitem.l_extendedprice < outer_ref(customer.c_acctbal) +15)------------------------TableScan: lineitem, partial_filters=[lineitem.l_orderkey = outer_ref(orders.o_orderkey), lineitem.l_extendedprice < outer_ref(customer.c_acctbal)] diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index ff6a5e661a98..192ea66a77c1 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -4689,6 +4689,30 @@ logical_plan 08)----------TableScan: j3 projection=[j3_string, j3_id] physical_plan_error This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn(Int32, Column { relation: Some(Bare { table: "j1" }), name: "j1_id" }) +# 2 nested lateral join with the deepest join referencing the outer most relation +query TT +explain SELECT * FROM j1 j1_outer, LATERAL ( + SELECT * FROM j1 j1_inner, LATERAL ( + SELECT * FROM j2 WHERE j1_inner.j1_id = j2_id and j1_outer.j1_id=j2_id + ) as j2 +) as j2; +---- +logical_plan +01)Cross Join: +02)--SubqueryAlias: j1_outer +03)----TableScan: j1 projection=[j1_string, j1_id] +04)--SubqueryAlias: j2 +05)----Subquery: +06)------Cross Join: +07)--------SubqueryAlias: j1_inner +08)----------TableScan: j1 projection=[j1_string, j1_id] +09)--------SubqueryAlias: j2 +10)----------Subquery: +11)------------Filter: outer_ref(j1_inner.j1_id) = j2.j2_id AND outer_ref(j1_outer.j1_id) = j2.j2_id +12)--------------TableScan: j2 projection=[j2_string, j2_id] +physical_plan_error This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn(Int32, Column { relation: Some(Bare { table: "j1_inner" }), name: "j1_id" }) + + query TT explain SELECT * FROM j1, LATERAL (SELECT 1) AS j2; ---- diff --git a/datafusion/sqllogictest/test_files/subquery.slt b/datafusion/sqllogictest/test_files/subquery.slt index 796570633f67..23a7388360ec 100644 --- a/datafusion/sqllogictest/test_files/subquery.slt +++ b/datafusion/sqllogictest/test_files/subquery.slt @@ -439,7 +439,7 @@ SELECT t1_id, t1_name, t1_int, (select t2_id, t2_name FROM t2 WHERE t2.t2_id = t #subquery_not_allowed #In/Exist Subquery is not allowed in ORDER BY clause. -statement error DataFusion error: Invalid \(non-executable\) plan after Analyzer\ncaused by\nError during planning: In/Exist subquery can only be used in Projection, Filter, TableScan, Window functions, Aggregate and Join plan nodes, but was used in \[Sort: t1.t1_int IN \(\) ASC NULLS LAST\] +statement error DataFusion error: Invalid \(non-executable\) plan after Analyzer\ncaused by\nError during planning: In/Exist subquery can only be used in Projection, Filter, TableScan, Window functions, Aggregate, Join and Dependent Join plan nodes, but was used in \[Sort: t1.t1_int IN \(\) ASC NULLS LAST\] SELECT t1_id, t1_name, t1_int FROM t1 order by t1_int in (SELECT t2_int FROM t2 WHERE t1.t1_id > t1.t1_int) #non_aggregated_correlated_scalar_subquery @@ -1482,3 +1482,85 @@ logical_plan statement count 0 drop table person; + +# correlated_recursive_scalar_subquery_with_level_3_scalar_subquery_referencing_level1_relation +query TT +explain select c_custkey from customer +where c_acctbal < ( + select sum(o_totalprice) from orders + where o_custkey = c_custkey + and o_totalprice < ( + select sum(l_extendedprice) as price from lineitem where l_orderkey = o_orderkey + and l_extendedprice < c_acctbal + ) +) order by c_custkey; +---- +logical_plan +01)Sort: customer.c_custkey ASC NULLS LAST +02)--Projection: customer.c_custkey +03)----Inner Join: customer.c_custkey = __scalar_sq_1.o_custkey Filter: CAST(customer.c_acctbal AS Decimal128(25, 2)) < __scalar_sq_1.sum(orders.o_totalprice) +04)------TableScan: customer projection=[c_custkey, c_acctbal] +05)------SubqueryAlias: __scalar_sq_1 +06)--------Projection: sum(orders.o_totalprice), orders.o_custkey +07)----------Aggregate: groupBy=[[orders.o_custkey]], aggr=[[sum(orders.o_totalprice)]] +08)------------Projection: orders.o_custkey, orders.o_totalprice +09)--------------Filter: CAST(orders.o_totalprice AS Decimal128(25, 2)) < () +10)----------------Subquery: +11)------------------Projection: sum(lineitem.l_extendedprice) AS price +12)--------------------Aggregate: groupBy=[[]], aggr=[[sum(lineitem.l_extendedprice)]] +13)----------------------Filter: lineitem.l_orderkey = outer_ref(orders.o_orderkey) AND lineitem.l_extendedprice < outer_ref(customer.c_acctbal) +14)------------------------TableScan: lineitem, partial_filters=[lineitem.l_orderkey = outer_ref(orders.o_orderkey), lineitem.l_extendedprice < outer_ref(customer.c_acctbal)] +15)----------------TableScan: orders projection=[o_orderkey, o_custkey, o_totalprice] + +# correlated_recursive_scalar_subquery_with_level_3_exists_subquery_referencing_level1_relation +query TT +explain select c_custkey from customer +where c_acctbal < ( + select sum(o_totalprice) from orders + where o_custkey = c_custkey + and exists ( + select * from lineitem where l_orderkey = o_orderkey + and l_extendedprice < c_acctbal + ) +) order by c_custkey; +---- +logical_plan +01)Sort: customer.c_custkey ASC NULLS LAST +02)--Projection: customer.c_custkey +03)----Inner Join: customer.c_custkey = __scalar_sq_2.o_custkey Filter: CAST(customer.c_acctbal AS Decimal128(25, 2)) < __scalar_sq_2.sum(orders.o_totalprice) +04)------TableScan: customer projection=[c_custkey, c_acctbal] +05)------SubqueryAlias: __scalar_sq_2 +06)--------Projection: sum(orders.o_totalprice), orders.o_custkey +07)----------Aggregate: groupBy=[[orders.o_custkey]], aggr=[[sum(orders.o_totalprice)]] +08)------------Projection: orders.o_custkey, orders.o_totalprice +09)--------------LeftSemi Join: orders.o_orderkey = __correlated_sq_1.l_orderkey Filter: __correlated_sq_1.l_extendedprice < customer.c_acctbal +10)----------------TableScan: orders projection=[o_orderkey, o_custkey, o_totalprice] +11)----------------SubqueryAlias: __correlated_sq_1 +12)------------------TableScan: lineitem projection=[l_orderkey, l_extendedprice] + +# correlated_recursive_scalar_subquery_with_level_3_in_subquery_referencing_level1_relation +query TT +explain select c_custkey from customer +where c_acctbal < ( + select sum(o_totalprice) from orders + where o_custkey = c_custkey + and o_totalprice in ( + select l_extendedprice as price from lineitem where l_orderkey = o_orderkey + and l_extendedprice < c_acctbal + ) +) order by c_custkey; +---- +logical_plan +01)Sort: customer.c_custkey ASC NULLS LAST +02)--Projection: customer.c_custkey +03)----Inner Join: customer.c_custkey = __scalar_sq_2.o_custkey Filter: CAST(customer.c_acctbal AS Decimal128(25, 2)) < __scalar_sq_2.sum(orders.o_totalprice) +04)------TableScan: customer projection=[c_custkey, c_acctbal] +05)------SubqueryAlias: __scalar_sq_2 +06)--------Projection: sum(orders.o_totalprice), orders.o_custkey +07)----------Aggregate: groupBy=[[orders.o_custkey]], aggr=[[sum(orders.o_totalprice)]] +08)------------Projection: orders.o_custkey, orders.o_totalprice +09)--------------LeftSemi Join: orders.o_totalprice = __correlated_sq_1.price, orders.o_orderkey = __correlated_sq_1.l_orderkey Filter: __correlated_sq_1.l_extendedprice < customer.c_acctbal +10)----------------TableScan: orders projection=[o_orderkey, o_custkey, o_totalprice] +11)----------------SubqueryAlias: __correlated_sq_1 +12)------------------Projection: lineitem.l_extendedprice AS price, lineitem.l_extendedprice, lineitem.l_orderkey +13)--------------------TableScan: lineitem projection=[l_orderkey, l_extendedprice] \ No newline at end of file diff --git a/datafusion/sqllogictest/test_files/subquery_general.slt b/datafusion/sqllogictest/test_files/subquery_general.slt new file mode 100644 index 000000000000..2c27e1b9d455 --- /dev/null +++ b/datafusion/sqllogictest/test_files/subquery_general.slt @@ -0,0 +1,367 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# make sure to a batch size smaller than row number of the table. +statement ok +set datafusion.execution.batch_size = 2; + +############# +## Subquery Tests +############# + + +############# +## Setup test data table +############# +# there tables for subquery +statement ok +CREATE TABLE t0(t0_id INT, t0_name TEXT, t0_int INT) AS VALUES +(11, 'o', 6), +(22, 'p', 7), +(33, 'q', 8), +(44, 'r', 9); + +statement ok +CREATE TABLE t1(t1_id INT, t1_name TEXT, t1_int INT) AS VALUES +(11, 'a', 1), +(22, 'b', 2), +(33, 'c', 3), +(44, 'd', 4); + +statement ok +CREATE TABLE t2(t2_id INT, t2_name TEXT, t2_int INT) AS VALUES +(11, 'z', 3), +(22, 'y', 1), +(44, 'x', 3), +(55, 'w', 3); + +statement ok +CREATE TABLE t3(t3_id INT PRIMARY KEY, t3_name TEXT, t3_int INT) AS VALUES +(11, 'e', 3), +(22, 'f', 1), +(44, 'g', 3), +(55, 'h', 3); + +# in_subquery_to_join_with_correlated_outer_filter +query ITI rowsort +select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id + 12 in ( + select t2.t2_id + 1 from t2 where t1.t1_int > 0 + ) +---- +11 a 1 +33 c 3 +44 d 4 + +# not_in_subquery_to_join_with_correlated_outer_filter +query ITI rowsort +select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id + 12 not in ( + select t2.t2_id + 1 from t2 where t1.t1_int > 0 + ) +---- +22 b 2 + +# wrapped_not_in_subquery_to_join_with_correlated_outer_filter +query ITI rowsort +select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where not t1.t1_id + 12 in ( + select t2.t2_id + 1 from t2 where t1.t1_int > 0 + ) +---- +22 b 2 + +query II rowsort +SELECT t1_id, (SELECT sum(t2_int) FROM t2 WHERE t2.t2_id = t1.t1_id) as t2_sum from t1 +---- +11 3 +22 1 +33 NULL +44 3 + +query IR rowsort +SELECT t1_id, (SELECT sum(t2_int * 1.0) + 1 FROM t2 WHERE t2.t2_id = t1.t1_id) as t2_sum from t1 +---- +11 4 +22 2 +33 NULL +44 4 + +query II rowsort +SELECT t1_id, (SELECT sum(t2_int) FROM t2 WHERE t2.t2_id = t1.t1_id group by t2_id, 'a') as t2_sum from t1 +---- +11 3 +22 1 +33 NULL +44 3 + +query II rowsort +SELECT t1_id, (SELECT sum(t2_int) FROM t2 WHERE t2.t2_id = t1.t1_id having sum(t2_int) < 3) as t2_sum from t1 +---- +11 NULL +22 1 +33 NULL +44 NULL + +#non_aggregated_correlated_scalar_subquery_unique +query II rowsort +SELECT t1_id, (SELECT t3_int FROM t3 WHERE t3.t3_id = t1.t1_id) as t3_int from t1 +---- +11 3 +22 1 +33 NULL +44 3 + +query II rowsort +SELECT t1_id, (SELECT a FROM (select 1 as a) WHERE a = t1.t1_int) as t2_int from t1 +---- +11 1 +22 NULL +33 NULL +44 NULL + +query IT rowsort +SELECT t1_id, t1_name FROM t1 WHERE EXISTS (SELECT * FROM t2 WHERE t2_id = t1_id limit 1) +---- +11 a +22 b +44 d + +query ITI rowsort +select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id > 40 or t1.t1_id in (select t2.t2_id from t2 where t1.t1_int > 0) +---- +11 a 1 +22 b 2 +44 d 4 + +query ITI rowsort +select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id = 11 or t1.t1_id + 12 not in (select t2.t2_id + 1 from t2 where t1.t1_int > 0) +---- +11 a 1 +22 b 2 + +query ITI rowsort +select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id > 40 or exists (select * from t2 where t1.t1_id = t2.t2_id) +---- +11 a 1 +22 b 2 +44 d 4 + +query ITI rowsort +select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id > 40 or not exists (select * from t2 where t1.t1_id = t2.t2_id) +---- +33 c 3 +44 d 4 + +# fail +# query ITI rowsort +# select t1.t1_id, +# t1.t1_name, +# t1.t1_int +# from t1 +# where t1.t1_id in (select t3.t3_id from t3) and (t1.t1_id > 40 or t1.t1_id in (select t2.t2_id from t2 where t1.t1_int > 0)) +# ---- +# 11 a 1 +# 22 b 2 +# 44 d 4 + +# Handle duplicate values in exists query +query ITI rowsort +select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id > 40 or exists (select * from t2 cross join t3 where t1.t1_id = t2.t2_id) +---- +11 a 1 +22 b 2 +44 d 4 + +# Nested subqueries +query ITI rowsort +select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where exists ( + select * from t2 where t1.t1_id = t2.t2_id OR exists ( + select * from t3 where t2.t2_id = t3.t3_id + ) +) +---- +11 a 1 +22 b 2 +33 c 3 +44 d 4 + +#SELECT t1_id, t1_name FROM t1 WHERE NOT EXISTS (SELECT * FROM t2 WHERE t2_id = t1_id limit 0) +#---- +#11 a +#22 b +#33 c +#44 d + +# fail +# query II rowsort +# SELECT t1_id, (SELECT count(*) FROM t2 WHERE t2.t2_int = t1.t1_int) from t1 +# ---- +# 11 1 +# 22 0 +# 33 3 +# 44 0 + +# fail +# query II rowsort +# SELECT t1_id, (SELECT count(*) FROM t2 WHERE t2.t2_int = t1.t1_int) as cnt from t1 +# ---- +# 11 1 +# 22 0 +# 33 3 +# 44 0 + +# fail +# query II rowsort +# SELECT t1_id, (SELECT count(*) as _cnt FROM t2 WHERE t2.t2_int = t1.t1_int) as cnt from t1 +# ---- +# 11 1 +# 22 0 +# 33 3 +# 44 0 + +# fail +# query II rowsort +# SELECT t1_id, (SELECT count(*) + 2 as _cnt FROM t2 WHERE t2.t2_int = t1.t1_int) from t1 +# ---- +# 11 3 +# 22 2 +# 33 5 +# 44 2 + +# fail +# query I rowsort +# select t1.t1_int from t1 where (select count(*) from t2 where t1.t1_id = t2.t2_id) < t1.t1_int +# ---- +# 2 +# 3 +# 4 + +# fail +# query II rowsort +# SELECT t1_id, (SELECT count(*) + 2 as cnt_plus_2 FROM t2 WHERE t2.t2_int = t1.t1_int having count(*) >1) from t1 +# ---- +# 11 NULL +# 22 NULL +# 33 5 +# 44 NULL + +# fail +# query II rowsort +# SELECT t1_id, (SELECT count(*) + 2 as cnt_plus_2 FROM t2 WHERE t2.t2_int = t1.t1_int having count(*) = 0) from t1 +# ---- +# 11 NULL +# 22 2 +# 33 NULL +# 44 2 + +# fail +# query I rowsort +# select t1.t1_int from t1 group by t1.t1_int having (select count(*) from t2 where t1.t1_int = t2.t2_int) = 0 +# ---- +# 2 +# 4 + +# fail +# query I rowsort +# select t1.t1_int from t1 where (select cnt from (select count(*) as cnt, sum(t2_int) from t2 where t1.t1_int = t2.t2_int)) = 0 +# ---- +# 2 +# 4 + +# fail +# query I rowsort +# select t1.t1_int from t1 where ( +# select cnt_plus_one + 1 as cnt_plus_two from ( +# select cnt + 1 as cnt_plus_one from ( +# select count(*) as cnt, sum(t2_int) s from t2 where t1.t1_int = t2.t2_int having cnt = 0 +# ) +# ) +# ) = 2 +# ---- +# 2 +# 4 + +# fail +# query I rowsort +# select t1.t1_int from t1 where +# (select case when count(*) = 1 then null else count(*) end as cnt from t2 where t2.t2_int = t1.t1_int) = 0 +# ---- +# 2 +# 4 + +# fail +# query B rowsort +# select t1_int > (select avg(t1_int) from t1) from t1 +# ---- +# false +# false +# true +# true + +# fail +# query IT rowsort +# SELECT t1_id, (SELECT case when max(t2.t2_id) > 1 then 'a' else 'b' end FROM t2 WHERE t2.t2_int = t1.t1_int) x from t1 +# ---- +# 11 a +# 22 b +# 33 a +# 44 b + +# fail +# query IB rowsort +# SELECT t1_id, (SELECT max(t2.t2_id) is null FROM t2 WHERE t2.t2_int = t1.t1_int) x from t1 +# ---- +# 11 false +# 22 true +# 33 false +# 44 true + + diff --git a/datafusion/substrait/src/logical_plan/producer/rel/join.rs b/datafusion/substrait/src/logical_plan/producer/rel/join.rs index 3dbac636feed..65e8cbad1eec 100644 --- a/datafusion/substrait/src/logical_plan/producer/rel/join.rs +++ b/datafusion/substrait/src/logical_plan/producer/rel/join.rs @@ -115,7 +115,7 @@ fn to_substrait_jointype(join_type: JoinType) -> join_rel::JoinType { JoinType::LeftSemi => join_rel::JoinType::LeftSemi, JoinType::LeftMark => join_rel::JoinType::LeftMark, JoinType::RightMark => join_rel::JoinType::RightMark, - JoinType::RightAnti | JoinType::RightSemi => { + JoinType::RightAnti | JoinType::RightSemi | JoinType::LeftSingle => { unimplemented!() } } diff --git a/datafusion/substrait/src/logical_plan/producer/rel/mod.rs b/datafusion/substrait/src/logical_plan/producer/rel/mod.rs index c3599a2635ff..372f59677a77 100644 --- a/datafusion/substrait/src/logical_plan/producer/rel/mod.rs +++ b/datafusion/substrait/src/logical_plan/producer/rel/mod.rs @@ -74,5 +74,11 @@ pub fn to_substrait_rel( LogicalPlan::RecursiveQuery(plan) => { not_impl_err!("Unsupported plan type: {plan:?}")? } + LogicalPlan::DependentJoin(join) => { + not_impl_err!("Unsupported plan type: {join:?}")? + } + LogicalPlan::DelimGet(delim_get) => { + not_impl_err!("Unsupported plan type: {delim_get:?}")? + } } }