Skip to content

Commit aa2fd29

Browse files
author
longshan.lu
committed
fix: Enhance join handling in optimizer rules by introducing new join method and improving cross join elimination logic
1 parent 1f29786 commit aa2fd29

File tree

7 files changed

+208
-80
lines changed

7 files changed

+208
-80
lines changed

qurious/src/execution/session.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use crate::common::table_relation::TableRelation;
1010
use crate::datasource::memory::MemoryTable;
1111
use crate::error::Error;
1212
use crate::functions::{all_builtin_functions, UserDefinedFunction};
13-
use crate::{internal_err, utils};
13+
use crate::internal_err;
1414
use crate::logical::plan::{
1515
CreateMemoryTable, DdlStatement, DmlOperator, DmlStatement, DropTable, Filter, LogicalPlan,
1616
};
@@ -98,7 +98,6 @@ impl ExecuteSession {
9898
LogicalPlan::Dml(stmt) => self.execute_dml(stmt),
9999
plan => {
100100
let plan = self.optimizer.optimize(plan)?;
101-
println!("{}", utils::format(&plan, 0));
102101
self.planner.create_physical_plan(&plan)?.execute()
103102
}
104103
}

qurious/src/logical/builder.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,14 @@ impl LogicalPlanBuilder {
8585
})
8686
}
8787

88-
pub fn join_on(self, right: LogicalPlan, join_type: JoinType, filter: Option<LogicalExpr>) -> Result<Self> {
88+
pub fn join(
89+
self,
90+
right: LogicalPlan,
91+
join_type: JoinType,
92+
on: Vec<(LogicalExpr, LogicalExpr)>,
93+
filter: Option<LogicalExpr>,
94+
) -> Result<Self> {
8995
let schema = build_join_schema(join_type, &self.plan.table_schema(), &right.table_schema())?;
90-
let on = vec![];
9196

9297
if join_type != JoinType::Inner && filter.is_none() && on.is_empty() {
9398
return Err(Error::InternalError(format!("join condition should not be empty")));

qurious/src/logical/plan/mod.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -294,11 +294,18 @@ impl TransformNode for LogicalPlan {
294294
})
295295
}
296296

297-
fn apply_children<'n, F>(&'n self, _f: F) -> Result<TreeNodeRecursion>
297+
fn apply_children<'n, F>(&'n self, mut f: F) -> Result<TreeNodeRecursion>
298298
where
299299
F: FnMut(&'n LogicalPlan) -> Result<TreeNodeRecursion>,
300300
{
301-
todo!()
301+
for child in self.children().into_iter().flatten() {
302+
match f(child)? {
303+
TreeNodeRecursion::Continue => {}
304+
TreeNodeRecursion::Stop => return Ok(TreeNodeRecursion::Stop),
305+
}
306+
}
307+
308+
Ok(TreeNodeRecursion::Continue)
302309
}
303310
}
304311

qurious/src/optimizer/rule/eliminate_cross_join.rs

Lines changed: 174 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,24 @@
1-
use indexmap::IndexSet;
1+
use std::{collections::HashMap, sync::Arc};
2+
3+
use indexmap::{Equivalent, IndexSet};
24

35
use crate::{
4-
common::{join_type::JoinType, table_schema::TableSchemaRef, transformed::Transformed},
6+
common::{
7+
join_type::JoinType,
8+
table_schema::TableSchemaRef,
9+
transformed::{TransformNode, Transformed, TransformedResult, TreeNodeRecursion},
10+
},
511
datatypes::operator::Operator,
612
error::Result,
713
logical::{
814
expr::{BinaryExpr, LogicalExpr},
9-
plan::{Filter, Join, LogicalPlan},
15+
plan::{Filter, LogicalPlan},
1016
LogicalPlanBuilder,
1117
},
1218
optimizer::rule::OptimizerRule,
1319
};
1420

21+
/// Eliminate cross joins by rewriting them to inner joins when possible.
1522
pub struct EliminateCrossJoin;
1623

