Skip to content

Commit b92e227

Browse files
author
longshan.lu
committed
refactor: Introduce column existence check and update optimizer rules for filter handling in joins
1 parent e2b96b0 commit b92e227

File tree

8 files changed

+278
-34
lines changed

8 files changed

+278
-34
lines changed

qurious/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ url = { workspace = true }
1212
dashmap = { workspace = true }
1313
log = { workspace = true }
1414
itertools = "0.13.0"
15+
indexmap = "2.11.1"
1516

1617

1718
[dev-dependencies]

qurious/src/common/table_schema.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@ impl TableSchema {
5858
}
5959
}
6060

61+
pub fn has_column(&self, column: &Column) -> bool {
62+
self.has_field(column.relation.as_ref(), &column.name)
63+
}
64+
6165
pub fn columns(&self) -> Vec<Column> {
6266
self.schema
6367
.fields()
Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
use indexmap::IndexSet;
2+
3+
use crate::{
4+
common::{join_type::JoinType, table_schema::TableSchemaRef},
5+
datatypes::operator::Operator,
6+
error::Result,
7+
logical::{
8+
expr::{BinaryExpr, LogicalExpr},
9+
plan::{Filter, Join, LogicalPlan},
10+
LogicalPlanBuilder,
11+
},
12+
optimizer::rule::OptimizerRule,
13+
};
14+
15+
pub struct EliminateCrossJoin;
16+
17+
impl OptimizerRule for EliminateCrossJoin {
18+
fn name(&self) -> &str {
19+
"eliminate_cross_join"
20+
}
21+
22+
fn rewrite(&self, plan: LogicalPlan) -> Result<LogicalPlan> {
23+
match plan {
24+
LogicalPlan::Filter(filter) if matches!(filter.input.as_ref(), LogicalPlan::CrossJoin(_)) => {
25+
let LogicalPlan::CrossJoin(cross_join) = *filter.input else {
26+
return Ok(LogicalPlan::Filter(filter));
27+
};
28+
29+
let left_schema = cross_join.left.table_schema();
30+
let right_schema = cross_join.right.table_schema();
31+
32+
let (join_keys, remaining_predicate) = extract_join_pairs(&filter.expr, &left_schema, &right_schema);
33+
34+
if join_keys.is_empty() {
35+
Ok(LogicalPlan::Filter(Filter {
36+
input: Box::new(LogicalPlan::CrossJoin(cross_join)),
37+
expr: filter.expr,
38+
}))
39+
} else {
40+
let inner_join_plan = LogicalPlan::Join(Join {
41+
left: cross_join.left,
42+
right: cross_join.right,
43+
join_type: JoinType::Inner,
44+
on: join_keys.into_iter().map(|(l, r)| (l.clone(), r.clone())).collect(),
45+
filter: None,
46+
schema: cross_join.schema,
47+
});
48+
49+
if let Some(predicate) = remaining_predicate {
50+
LogicalPlanBuilder::filter(inner_join_plan, predicate)
51+
} else {
52+
Ok(inner_join_plan)
53+
}
54+
}
55+
}
56+
_ => Ok(plan),
57+
}
58+
}
59+
}
60+
61+
fn extract_join_pairs<'a>(
62+
expr: &'a LogicalExpr,
63+
left_schema: &TableSchemaRef,
64+
right_schema: &TableSchemaRef,
65+
) -> (IndexSet<(&'a LogicalExpr, &'a LogicalExpr)>, Option<LogicalExpr>) {
66+
let mut join_keys = IndexSet::new();
67+
68+
match expr {
69+
LogicalExpr::BinaryExpr(BinaryExpr {
70+
left,
71+
op: Operator::Eq,
72+
right,
73+
}) => {
74+
let left_col = left.try_as_column();
75+
let right_col = right.try_as_column();
76+
77+
if let (Some(left_col), Some(right_col)) = (left_col, right_col) {
78+
if (left_schema.has_column(left_col) || right_schema.has_column(left_col))
79+
&& (left_schema.has_column(right_col) || right_schema.has_column(right_col))
80+
{
81+
join_keys.insert((left.as_ref(), right.as_ref()));
82+
83+
return (join_keys, None);
84+
}
85+
}
86+
87+
(join_keys, Some(expr.clone()))
88+
}
89+
LogicalExpr::BinaryExpr(BinaryExpr {
90+
left,
91+
op: Operator::And,
92+
right,
93+
}) => {
94+
let (left_join_keys, left_predicate) = extract_join_pairs(left, left_schema, right_schema);
95+
let (right_join_keys, right_predicate) = extract_join_pairs(right, left_schema, right_schema);
96+
97+
join_keys.extend(left_join_keys);
98+
join_keys.extend(right_join_keys);
99+
100+
let predicate = match (left_predicate, right_predicate) {
101+
(Some(left_predicate), Some(right_predicate)) => Some(LogicalExpr::BinaryExpr(BinaryExpr {
102+
left: Box::new(left_predicate),
103+
op: Operator::And,
104+
right: Box::new(right_predicate),
105+
})),
106+
(l, r) => l.or(r),
107+
};
108+
109+
(join_keys, predicate)
110+
}
111+
LogicalExpr::BinaryExpr(BinaryExpr {
112+
left,
113+
op: Operator::Or,
114+
right,
115+
}) => {
116+
let (left_join_keys, left_predicate) = extract_join_pairs(left, left_schema, right_schema);
117+
let (right_join_keys, right_predicate) = extract_join_pairs(right, left_schema, right_schema);
118+
119+
for (l, r) in left_join_keys {
120+
if right_join_keys.contains(&(l, r)) || right_join_keys.contains(&(r, l)) {
121+
join_keys.insert((l, r));
122+
}
123+
}
124+
125+
let predicate = match (left_predicate, right_predicate) {
126+
(Some(l), Some(r)) => Some(LogicalExpr::BinaryExpr(BinaryExpr {
127+
left: Box::new(l),
128+
op: Operator::Or,
129+
right: Box::new(r),
130+
})),
131+
(l, r) => l.or(r),
132+
};
133+
134+
(join_keys, predicate)
135+
}
136+
137+
_ => (join_keys, Some(expr.clone())),
138+
}
139+
}
140+
141+
#[cfg(test)]
142+
mod tests {
143+
use crate::{optimizer::rule::eliminate_cross_join::EliminateCrossJoin, test_utils::assert_after_optimizer};
144+
145+
#[test]
146+
fn test_eliminate_cross_join_simple_and() {
147+
assert_after_optimizer(
148+
"SELECT * FROM users, repos WHERE users.id = repos.owner_id AND users.id = 10",
149+
Box::new(EliminateCrossJoin),
150+
vec![
151+
"Projection: (users.email, repos.id, users.id, repos.name, users.name, repos.owner_id)",
152+
" Filter: users.id = Int64(10)",
153+
" Inner Join: On: (users.id, repos.owner_id)",
154+
" TableScan: users",
155+
" TableScan: repos",
156+
],
157+
);
158+
}
159+
160+
#[test]
161+
fn test_eliminate_cross_join_simple_or() {
162+
assert_after_optimizer(
163+
"SELECT * FROM users, repos WHERE users.id = repos.owner_id OR users.id = 10",
164+
Box::new(EliminateCrossJoin),
165+
vec![
166+
"Projection: (users.email, repos.id, users.id, repos.name, users.name, repos.owner_id)",
167+
" Filter: users.id = repos.owner_id OR users.id = Int64(10)",
168+
" CrossJoin",
169+
" TableScan: users",
170+
" TableScan: repos",
171+
],
172+
);
173+
}
174+
175+
#[test]
176+
fn test_eliminate_cross_join_and() {
177+
assert_after_optimizer(
178+
"SELECT * FROM users, repos WHERE (users.id = repos.owner_id and users.name < 'a') AND (users.id = repos.owner_id and users.name = 'b')",
179+
Box::new(EliminateCrossJoin),
180+
vec![
181+
"Projection: (users.email, repos.id, users.id, repos.name, users.name, repos.owner_id)",
182+
" Filter: users.name < Utf8('a') AND users.name = Utf8('b')",
183+
" Inner Join: On: (users.id, repos.owner_id)",
184+
" TableScan: users",
185+
" TableScan: repos",
186+
],
187+
);
188+
}
189+
190+
#[test]
191+
fn test_eliminate_cross_join_or() {
192+
assert_after_optimizer(
193+
"SELECT * FROM users, repos WHERE (users.id = repos.owner_id and users.name < 'a') OR (users.id = repos.owner_id and users.name = 'b')",
194+
Box::new(EliminateCrossJoin),
195+
vec![
196+
"Projection: (users.email, repos.id, users.id, repos.name, users.name, repos.owner_id)",
197+
" Filter: users.name < Utf8('a') OR users.name = Utf8('b')",
198+
" Inner Join: On: (users.id, repos.owner_id)",
199+
" TableScan: users",
200+
" TableScan: repos",
201+
],
202+
);
203+
}
204+
205+
#[test]
206+
fn test_eliminate_cross_join_or_with_not_valid_join_pair_case1() {
207+
assert_after_optimizer(
208+
"SELECT * FROM users, repos WHERE (users.id = repos.owner_id and users.name < 'a') OR (users.id = repos.id and users.name = 'b')",
209+
Box::new(EliminateCrossJoin),
210+
vec![
211+
"Projection: (users.email, repos.id, users.id, repos.name, users.name, repos.owner_id)",
212+
" Filter: users.id = repos.owner_id AND users.name < Utf8('a') OR users.id = repos.id AND users.name = Utf8('b')",
213+
" CrossJoin",
214+
" TableScan: users",
215+
" TableScan: repos",
216+
],
217+
);
218+
}
219+
220+
#[test]
221+
fn test_eliminate_cross_join_or_with_not_valid_join_pair_case2() {
222+
assert_after_optimizer(
223+
"SELECT * FROM users, repos WHERE (users.id = repos.owner_id and users.name < 'a') OR (users.id = repos.owner_id OR users.name = 'b')",
224+
Box::new(EliminateCrossJoin),
225+
vec![
226+
"Projection: (users.email, repos.id, users.id, repos.name, users.name, repos.owner_id)",
227+
" Filter: users.id = repos.owner_id AND users.name < Utf8('a') OR users.id = repos.owner_id OR users.name = Utf8('b')",
228+
" CrossJoin",
229+
" TableScan: users",
230+
" TableScan: repos",
231+
],
232+
);
233+
}
234+
}

