Skip to content
This repository was archived by the owner on Jan 7, 2025. It is now read-only.

Commit 5f26d36

Browse files
authored
Filter/Agg transpose bugfix (#252)
The check for whether or not a filter expression could be pushed beyond an agg node was incorrect. It was checking if the column was in the group by columns (checking the equality of the numbers), when it should be checking based on indices if we are only referring to columns that are *emitted* from the agg node as group by columns. For example, if we see: ``` Filter #1 > 100 Agg { groups: [#1], agg: Sum() } ``` We should *not* push down because `#1` refers to the sum column. In the current main branch, it is pushed down because it sees that `#1` equals a column in the `groups` field. It should be checking that every column is `< groups.len()` instead.
1 parent 2dd2a31 commit 5f26d36

File tree

4 files changed

+60
-39
lines changed

4 files changed

+60
-39
lines changed

optd-datafusion-repr/src/rules/filter_pushdown.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,10 @@ fn apply_filter_agg_transpose(
346346
let mut group_by_cols_only = true;
347347
for child in children {
348348
if let Some(col_ref) = ColumnRefPred::from_pred_node(child.clone()) {
349-
if !group_cols.contains(&col_ref.index()) {
349+
// The agg schema is (group columns) + (expr columns),
350+
// so if the column ref is < group_cols.len(), it is
351+
// a group column.
352+
if col_ref.index() >= group_cols.len() {
350353
group_by_cols_only = false;
351354
break;
352355
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
statement ok
2+
create table t1(v1 int, v2 int);
3+
4+
statement ok
5+
create table t2(v3 int, v4 int);
6+
7+
statement ok
8+
insert into t1 values (1, 100), (2, 200), (2, 250), (3, 300), (3, 300);
9+
10+
statement ok
11+
insert into t2 values (2, 200), (2, 250), (3, 300);

optd-sqllogictest/slt/unnest-dup.slt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
include _basic_tables.slt.part
2+
3+
query
4+
select * from t1 where (select sum(v4) from t2 where v3 = v1) > 100;
5+
----
6+
2 200
7+
2 250
8+
3 300
9+
3 300

optd-sqlplannertest/tests/subqueries/subquery_unnesting.planner.sql

Lines changed: 36 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -53,25 +53,24 @@ LogicalProjection { exprs: [ #0, #1 ] }
5353
├── LogicalAgg { exprs: [], groups: [ #0 ] }
5454
│ └── LogicalScan { table: t1 }
5555
└── LogicalScan { table: t2 }
56-
PhysicalProjection { exprs: [ #2, #3 ], cost: {compute=8019,io=3000}, stat: {row_cnt=1} }
57-
└── PhysicalHashJoin { join_type: Inner, left_keys: [ #0 ], right_keys: [ #0 ], cost: {compute=8016,io=3000}, stat: {row_cnt=1} }
58-
├── PhysicalAgg
59-
│ ├── aggrs:Agg(Sum)
60-
│ │ ── [ Cast { cast_to: Int64, child: #2 } ]
61-
├── groups: [ #1 ]
62-
│ ├── cost: {compute=7014,io=2000}
56+
PhysicalProjection { exprs: [ #2, #3 ], cost: {compute=18005,io=3000}, stat: {row_cnt=1} }
57+
└── PhysicalHashJoin { join_type: Inner, left_keys: [ #0 ], right_keys: [ #0 ], cost: {compute=18002,io=3000}, stat: {row_cnt=1} }
58+
├── PhysicalFilter
59+
│ ├── cond:Gt
60+
│ │ ── #1
61+
│ └── 100(i64)
62+
│ ├── cost: {compute=17000,io=2000}
6363
│ ├── stat: {row_cnt=1}
64-
│ └── PhysicalProjection { exprs: [ #2, #0, #1 ], cost: {compute=7006,io=2000}, stat: {row_cnt=1} }
65-
│ └── PhysicalHashJoin { join_type: Inner, left_keys: [ #0 ], right_keys: [ #0 ], cost: {compute=7002,io=2000}, stat: {row_cnt=1} }
66-
│ ├── PhysicalFilter
67-
│ │ ├── cond:Gt
68-
│ │ │ ├── #0
69-
│ │ │ └── 100(i64)
70-
│ │ ├── cost: {compute=3000,io=1000}
71-
│ │ ├── stat: {row_cnt=1}
72-
│ │ └── PhysicalScan { table: t2, cost: {compute=0,io=1000}, stat: {row_cnt=1000} }
73-
│ └── PhysicalAgg { aggrs: [], groups: [ #0 ], cost: {compute=3000,io=1000}, stat: {row_cnt=1000} }
74-
│ └── PhysicalScan { table: t1, cost: {compute=0,io=1000}, stat: {row_cnt=1000} }
64+
│ └── PhysicalAgg
65+
│ ├── aggrs:Agg(Sum)
66+
│ │ └── [ Cast { cast_to: Int64, child: #2 } ]
67+
│ ├── groups: [ #1 ]
68+
│ ├── cost: {compute=14000,io=2000}
69+
│ ├── stat: {row_cnt=1000}
70+
│ └── PhysicalHashJoin { join_type: Inner, left_keys: [ #0 ], right_keys: [ #0 ], cost: {compute=6000,io=2000}, stat: {row_cnt=1000} }
71+
│ ├── PhysicalAgg { aggrs: [], groups: [ #0 ], cost: {compute=3000,io=1000}, stat: {row_cnt=1000} }
72+
│ │ └── PhysicalScan { table: t1, cost: {compute=0,io=1000}, stat: {row_cnt=1000} }
73+
│ └── PhysicalScan { table: t2, cost: {compute=0,io=1000}, stat: {row_cnt=1000} }
7574
└── PhysicalScan { table: t1, cost: {compute=0,io=1000}, stat: {row_cnt=1000} }
7675
*/
7776

@@ -135,27 +134,26 @@ LogicalProjection { exprs: [ #0, #1 ] }
135134
└── LogicalJoin { join_type: Cross, cond: true }
136135
├── LogicalScan { table: t2 }
137136
└── LogicalScan { table: t3 }
138-
PhysicalProjection { exprs: [ #2, #3 ], cost: {compute=9021,io=4000}, stat: {row_cnt=1} }
139-
└── PhysicalHashJoin { join_type: Inner, left_keys: [ #0 ], right_keys: [ #0 ], cost: {compute=9018,io=4000}, stat: {row_cnt=1} }
140-
├── PhysicalAgg
141-
│ ├── aggrs:Agg(Sum)
142-
│ │ ── [ Cast { cast_to: Int64, child: #2 } ]
143-
├── groups: [ #1 ]
144-
│ ├── cost: {compute=8016,io=3000}
137+
PhysicalProjection { exprs: [ #2, #3 ], cost: {compute=21005,io=4000}, stat: {row_cnt=1} }
138+
└── PhysicalHashJoin { join_type: Inner, left_keys: [ #0 ], right_keys: [ #0 ], cost: {compute=21002,io=4000}, stat: {row_cnt=1} }
139+
├── PhysicalFilter
140+
│ ├── cond:Gt
141+
│ │ ── #1
142+
│ └── 100(i64)
143+
│ ├── cost: {compute=20000,io=3000}
145144
│ ├── stat: {row_cnt=1}
146-
│ └── PhysicalHashJoin { join_type: Inner, left_keys: [ #2 ], right_keys: [ #0 ], cost: {compute=8008,io=3000}, stat: {row_cnt=1} }
147-
│ ├── PhysicalProjection { exprs: [ #2, #0, #1 ], cost: {compute=7006,io=2000}, stat: {row_cnt=1} }
148-
│ │ └── PhysicalHashJoin { join_type: Inner, left_keys: [ #0 ], right_keys: [ #0 ], cost: {compute=7002,io=2000}, stat: {row_cnt=1} }
149-
│ │ ├── PhysicalFilter
150-
│ │ │ ├── cond:Gt
151-
│ │ │ │ ├── #0
152-
│ │ │ │ └── 100(i64)
153-
│ │ │ ├── cost: {compute=3000,io=1000}
154-
│ │ │ ├── stat: {row_cnt=1}
155-
│ │ │ └── PhysicalScan { table: t2, cost: {compute=0,io=1000}, stat: {row_cnt=1000} }
156-
│ │ └── PhysicalAgg { aggrs: [], groups: [ #0 ], cost: {compute=3000,io=1000}, stat: {row_cnt=1000} }
157-
│ │ └── PhysicalScan { table: t1, cost: {compute=0,io=1000}, stat: {row_cnt=1000} }
158-
│ └── PhysicalScan { table: t3, cost: {compute=0,io=1000}, stat: {row_cnt=1000} }
145+
│ └── PhysicalAgg
146+
│ ├── aggrs:Agg(Sum)
147+
│ │ └── [ Cast { cast_to: Int64, child: #2 } ]
148+
│ ├── groups: [ #1 ]
149+
│ ├── cost: {compute=17000,io=3000}
150+
│ ├── stat: {row_cnt=1000}
151+
│ └── PhysicalHashJoin { join_type: Inner, left_keys: [ #2 ], right_keys: [ #0 ], cost: {compute=9000,io=3000}, stat: {row_cnt=1000} }
152+
│ ├── PhysicalHashJoin { join_type: Inner, left_keys: [ #0 ], right_keys: [ #0 ], cost: {compute=6000,io=2000}, stat: {row_cnt=1000} }
153+
│ │ ├── PhysicalAgg { aggrs: [], groups: [ #0 ], cost: {compute=3000,io=1000}, stat: {row_cnt=1000} }
154+
│ │ │ └── PhysicalScan { table: t1, cost: {compute=0,io=1000}, stat: {row_cnt=1000} }
155+
│ │ └── PhysicalScan { table: t2, cost: {compute=0,io=1000}, stat: {row_cnt=1000} }
156+
│ └── PhysicalScan { table: t3, cost: {compute=0,io=1000}, stat: {row_cnt=1000} }
159157
└── PhysicalScan { table: t1, cost: {compute=0,io=1000}, stat: {row_cnt=1000} }
160158
*/
161159

0 commit comments

Comments
 (0)