1724
impl OptimizerRule for EliminateCrossJoin {
@@ -20,40 +27,74 @@ impl OptimizerRule for EliminateCrossJoin {
2027
}
2128

2229
fn rewrite(&self, plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
23-
match plan {
24-
LogicalPlan::Filter(filter) if matches!(filter.input.as_ref(), LogicalPlan::CrossJoin(_)) => {
25-
let LogicalPlan::CrossJoin(cross_join) = *filter.input else {
26-
return Ok(Transformed::no(LogicalPlan::Filter(filter)));
27-
};
30+
let LogicalPlan::Filter(Filter { input, expr: predicate }) = plan else {
31+
return Ok(Transformed::no(plan));
32+
};
2833

29-
let left_schema = cross_join.left.table_schema();
30-
let right_schema = cross_join.right.table_schema();
34+
let mut all_corss_joins = vec![];
3135

32-
let (join_keys, remaining_predicate) = extract_join_pairs(&filter.expr, &left_schema, &right_schema);
36+
// collect all cross joins and filter predicates in order
37+
input.apply(|plan| {
38+
if let LogicalPlan::CrossJoin(cross_join) = plan {
39+
all_corss_joins.push(cross_join);
40+
}
41+
Ok(TreeNodeRecursion::Continue)
42+
})?;
3343

34-
if join_keys.is_empty() {
35-
Ok(Transformed::no(LogicalPlan::Filter(Filter {
36-
input: Box::new(LogicalPlan::CrossJoin(cross_join)),
37-
expr: filter.expr,
38-
})))
39-
} else {
40-
let inner_join_plan = LogicalPlan::Join(Join {
41-
left: cross_join.left,
42-
right: cross_join.right,
43-
join_type: JoinType::Inner,
44-
on: join_keys.into_iter().map(|(l, r)| (l.clone(), r.clone())).collect(),
45-
filter: None,
46-
schema: cross_join.schema,
47-
});
48-
49-
if let Some(predicate) = remaining_predicate {
50-
LogicalPlanBuilder::filter(inner_join_plan, predicate).map(Transformed::yes)
51-
} else {
52-
Ok(Transformed::yes(inner_join_plan))
53-
}
54-
}
44+
if all_corss_joins.is_empty() {
45+
return Ok(Transformed::no(LogicalPlan::Filter(Filter { input, expr: predicate })));
46+
}
47+
48+
let mut all_join_keys = IndexSet::new();
49+
let mut replaced_cross_joins = HashMap::new();
50+
let len = all_corss_joins.len();
51+
// iteratively rewrite cross joins to inner joins from bottom to top
52+
for (index, cross_join) in all_corss_joins.into_iter().rev().enumerate() {
53+
let left_schema = cross_join.left.table_schema();
54+
let right_schema = cross_join.right.table_schema();
55+
56+
let join_keys = extract_join_pairs(&predicate, &left_schema, &right_schema);
57+
58+
all_join_keys.extend(join_keys.clone());
59+
60+
if !join_keys.is_empty() {
61+
let inner_join_plan = LogicalPlanBuilder::from(Arc::unwrap_or_clone(cross_join.left.clone()))
62+
.join(
63+
Arc::unwrap_or_clone(cross_join.right.clone()),
64+
JoinType::Inner,
65+
join_keys.into_iter().collect(),
66+
None,
67+
)?
68+
.build();
69+
70+
// this index should be from the top to the bottom
71+
replaced_cross_joins.insert((len - 1) - index, inner_join_plan);
5572
}
56-
_ => Ok(Transformed::no(plan)),
73+
}
74+
75+
if replaced_cross_joins.is_empty() {
76+
return Ok(Transformed::no(LogicalPlan::Filter(Filter { input, expr: predicate })));
77+
}
78+
79+
// combine all predicates and replaced cross joins
80+
let mut index = 0;
81+
let new_input = input
82+
.transform(|plan| {
83+
let result = if let Some(replaced_join) = replaced_cross_joins.remove(&index) {
84+
Ok(Transformed::yes(replaced_join))
85+
} else {
86+
Ok(Transformed::no(plan))
87+
};
88+
index += 1;
89+
result
90+
})
91+
.data()?;
92+
93+
// remove all join keys from original predicates
94+
if let Some(predicate) = remove_join_keys(predicate, &all_join_keys) {
95+
LogicalPlanBuilder::filter(new_input, predicate).map(Transformed::yes)
96+
} else {
97+
Ok(Transformed::yes(new_input))
5798
}
5899
}
59100
}
@@ -62,7 +103,7 @@ fn extract_join_pairs<'a>(
62103
expr: &'a LogicalExpr,
63104
left_schema: &TableSchemaRef,
64105
right_schema: &TableSchemaRef,
65-
) -> (IndexSet<(&'a LogicalExpr, &'a LogicalExpr)>, Option<LogicalExpr>) {
106+
) -> IndexSet<(LogicalExpr, LogicalExpr)> {
66107
let mut join_keys = IndexSet::new();
67108

68109
match expr {
@@ -75,69 +116,92 @@ fn extract_join_pairs<'a>(
75116
let right_col = right.try_as_column();
76117

77118
if let (Some(left_col), Some(right_col)) = (left_col, right_col) {
78-
if (left_schema.has_column(left_col) || right_schema.has_column(left_col))
79-
&& (left_schema.has_column(right_col) || right_schema.has_column(right_col))
80-
{
81-
join_keys.insert((left.as_ref(), right.as_ref()));
82-
83-
return (join_keys, None);
119+
if left_schema.has_column(left_col) && right_schema.has_column(right_col) {
120+
join_keys.insert((left.as_ref().clone(), right.as_ref().clone()));
121+
} else if right_schema.has_column(left_col) && left_schema.has_column(right_col) {
122+
join_keys.insert((right.as_ref().clone(), left.as_ref().clone()));
84123
}
85124
}
86-
87-
(join_keys, Some(expr.clone()))
88125
}
89126
LogicalExpr::BinaryExpr(BinaryExpr {
90127
left,
91128
op: Operator::And,
92129
right,
93130
}) => {
94-
let (left_join_keys, left_predicate) = extract_join_pairs(left, left_schema, right_schema);
95-
let (right_join_keys, right_predicate) = extract_join_pairs(right, left_schema, right_schema);
131+
let left_join_keys = extract_join_pairs(left, left_schema, right_schema);
132+
let right_join_keys = extract_join_pairs(right, left_schema, right_schema);
96133

97134
join_keys.extend(left_join_keys);
98135
join_keys.extend(right_join_keys);
99-
100-
let predicate = match (left_predicate, right_predicate) {
101-
(Some(left_predicate), Some(right_predicate)) => Some(LogicalExpr::BinaryExpr(BinaryExpr {
102-
left: Box::new(left_predicate),
103-
op: Operator::And,
104-
right: Box::new(right_predicate),
105-
})),
106-
(l, r) => l.or(r),
107-
};
108-
109-
(join_keys, predicate)
110136
}
111137
LogicalExpr::BinaryExpr(BinaryExpr {
112138
left,
113139
op: Operator::Or,
114140
right,
115141
}) => {
116-
let (left_join_keys, left_predicate) = extract_join_pairs(left, left_schema, right_schema);
117-
let (right_join_keys, right_predicate) = extract_join_pairs(right, left_schema, right_schema);
142+
let left_join_keys = extract_join_pairs(left, left_schema, right_schema);
143+
let right_join_keys = extract_join_pairs(right, left_schema, right_schema);
118144

119145
for (l, r) in left_join_keys {
120-
if right_join_keys.contains(&(l, r)) || right_join_keys.contains(&(r, l)) {
146+
if right_join_keys.contains(&ExprPair::new(&l, &r)) || right_join_keys.contains(&ExprPair::new(&r, &l))
147+
{
121148
join_keys.insert((l, r));
122149
}
123150
}
151+
}
152+
_ => {}
153+
}
124154

125-
let predicate = match (left_predicate, right_predicate) {
126-
(Some(l), Some(r)) => Some(LogicalExpr::BinaryExpr(BinaryExpr {
127-
left: Box::new(l),
128-
op: Operator::Or,
129-
right: Box::new(r),
130-
})),
131-
(l, r) => l.or(r),
132-
};
155+
join_keys
156+
}
133157

134-
(join_keys, predicate)
158+
fn remove_join_keys(expr: LogicalExpr, join_keys: &IndexSet<(LogicalExpr, LogicalExpr)>) -> Option<LogicalExpr> {
159+
match expr {
160+
LogicalExpr::BinaryExpr(BinaryExpr {
161+
left,
162+
op: Operator::Eq,
163+
right,
164+
}) if join_keys.contains(&ExprPair::new(left.as_ref(), right.as_ref()))
165+
|| join_keys.contains(&ExprPair::new(right.as_ref(), left.as_ref())) =>
166+
{
167+
None
168+
}
169+
LogicalExpr::BinaryExpr(BinaryExpr { left, op, right }) if op == Operator::And => {
170+
let l = remove_join_keys(*left, join_keys);
171+
let r = remove_join_keys(*right, join_keys);
172+
match (l, r) {
173+
(Some(ll), Some(rr)) => Some(LogicalExpr::BinaryExpr(BinaryExpr::new(ll, op, rr))),
174+
(Some(ll), _) => Some(ll),
175+
(_, Some(rr)) => Some(rr),
176+
_ => None,
177+
}
135178
}
179+
LogicalExpr::BinaryExpr(BinaryExpr { left, op, right }) if op == Operator::Or => {
180+
let l = remove_join_keys(*left, join_keys);
181+
let r = remove_join_keys(*right, join_keys);
182+
match (l, r) {
183+
(Some(ll), Some(rr)) => Some(LogicalExpr::BinaryExpr(BinaryExpr::new(ll, op, rr))),
184+
_ => None,
185+
}
186+
}
187+
_ => Some(expr),
188+
}
189+
}
190+
191+
#[derive(Debug, Eq, PartialEq, Hash)]
192+
struct ExprPair<'a>(&'a LogicalExpr, &'a LogicalExpr);
136193

137-
_ => (join_keys, Some(expr.clone())),
194+
impl<'a> ExprPair<'a> {
195+
fn new(left: &'a LogicalExpr, right: &'a LogicalExpr) -> Self {
196+
Self(left, right)
138197
}
139198
}
140199

200+
impl Equivalent<(LogicalExpr, LogicalExpr)> for ExprPair<'_> {
201+
fn equivalent(&self, other: &(LogicalExpr, LogicalExpr)) -> bool {
202+
self.0 == &other.0 && self.1 == &other.1
203+
}
204+
}
141205
#[cfg(test)]
142206
mod tests {
143207
use crate::{optimizer::rule::eliminate_cross_join::EliminateCrossJoin, test_utils::assert_after_optimizer};
@@ -231,4 +295,46 @@ mod tests {
231295
],
232296
);
233297
}
298+
299+
#[test]
300+
fn test_tpch_03() {
301+
assert_after_optimizer(
302+
" select
303+
l_orderkey,
304+
sum(l_extendedprice * (1 - l_discount)) as revenue,
305+
o_orderdate,
306+
o_shippriority
307+
from
308+
customer,
309+
orders,
310+
lineitem
311+
where
312+
c_mktsegment = 'BUILDING'
313+
and c_custkey = o_custkey
314+
and l_orderkey = o_orderkey
315+
and o_orderdate < date '1995-03-15'
316+
and l_shipdate > date '1995-03-15'
317+
group by
318+
l_orderkey,
319+
o_orderdate,
320+
o_shippriority
321+
order by
322+
revenue desc,
323+
o_orderdate
324+
limit 10;",
325+
vec![Box::new(EliminateCrossJoin)],
326+
vec![
327+
"Limit: fetch=10, skip=0",
328+
" Sort: revenue DESC, orders.o_orderdate ASC",
329+
" Projection: (lineitem.l_orderkey, SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS revenue, orders.o_orderdate, orders.o_shippriority)",
330+
" Aggregate: group_expr=[lineitem.l_orderkey,orders.o_orderdate,orders.o_shippriority], aggregat_expr=[SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]",
331+
" Filter: customer.c_mktsegment = Utf8('BUILDING') AND orders.o_orderdate < CAST(Utf8('1995-03-15') AS Date32) AND lineitem.l_shipdate > CAST(Utf8('1995-03-15') AS Date32)",
332+
" Inner Join: On: (orders.o_orderkey, lineitem.l_orderkey)",
333+
" Inner Join: On: (customer.c_custkey, orders.o_custkey)",
334+
" TableScan: customer",
335+
" TableScan: orders",
336+
" TableScan: lineitem",
337+
],
338+
);
339+
}
234340
}

qurious/src/optimizer/rule/scalar_subquery_to_join.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ impl OptimizerRule for ScalarSubqueryToJoin {
8484
})?;
8585

8686
LogicalPlanBuilder::from(cur_input)
87-
.join_on(new_subquery_plan, JoinType::Left, join_filter)?
87+
.join(new_subquery_plan, JoinType::Left, vec![], join_filter)?
8888
.build()
8989
}
9090
};

0 commit comments

Comments
 (0)