Skip to content

Commit 99a08c4

Browse files
xiedeyantumihaibudiu
authored andcommitted
[CALCITE-7008] Extend MinusToAntiJoinRule to support n-way inputs
1 parent b603c23 commit 99a08c4

File tree

6 files changed

+191
-30
lines changed

6 files changed

+191
-30
lines changed

core/src/main/java/org/apache/calcite/rel/rules/CoreRules.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,7 @@ private CoreRules() {}
418418
MinusToDistinctRule.Config.DEFAULT.toRule();
419419

420420
/** Rule to translates a {@link Minus} to {@link Join} anti-join}. */
421-
public static final MinusToAntiJoinRule MINUS_TO_ANTI_JOIN_RULE =
421+
public static final MinusToAntiJoinRule MINUS_TO_ANTI_JOIN =
422422
MinusToAntiJoinRule.Config.DEFAULT.toRule();
423423

424424
/** Rule that converts a {@link LogicalMatch} to the result of calling

core/src/main/java/org/apache/calcite/rel/rules/MinusToAntiJoinRule.java

Lines changed: 79 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,18 @@
3636
/**
3737
* Planner rule that translates a {@link Minus}
3838
* to a series of {@link org.apache.calcite.rel.core.Join} that type is
39-
* {@link JoinRelType#ANTI}. This rule supports 2-way Minus conversion,
39+
* {@link JoinRelType#ANTI}. This rule supports n-way Minus conversion,
4040
* as this rule can be repeatedly applied during query optimization to
4141
* refine the plan.
4242
*
43-
* <h2>Example</h2>
43+
* <p>Example for 2-way
44+
*
45+
* <p>Original sql:
46+
* <pre>{@code
47+
* select ename from emp where deptno = 10
48+
* except
49+
* select ename from emp where deptno = 20
50+
* }</pre>
4451
*
4552
* <p>Original plan:
4653
* <pre>{@code
@@ -50,7 +57,6 @@
5057
* LogicalTableScan(table=[[CATALOG, SALES, EMP]])
5158
* LogicalProject(ENAME=[$1])
5259
* LogicalFilter(condition=[=($7, 20)])
53-
* LogicalTableScan(table=[[CATALOG, SALES, EMP]])
5460
* }</pre>
5561
*
5662
* <p>Plan after conversion:
@@ -64,6 +70,46 @@
6470
* LogicalFilter(condition=[=($7, 20)])
6571
* LogicalTableScan(table=[[CATALOG, SALES, EMP]])
6672
* }</pre>
73+
*
74+
* <p>Example for n-way
75+
*
76+
* <p>Original sql:
77+
* <pre>{@code
78+
* select ename from emp where deptno = 10
79+
* except
80+
* select deptno from emp where ename in ('a', 'b')
81+
* except
82+
* select ename from empnullables
83+
* }</pre>
84+
*
85+
* <p>Original plan:
86+
* <pre>{@code
87+
* LogicalMinus(all=[false])
88+
* LogicalProject(ENAME=[$1])
89+
* LogicalFilter(condition=[=($7, 10)])
90+
* LogicalTableScan(table=[[CATALOG, SALES, EMP]])
91+
* LogicalProject(DEPTNO=[CAST($7):VARCHAR NOT NULL])
92+
* LogicalFilter(condition=[OR(=($1, 'a'), =($1, 'b'))])
93+
* LogicalTableScan(table=[[CATALOG, SALES, EMP]])
94+
* LogicalProject(ENAME=[$1])
95+
* LogicalTableScan(table=[[CATALOG, SALES, EMPNULLABLES]])
96+
* }</pre>
97+
*
98+
* <p>Plan after conversion:
99+
* <pre>{@code
100+
* LogicalProject(ENAME=[CAST($0):VARCHAR])
101+
* LogicalAggregate(group=[{0}])
102+
* LogicalJoin(condition=[<=>(CAST($0):VARCHAR, CAST($1):VARCHAR)], joinType=[anti])
103+
* LogicalJoin(condition=[=(CAST($0):VARCHAR, $1)], joinType=[anti])
104+
* LogicalProject(ENAME=[$1])
105+
* LogicalFilter(condition=[=($7, 10)])
106+
* LogicalTableScan(table=[[CATALOG, SALES, EMP]])
107+
* LogicalProject(DEPTNO=[CAST($7):VARCHAR NOT NULL])
108+
* LogicalFilter(condition=[OR(=($1, 'a'), =($1, 'b'))])
109+
* LogicalTableScan(table=[[CATALOG, SALES, EMP]])
110+
* LogicalProject(ENAME=[$1])
111+
* LogicalTableScan(table=[[CATALOG, SALES, EMPNULLABLES]])
112+
* }</pre>
67113
*/
68114
@Value.Enclosing
69115
public class MinusToAntiJoinRule
@@ -84,37 +130,44 @@ protected MinusToAntiJoinRule(Config config) {
84130
}
85131

