Skip to content
This repository was archived by the owner on Jan 7, 2025. It is now read-only.

Commit 54b36e9

Browse files
committed
fix(df-repr): enable project merge and project elimination rule (#223)
Signed-off-by: Alex Chi <[email protected]>
1 parent 7dd002d commit 54b36e9

File tree

5 files changed

+92
-98
lines changed

5 files changed

+92
-98
lines changed

optd-datafusion-repr/src/lib.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,10 +129,10 @@ impl DatafusionOptimizer {
129129
// rule_wrappers.push(RuleWrapper::new_cascades(Arc::new(
130130
// ProjectionPullUpJoin::new(),
131131
// )));
132-
// rule_wrappers.push(RuleWrapper::new_cascades(Arc::new(
133-
// EliminateProjectRule::new(),
134-
// )));
135-
// rule_wrappers.push(RuleWrapper::new_cascades(Arc::new(ProjectMergeRule::new())));
132+
rule_wrappers.push(RuleWrapper::new_cascades(Arc::new(
133+
EliminateProjectRule::new(),
134+
)));
135+
rule_wrappers.push(RuleWrapper::new_cascades(Arc::new(ProjectMergeRule::new())));
136136
// rule_wrappers.push(RuleWrapper::new_cascades(Arc::new(
137137
// EliminateFilterRule::new(),
138138
// )));

optd-datafusion-repr/src/rules.rs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
mod joins;
66
mod macros;
77
mod physical;
8-
// mod project_transpose;
8+
mod project_transpose;
99
// mod subquery;
1010

1111
// pub use eliminate_duplicated_expr::{
@@ -19,11 +19,7 @@ mod physical;
1919
// };
2020
pub use joins::*;
2121
pub use physical::PhysicalConversionRule;
22-
// pub use project_transpose::{
23-
// project_filter_transpose::{FilterProjectTransposeRule, ProjectFilterTransposeRule},
24-
// project_join_transpose::ProjectionPullUpJoin,
25-
// project_merge::{EliminateProjectRule, ProjectMergeRule},
26-
// };
22+
pub use project_transpose::*;
2723
// pub use subquery::{
2824
// DepInitialDistinct, DepJoinEliminate, DepJoinPastAgg, DepJoinPastFilter, DepJoinPastProj,
2925
// };
Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
pub mod project_filter_transpose;
2-
pub mod project_join_transpose;
1+
// pub mod project_filter_transpose;
2+
// pub mod project_join_transpose;
33
pub mod project_merge;
44
pub mod project_transpose_common;
5+
6+
pub use project_merge::*;
Lines changed: 74 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1-
use std::collections::HashMap;
2-
1+
use optd_core::nodes::PlanNodeOrGroup;
2+
use optd_core::optimizer::Optimizer;
33
use optd_core::rules::{Rule, RuleMatcher};
4-
use optd_core::{nodes::PlanNode, optimizer::Optimizer};
54

6-
use crate::plan_nodes::{DfNodeType, DfReprPlanNode, DfReprPlanNode, ListPred, LogicalProjection};
5+
use crate::plan_nodes::{
6+
ArcDfPlanNode, ColumnRefPred, DfNodeType, DfReprPlanNode, DfReprPredNode, LogicalProjection,
7+
};
78
use crate::rules::macros::define_rule;
9+
use crate::OptimizerExt;
810

911
use super::project_transpose_common::ProjectionMapping;
1012

@@ -13,20 +15,18 @@ use super::project_transpose_common::ProjectionMapping;
1315
define_rule!(
1416
ProjectMergeRule,
1517
apply_projection_merge,
16-
(Projection, (Projection, child, [exprs2]), [exprs1])
18+
(Projection, (Projection, child))
1719
);
1820

1921
fn apply_projection_merge(
2022
_optimizer: &impl Optimizer<DfNodeType>,
21-
ProjectMergeRulePicks {
22-
child,
23-
exprs1,
24-
exprs2,
25-
}: ProjectMergeRulePicks,
23+
binding: ArcDfPlanNode,
2624
) -> Vec<PlanNodeOrGroup<DfNodeType>> {
27-
let child = DfReprPlanNode::from_group(child.into());
28-
let exprs1 = ListPred::from_rel_node(exprs1.into()).unwrap();
29-
let exprs2 = ListPred::from_rel_node(exprs2.into()).unwrap();
25+
let proj1 = LogicalProjection::from_plan_node(binding).unwrap();
26+
let proj2 = LogicalProjection::from_plan_node(proj1.child().unwrap_plan_node()).unwrap();
27+
let child = proj2.child();
28+
let exprs1 = proj1.exprs();
29+
let exprs2 = proj2.exprs();
3030

3131
let Some(mapping) = ProjectionMapping::build(&exprs1) else {
3232
return vec![];
@@ -36,33 +36,32 @@ fn apply_projection_merge(
3636
return vec![];
3737
};
3838

39-
let node: LogicalProjection = LogicalProjection::new(child, res_exprs);
39+
let node = LogicalProjection::new_unchecked(child, res_exprs);
4040

41-
vec![node.into_rel_node().as_ref().clone()]
41+
vec![node.into_plan_node().into()]
4242
}
4343

4444
// Proj child [identical columns] -> eliminate
4545
define_rule!(
4646
EliminateProjectRule,
4747
apply_eliminate_project,
48-
(Projection, child, [expr])
48+
(Projection, child)
4949
);
5050

5151
fn apply_eliminate_project(
5252
optimizer: &impl Optimizer<DfNodeType>,
53-
EliminateProjectRulePicks { child, expr }: EliminateProjectRulePicks,
54-
) -> Vec<RelNode<DfNodeType>> {
55-
let exprs = ExprList::from_rel_node(expr.into()).unwrap();
56-
let child_columns = optimizer
57-
.get_property::<SchemaPropertyBuilder>(child.clone().into(), 0)
58-
.len();
59-
if child_columns != exprs.len() {
53+
binding: ArcDfPlanNode,
54+
) -> Vec<PlanNodeOrGroup<DfNodeType>> {
55+
let proj = LogicalProjection::from_plan_node(binding).unwrap();
56+
let child = proj.child();
57+
let exprs = proj.exprs();
58+
let child_schema = optimizer.get_schema_of(child.clone());
59+
if child_schema.len() != exprs.len() {
6060
return Vec::new();
6161
}
6262
for i in 0..exprs.len() {
6363
let child_expr = exprs.child(i);
64-
if child_expr.typ() == DfNodeType::ColumnRef {
65-
let child_expr = ColumnRefExpr::from_rel_node(child_expr.into_rel_node()).unwrap();
64+
if let Some(child_expr) = ColumnRefPred::from_pred_node(child_expr) {
6665
if child_expr.index() != i {
6766
return Vec::new();
6867
}
@@ -77,13 +76,10 @@ fn apply_eliminate_project(
7776
mod tests {
7877
use std::sync::Arc;
7978

80-
use optd_core::optimizer::Optimizer;
79+
use super::*;
8180

8281
use crate::{
83-
plan_nodes::{
84-
ColumnRefPred, DfNodeType, DfReprPlanNode, ListPred, LogicalProjection, LogicalScan,
85-
},
86-
rules::ProjectMergeRule,
82+
plan_nodes::{ListPred, LogicalScan},
8783
testing::new_test_optimizer,
8884
};
8985

@@ -95,30 +91,30 @@ mod tests {
9591
let scan = LogicalScan::new("customer".into());
9692

9793
let top_proj_exprs = ListPred::new(vec![
98-
ColumnRefPred::new(2).into_expr(),
99-
ColumnRefPred::new(0).into_expr(),
94+
ColumnRefPred::new(2).into_pred_node(),
95+
ColumnRefPred::new(0).into_pred_node(),
10096
]);
10197

10298
let bot_proj_exprs = ListPred::new(vec![
103-
ColumnRefPred::new(2).into_expr(),
104-
ColumnRefPred::new(0).into_expr(),
105-
ColumnRefPred::new(4).into_expr(),
99+
ColumnRefPred::new(2).into_pred_node(),
100+
ColumnRefPred::new(0).into_pred_node(),
101+
ColumnRefPred::new(4).into_pred_node(),
106102
]);
107103

108104
let bot_proj = LogicalProjection::new(scan.into_plan_node(), bot_proj_exprs);
109105
let top_proj = LogicalProjection::new(bot_proj.into_plan_node(), top_proj_exprs);
110106

111-
let plan = test_optimizer.optimize(top_proj.into_rel_node()).unwrap();
107+
let plan = test_optimizer.optimize(top_proj.into_plan_node()).unwrap();
112108

113109
let res_proj_exprs = ListPred::new(vec![
114-
ColumnRefPred::new(4).into_expr(),
115-
ColumnRefPred::new(2).into_expr(),
110+
ColumnRefPred::new(4).into_pred_node(),
111+
ColumnRefPred::new(2).into_pred_node(),
116112
])
117-
.into_rel_node();
113+
.into_pred_node();
118114

119115
assert_eq!(plan.typ, DfNodeType::Projection);
120-
assert_eq!(plan.child(1), res_proj_exprs);
121-
assert!(matches!(plan.child(0).typ, DfNodeType::Scan));
116+
assert_eq!(plan.predicate(0), res_proj_exprs);
117+
assert!(matches!(plan.child_rel(0).typ, DfNodeType::Scan));
122118
}
123119

124120
#[test]
@@ -129,42 +125,42 @@ mod tests {
129125
let scan = LogicalScan::new("customer".into());
130126

131127
let proj_exprs_1 = ListPred::new(vec![
132-
ColumnRefPred::new(2).into_expr(),
133-
ColumnRefPred::new(0).into_expr(),
134-
ColumnRefPred::new(4).into_expr(),
135-
ColumnRefPred::new(3).into_expr(),
128+
ColumnRefPred::new(2).into_pred_node(),
129+
ColumnRefPred::new(0).into_pred_node(),
130+
ColumnRefPred::new(4).into_pred_node(),
131+
ColumnRefPred::new(3).into_pred_node(),
136132
]);
137133

138134
let proj_exprs_2 = ListPred::new(vec![
139-
ColumnRefPred::new(1).into_expr(),
140-
ColumnRefPred::new(0).into_expr(),
141-
ColumnRefPred::new(3).into_expr(),
135+
ColumnRefPred::new(1).into_pred_node(),
136+
ColumnRefPred::new(0).into_pred_node(),
137+
ColumnRefPred::new(3).into_pred_node(),
142138
]);
143139

144140
let proj_exprs_3 = ListPred::new(vec![
145-
ColumnRefPred::new(1).into_expr(),
146-
ColumnRefPred::new(0).into_expr(),
147-
ColumnRefPred::new(2).into_expr(),
141+
ColumnRefPred::new(1).into_pred_node(),
142+
ColumnRefPred::new(0).into_pred_node(),
143+
ColumnRefPred::new(2).into_pred_node(),
148144
]);
149145

150146
let proj_1 = LogicalProjection::new(scan.into_plan_node(), proj_exprs_1);
151147
let proj_2 = LogicalProjection::new(proj_1.into_plan_node(), proj_exprs_2);
152148
let proj_3 = LogicalProjection::new(proj_2.into_plan_node(), proj_exprs_3);
153149

154150
// needs to be called twice
155-
let plan = test_optimizer.optimize(proj_3.into_rel_node()).unwrap();
151+
let plan = test_optimizer.optimize(proj_3.into_plan_node()).unwrap();
156152
let plan = test_optimizer.optimize(plan).unwrap();
157153

158154
let res_proj_exprs = ListPred::new(vec![
159-
ColumnRefPred::new(2).into_expr(),
160-
ColumnRefPred::new(0).into_expr(),
161-
ColumnRefPred::new(3).into_expr(),
155+
ColumnRefPred::new(2).into_pred_node(),
156+
ColumnRefPred::new(0).into_pred_node(),
157+
ColumnRefPred::new(3).into_pred_node(),
162158
])
163-
.into_rel_node();
159+
.into_pred_node();
164160

165161
assert_eq!(plan.typ, DfNodeType::Projection);
166-
assert_eq!(plan.child(1), res_proj_exprs);
167-
assert!(matches!(plan.child(0).typ, DfNodeType::Scan));
162+
assert_eq!(plan.predicate(0), res_proj_exprs);
163+
assert!(matches!(plan.child_rel(0).typ, DfNodeType::Scan));
168164
}
169165

170166
#[test]
@@ -175,28 +171,28 @@ mod tests {
175171
let scan = LogicalScan::new("customer".into());
176172

177173
let proj_exprs_1 = ListPred::new(vec![
178-
ColumnRefPred::new(2).into_expr(),
179-
ColumnRefPred::new(0).into_expr(),
180-
ColumnRefPred::new(4).into_expr(),
181-
ColumnRefPred::new(3).into_expr(),
174+
ColumnRefPred::new(2).into_pred_node(),
175+
ColumnRefPred::new(0).into_pred_node(),
176+
ColumnRefPred::new(4).into_pred_node(),
177+
ColumnRefPred::new(3).into_pred_node(),
182178
]);
183179

184180
let proj_exprs_2 = ListPred::new(vec![
185-
ColumnRefPred::new(1).into_expr(),
186-
ColumnRefPred::new(0).into_expr(),
187-
ColumnRefPred::new(3).into_expr(),
181+
ColumnRefPred::new(1).into_pred_node(),
182+
ColumnRefPred::new(0).into_pred_node(),
183+
ColumnRefPred::new(3).into_pred_node(),
188184
]);
189185

190186
let proj_exprs_3 = ListPred::new(vec![
191-
ColumnRefPred::new(1).into_expr(),
192-
ColumnRefPred::new(0).into_expr(),
193-
ColumnRefPred::new(2).into_expr(),
187+
ColumnRefPred::new(1).into_pred_node(),
188+
ColumnRefPred::new(0).into_pred_node(),
189+
ColumnRefPred::new(2).into_pred_node(),
194190
]);
195191

196192
let proj_exprs_4 = ListPred::new(vec![
197-
ColumnRefPred::new(0).into_expr(),
198-
ColumnRefPred::new(1).into_expr(),
199-
ColumnRefPred::new(2).into_expr(),
193+
ColumnRefPred::new(0).into_pred_node(),
194+
ColumnRefPred::new(1).into_pred_node(),
195+
ColumnRefPred::new(2).into_pred_node(),
200196
]);
201197

202198
let proj_1 = LogicalProjection::new(scan.into_plan_node(), proj_exprs_1);
@@ -205,19 +201,19 @@ mod tests {
205201
let proj_4 = LogicalProjection::new(proj_3.into_plan_node(), proj_exprs_4);
206202

207203
// needs to be called three times
208-
let plan = test_optimizer.optimize(proj_4.into_rel_node()).unwrap();
204+
let plan = test_optimizer.optimize(proj_4.into_plan_node()).unwrap();
209205
let plan = test_optimizer.optimize(plan).unwrap();
210206
let plan = test_optimizer.optimize(plan).unwrap();
211207

212208
let res_proj_exprs = ListPred::new(vec![
213-
ColumnRefPred::new(2).into_expr(),
214-
ColumnRefPred::new(0).into_expr(),
215-
ColumnRefPred::new(3).into_expr(),
209+
ColumnRefPred::new(2).into_pred_node(),
210+
ColumnRefPred::new(0).into_pred_node(),
211+
ColumnRefPred::new(3).into_pred_node(),
216212
])
217-
.into_rel_node();
213+
.into_pred_node();
218214

219215
assert_eq!(plan.typ, DfNodeType::Projection);
220-
assert_eq!(plan.child(1), res_proj_exprs);
221-
assert!(matches!(plan.child(0).typ, DfNodeType::Scan));
216+
assert_eq!(plan.predicate(0), res_proj_exprs);
217+
assert!(matches!(plan.child_rel(0).typ, DfNodeType::Scan));
222218
}
223219
}

optd-datafusion-repr/src/rules/project_transpose/project_transpose_common.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use crate::plan_nodes::{ColumnRefPred, DfReprPlanNode, ListPred};
1+
use crate::plan_nodes::{ArcDfPredNode, ColumnRefPred, DfReprPredNode, ListPred, PredExt};
22

33
/// This struct holds the mapping from original columns to projected columns.
44
///
@@ -25,7 +25,7 @@ impl ProjectionMapping {
2525
let mut forward = vec![];
2626
let mut backward = vec![];
2727
for (i, expr) in exprs.to_vec().iter().enumerate() {
28-
let col_expr = ColumnRefPred::from_rel_node(expr.clone().into_rel_node())?;
28+
let col_expr = ColumnRefPred::from_pred_node(expr.clone())?;
2929
let col_idx = col_expr.index();
3030
forward.push(col_idx);
3131
if col_idx >= backward.len() {
@@ -52,9 +52,9 @@ impl ProjectionMapping {
5252
/// Projection { exprs: [#1, #0, #3, #5, #4] } --> has mapping
5353
/// ---->
5454
/// Join { cond: #1=#4 }
55-
pub fn rewrite_join_cond(&self, cond: Expr, child_schema_len: usize) -> Expr {
55+
pub fn rewrite_join_cond(&self, cond: ArcDfPredNode, child_schema_len: usize) -> ArcDfPredNode {
5656
let schema_size = self.forward.len();
57-
cond.rewrite_column_refs(&mut |col_idx| {
57+
cond.rewrite_column_refs(|col_idx| {
5858
if col_idx < schema_size {
5959
self.projection_col_maps_to(col_idx)
6060
} else {
@@ -78,7 +78,7 @@ impl ProjectionMapping {
7878
/// Projection { exprs: [#1, #0, #3, #5, #4] } --> has mapping
7979
/// ---->
8080
/// Filter { cond: #1=0 and #4=1 }
81-
pub fn rewrite_filter_cond(&self, cond: Expr, is_added: bool) -> Expr {
81+
pub fn rewrite_filter_cond(&self, cond: ArcDfPredNode, is_added: bool) -> ArcDfPredNode {
8282
cond.rewrite_column_refs(&mut |col_idx| {
8383
if is_added {
8484
self.original_col_maps_to(col_idx)
@@ -112,10 +112,10 @@ impl ProjectionMapping {
112112
}
113113
} else {
114114
for i in exprs.to_vec() {
115-
let col_ref = ColumnRefPred::from_rel_node(i.into_rel_node()).unwrap();
115+
let col_ref = ColumnRefPred::from_pred_node(i).unwrap();
116116
let col_idx = self.original_col_maps_to(col_ref.index()).unwrap();
117-
let col: Expr = ColumnRefPred::new(col_idx).into_expr();
118-
new_projection_exprs.push(col);
117+
let col = ColumnRefPred::new(col_idx);
118+
new_projection_exprs.push(col.into_pred_node());
119119
}
120120
}
121121
Some(ListPred::new(new_projection_exprs))

0 commit comments

Comments
 (0)