qurious/src/optimizer/rule/extract_equijoin_predicate.rs

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use crate::common::table_schema::TableSchemaRef;
22
use crate::common::transformed::{TransformNode, Transformed, TransformedResult};
33
use crate::error::Result;
44
use crate::logical::expr::Column;
5-
use crate::utils::expr::split_conjunctive_predicates;
5+
use crate::utils::expr::{check_all_columns_from_schema, split_conjunctive_predicates};
66
use crate::{
77
datatypes::operator::Operator,
88
logical::{
@@ -110,14 +110,7 @@ fn extract_equijoin_predicates(
110110
)
111111
}
112112

113-
fn check_all_columns_from_schema(columns: &HashSet<Column>, schema: &TableSchemaRef) -> bool {
114-
for col in columns.iter() {
115-
if !schema.has_field(col.relation.as_ref(), &col.name) {
116-
return false;
117-
}
118-
}
119-
true
120-
}
113+
121114

122115
#[cfg(test)]
123116
mod tests {

qurious/src/optimizer/rule/mod.rs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,16 @@
1-
2-
3-
41
mod count_wildcard_rule;
2+
mod eliminate_cross_join;
53
mod extract_equijoin_predicate;
6-
mod pushdown_filter_inner_join;
4+
mod pushdown_filter_join;
75
mod rule_optimizer;
86
mod scalar_subquery_to_join;
9-
mod type_coercion;
107
mod simplify_exprs;
8+
mod type_coercion;
119

1210
pub use rule_optimizer::*;
1311

1412
pub use count_wildcard_rule::*;
1513
pub use extract_equijoin_predicate::*;
16-
pub use pushdown_filter_inner_join::*;
14+
pub use pushdown_filter_join::*;
1715
pub use scalar_subquery_to_join::*;
1816
pub use type_coercion::*;

0 commit comments

Comments
 (0)