86132
List<RelNode> inputs = minus.getInputs();
87-
if (inputs.size() != 2) {
133+
if (inputs.size() < 2) {
88134
return;
89135
}
90136

91137
final RelBuilder relBuilder = call.builder();
92138
final RexBuilder rexBuilder = relBuilder.getRexBuilder();
93139

94-
RelNode left = inputs.get(0);
95-
RelNode right = inputs.get(1);
96-
97-
List<RexNode> conditions = new ArrayList<>();
98-
int fieldCount = left.getRowType().getFieldCount();
99-
100-
for (int i = 0; i < fieldCount; i++) {
101-
RelDataType leftFieldType = left.getRowType().getFieldList().get(i).getType();
102-
RelDataType rightFieldType = right.getRowType().getFieldList().get(i).getType();
103-
104-
// No further optimization will be performed based on field nullability,
105-
// as this can be uniformly optimized by other rules.
106-
conditions.add(
107-
relBuilder.isNotDistinctFrom(
108-
rexBuilder.makeInputRef(leftFieldType, i),
109-
rexBuilder.makeInputRef(rightFieldType, i + fieldCount)));
140+
final RelDataType leastRowType = minus.getRowType();
141+
RelNode current = inputs.get(0);
142+
relBuilder.push(current);
143+
144+
for (int i = 1; i < inputs.size(); i++) {
145+
RelNode next = inputs.get(i);
146+
int fieldCount = current.getRowType().getFieldCount();
147+
148+
List<RexNode> conditions = new ArrayList<>();
149+
for (int j = 0; j < fieldCount; j++) {
150+
RelDataType leftFieldType = current.getRowType().getFieldList().get(j).getType();
151+
RelDataType rightFieldType = next.getRowType().getFieldList().get(j).getType();
152+
RelDataType leastFieldType = leastRowType.getFieldList().get(j).getType();
153+
154+
conditions.add(
155+
relBuilder.isNotDistinctFrom(
156+
rexBuilder.makeCast(leastFieldType,
157+
rexBuilder.makeInputRef(leftFieldType, j)),
158+
rexBuilder.makeCast(leastFieldType,
159+
rexBuilder.makeInputRef(rightFieldType, j + fieldCount))));
160+
}
161+
RexNode condition = RexUtil.composeConjunction(rexBuilder, conditions);
162+
163+
relBuilder.push(next)
164+
.join(JoinRelType.ANTI, condition);
165+
166+
current = relBuilder.peek();
110167
}
111-
RexNode condition = RexUtil.composeConjunction(rexBuilder, conditions);
112-
113-
relBuilder.push(left)
114-
.push(right)
115-
.join(JoinRelType.ANTI, condition)
116-
.distinct();
117168

169+
relBuilder.distinct()
170+
.convert(leastRowType, true);
118171
call.transformTo(relBuilder.build());
119172
}
120173

core/src/test/java/org/apache/calcite/test/JdbcTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4230,7 +4230,7 @@ public void checkOrderBy(final boolean desc,
42304230
p -> {
42314231
p.removeRule(CoreRules.MINUS_TO_DISTINCT);
42324232
p.removeRule(ENUMERABLE_MINUS_RULE);
4233-
p.addRule(CoreRules.MINUS_TO_ANTI_JOIN_RULE);
4233+
p.addRule(CoreRules.MINUS_TO_ANTI_JOIN);
42344234
})
42354235
.explainContains("joinType=[anti]")
42364236
.returnsUnordered(returns);

core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3684,7 +3684,18 @@ private void checkPushJoinThroughUnionOnRightDoesNotMatchSemiOrAntiJoin(JoinRelT
36843684
final String sql = "select ename from emp where deptno = 10\n"
36853685
+ "except\n"
36863686
+ "select ename from emp where deptno = 20\n";
3687-
sql(sql).withRule(CoreRules.MINUS_TO_ANTI_JOIN_RULE)
3687+
sql(sql).withRule(CoreRules.MINUS_TO_ANTI_JOIN)
3688+
.check();
3689+
}
3690+
3691+
@Test void testMinusToAntiJoinRuleMultiInputs() {
3692+
final String sql = "select ename from emp where deptno = 10\n"
3693+
+ "except\n"
3694+
+ "select deptno from emp where ename in ('a', 'b')\n"
3695+
+ "except\n"
3696+
+ "select ename from empnullables\n";
3697+
sql(sql).withPreRule(CoreRules.MINUS_MERGE)
3698+
.withRule(CoreRules.MINUS_TO_ANTI_JOIN)
36883699
.check();
36893700
}
36903701

core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9113,13 +9113,52 @@ LogicalMinus(all=[false])
91139113
<Resource name="planAfter">
91149114
<![CDATA[
91159115
LogicalAggregate(group=[{0}])
9116-
LogicalJoin(condition=[IS NOT DISTINCT FROM($0, $1)], joinType=[anti])
9116+
LogicalJoin(condition=[=($0, $1)], joinType=[anti])
91179117
LogicalProject(ENAME=[$1])
91189118
LogicalFilter(condition=[=($7, 10)])
91199119
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
91209120
LogicalProject(ENAME=[$1])
91219121
LogicalFilter(condition=[=($7, 20)])
91229122
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
9123+
]]>
9124+
</Resource>
9125+
</TestCase>
9126+
<TestCase name="testMinusToAntiJoinRuleMultiInputs">
9127+
<Resource name="sql">
9128+
<![CDATA[select ename from emp where deptno = 10
9129+
except
9130+
select deptno from emp where ename in ('a', 'b')
9131+
except
9132+
select ename from empnullables
9133+
]]>
9134+
</Resource>
9135+
<Resource name="planBefore">
9136+
<![CDATA[
9137+
LogicalMinus(all=[false])
9138+
LogicalProject(ENAME=[$1])
9139+
LogicalFilter(condition=[=($7, 10)])
9140+
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
9141+
LogicalProject(DEPTNO=[CAST($7):VARCHAR NOT NULL])
9142+
LogicalFilter(condition=[OR(=($1, 'a'), =($1, 'b'))])
9143+
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
9144+
LogicalProject(ENAME=[$1])
9145+
LogicalTableScan(table=[[CATALOG, SALES, EMPNULLABLES]])
9146+
]]>
9147+
</Resource>
9148+
<Resource name="planAfter">
9149+
<![CDATA[
9150+
LogicalProject(ENAME=[CAST($0):VARCHAR])
9151+
LogicalAggregate(group=[{0}])
9152+
LogicalJoin(condition=[IS NOT DISTINCT FROM(CAST($0):VARCHAR, CAST($1):VARCHAR)], joinType=[anti])
9153+
LogicalJoin(condition=[=(CAST($0):VARCHAR, $1)], joinType=[anti])
9154+
LogicalProject(ENAME=[$1])
9155+
LogicalFilter(condition=[=($7, 10)])
9156+
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
9157+
LogicalProject(DEPTNO=[CAST($7):VARCHAR NOT NULL])
9158+
LogicalFilter(condition=[OR(=($1, 'a'), =($1, 'b'))])
9159+
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
9160+
LogicalProject(ENAME=[$1])
9161+
LogicalTableScan(table=[[CATALOG, SALES, EMPNULLABLES]])
91239162
]]>
91249163
</Resource>
91259164
</TestCase>

core/src/test/resources/sql/planner.iq

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,64 @@ EnumerableIntersect(all=[false])
203203
EnumerableValues(tuples=[[{ 1.0 }, { 4.0 }, { null }]])
204204
!plan
205205

206+
# [CALCITE-7008] Extend MinusToAntiJoinRule to support n-way inputs
207+
!set planner-rules "
208+
-EnumerableRules.ENUMERABLE_MINUS_RULE,
209+
-CoreRules.MINUS_TO_DISTINCT,
210+
+CoreRules.MINUS_TO_ANTI_JOIN"
211+
select a from (values (1.0), (2.0), (3.0), (4.0), (5.0)) as t1 (a)
212+
except
213+
select a from (values (1), (2)) as t2 (a)
214+
except
215+
select a from (values (1.0), (4.0), (null)) as t3 (a);
216+
+-----+
217+
| A |
218+
+-----+
219+
| 3.0 |
220+
| 5.0 |
221+
+-----+
222+
(2 rows)
223+
224+
!ok
225+
226+
EnumerableCalc(expr#0=[{inputs}], expr#1=[CAST($t0):DECIMAL(11, 1)], A=[$t1])
227+
EnumerableNestedLoopJoin(condition=[OR(AND(IS NULL(CAST($0):DECIMAL(11, 1)), IS NULL(CAST($1):DECIMAL(11, 1))), =(CAST($0):DECIMAL(11, 1), CAST($1):DECIMAL(11, 1)))], joinType=[anti])
228+
EnumerableAggregate(group=[{0}])
229+
EnumerableNestedLoopJoin(condition=[=(CAST($0):DECIMAL(11, 1) NOT NULL, CAST($1):DECIMAL(11, 1) NOT NULL)], joinType=[anti])
230+
EnumerableCalc(expr#0=[{inputs}], expr#1=[CAST($t0):DECIMAL(11, 1) NOT NULL], A=[$t1])
231+
EnumerableValues(tuples=[[{ 1.0 }, { 2.0 }, { 3.0 }, { 4.0 }, { 5.0 }]])
232+
EnumerableCalc(expr#0=[{inputs}], expr#1=[CAST($t0):DECIMAL(11, 1) NOT NULL], A=[$t1])
233+
EnumerableValues(tuples=[[{ 1 }, { 2 }]])
234+
EnumerableCalc(expr#0=[{inputs}], expr#1=[CAST($t0):DECIMAL(11, 1)], A=[$t1])
235+
EnumerableValues(tuples=[[{ 1.0 }, { 4.0 }, { null }]])
236+
!plan
237+
!set planner-rules original
238+
239+
# [CALCITE-7008] Extend MinusToAntiJoinRule to support n-way inputs
240+
select a from (values (1.0), (2.0), (3.0), (4.0), (5.0)) as t1 (a)
241+
except
242+
select a from (values (1), (2)) as t2 (a)
243+
except
244+
select a from (values (1.0), (4.0), (null)) as t3 (a);
245+
+-----+
246+
| A |
247+
+-----+
248+
| 3.0 |
249+
| 5.0 |
250+
+-----+
251+
(2 rows)
252+
253+
!ok
254+
255+
EnumerableMinus(all=[false])
256+
EnumerableCalc(expr#0=[{inputs}], expr#1=[CAST($t0):DECIMAL(11, 1) NOT NULL], A=[$t1])
257+
EnumerableValues(tuples=[[{ 1.0 }, { 2.0 }, { 3.0 }, { 4.0 }, { 5.0 }]])
258+
EnumerableCalc(expr#0=[{inputs}], expr#1=[CAST($t0):DECIMAL(11, 1) NOT NULL], A=[$t1])
259+
EnumerableValues(tuples=[[{ 1 }, { 2 }]])
260+
EnumerableCalc(expr#0=[{inputs}], expr#1=[CAST($t0):DECIMAL(11, 1)], A=[$t1])
261+
EnumerableValues(tuples=[[{ 1.0 }, { 4.0 }, { null }]])
262+
!plan
263+
206264
# Test predicate push down with/without expand disjunction.
207265
with t1 (id1, col11, col12) as (values (1, 11, 111), (2, 12, 122), (3, 13, 133), (4, 14, 144), (5, 15, 155)),
208266
t2 (id2, col21, col22) as (values (1, 21, 211), (2, 22, 222), (3, 23, 233), (4, 24, 244), (5, 25, 255)),

0 commit comments

Comments
 (0)