Skip to content

Commit 46e621a

Browse files
Draft fix
1 parent 5cfe2a5 commit 46e621a

File tree

5 files changed

+216
-86
lines changed

5 files changed

+216
-86
lines changed

core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java

Lines changed: 41 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
import org.apache.calcite.rex.RexSubQuery;
7474
import org.apache.calcite.rex.RexUtil;
7575
import org.apache.calcite.rex.RexVisitorImpl;
76+
import org.apache.calcite.rex.RexWindowBounds;
7677
import org.apache.calcite.runtime.PairList;
7778
import org.apache.calcite.sql.SqlAggFunction;
7879
import org.apache.calcite.sql.SqlExplainFormat;
@@ -574,9 +575,7 @@ protected RexNode removeCorrelationExpr(
574575
}
575576

576577
if (isCorVarDefined && (rel.fetch != null || rel.offset != null)) {
577-
if (rel.fetch != null
578-
&& rel.offset == null
579-
&& RexLiteral.intValue(rel.fetch) == 1) {
578+
if (rel.fetch != null && rel.offset == null && RexLiteral.intValue(rel.fetch) >= 1) {
580579
return decorrelateFetchOneSort(rel, frame);
581580
}
582581
// Can not decorrelate if the sort has per-correlate-key attributes like
@@ -1099,25 +1098,49 @@ private static void shiftMapping(Map<Integer, Integer> mapping, int startIndex,
10991098
for (RelDataTypeField field : sort.getRowType().getFieldList()) {
11001099
final int newIdx =
11011100
requireNonNull(frame.oldToNewOutputs.get(field.getIndex()));
1102-
1103-
RelBuilder.AggCall aggCall =
1104-
relBuilder.aggregateCall(SqlStdOperatorTable.FIRST_VALUE,
1105-
RexInputRef.of(newIdx, fieldList));
1106-
1107-
// Convert each field from the sorted output to a window function that partitions by
1108-
// correlated variables, orders by the collation, and return the first_value.
1109-
RexNode winCall = aggCall.over()
1110-
.orderBy(sortExprs)
1111-
.partitionBy(corVarProjects.leftList())
1112-
.toRex();
11131101
mapOldToNewOutputs.put(newProjExprs.size(), newProjExprs.size());
1114-
newProjExprs.add(winCall, field.getName());
1102+
newProjExprs.add(RexInputRef.of(newIdx, fieldList), field.getName());
11151103
}
11161104
newProjExprs.addAll(corVarProjects);
1117-
RelNode result = relBuilder.push(frame.r)
1118-
.project(newProjExprs.leftList(), newProjExprs.rightList())
1119-
.distinct().build();
11201105

1106+
relBuilder.push(frame.r);
1107+
1108+
RexNode rowNumberCall = relBuilder.aggregateCall(SqlStdOperatorTable.ROW_NUMBER)
1109+
.over()
1110+
.partitionBy(corVarProjects.leftList())
1111+
.orderBy(sortExprs)
1112+
.let(c -> c.rowsBetween(RexWindowBounds.UNBOUNDED_PRECEDING, RexWindowBounds.CURRENT_ROW))
1113+
.toRex();
1114+
newProjExprs.add(rowNumberCall, "rn"); // Add the row number column
1115+
relBuilder.project(newProjExprs.leftList(), newProjExprs.rightList());
1116+
1117+
List<RexNode> conditions = new ArrayList<>();
1118+
if (sort.offset != null) {
1119+
RexNode greaterThenLowerBound =
1120+
relBuilder.call(
1121+
SqlStdOperatorTable.GREATER_THAN,
1122+
relBuilder.field(newProjExprs.size() - 1),
1123+
sort.offset);
1124+
conditions.add(greaterThenLowerBound);
1125+
}
1126+
if (sort.fetch != null) {
1127+
RexNode upperBound = sort.offset == null
1128+
? sort.fetch
1129+
: relBuilder.call(SqlStdOperatorTable.PLUS, sort.offset, sort.fetch);
1130+
RexNode lessThenOrEqualUpperBound =
1131+
relBuilder.call(
1132+
SqlStdOperatorTable.LESS_THAN_OR_EQUAL,
1133+
relBuilder.field(newProjExprs.size() - 1),
1134+
upperBound);
1135+
conditions.add(lessThenOrEqualUpperBound);
1136+
}
1137+
1138+
RelNode result;
1139+
if (!conditions.isEmpty()) {
1140+
result = relBuilder.filter(conditions).build();
1141+
} else {
1142+
result = relBuilder.build();
1143+
}
11211144
return register(sort, result, mapOldToNewOutputs, corDefOutputs);
11221145
}
11231146

core/src/test/java/org/apache/calcite/sql2rel/RelDecorrelatorTest.java

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,4 +584,67 @@ public static Frameworks.ConfigBuilder config() {
584584
+ " LogicalTableScan(table=[[scott, DEPT]])\n";
585585
assertThat(decorrelatedNoRules, hasTree(planDecorrelatedNoRules));
586586
}
587+
588+
@Test void testDecorrelateCorrelatedOrderByLimitToRowNumberRows() {
589+
final FrameworkConfig frameworkConfig = config().build();
590+
final RelBuilder builder = RelBuilder.create(frameworkConfig);
591+
final RelOptCluster cluster = builder.getCluster();
592+
final Planner planner = Frameworks.getPlanner(frameworkConfig);
593+
final String sql = ""
594+
+ "SELECT dname FROM dept WHERE 2000 > (\n"
595+
+ "SELECT emp.sal FROM emp where dept.deptno = emp.deptno\n"
596+
+ "ORDER BY year(hiredate), emp.sal limit 1)";
597+
final RelNode originalRel;
598+
try {
599+
final SqlNode parse = planner.parse(sql);
600+
final SqlNode validate = planner.validate(parse);
601+
originalRel = planner.rel(validate).rel;
602+
} catch (Exception e) {
603+
throw TestUtil.rethrow(e);
604+
}
605+
606+
final HepProgram hepProgram = HepProgram.builder()
607+
.addRuleCollection(
608+
ImmutableList.of(
609+
// SubQuery program rules
610+
CoreRules.FILTER_SUB_QUERY_TO_CORRELATE,
611+
CoreRules.PROJECT_SUB_QUERY_TO_CORRELATE,
612+
CoreRules.JOIN_SUB_QUERY_TO_CORRELATE))
613+
.build();
614+
final Program program =
615+
Programs.of(hepProgram, true,
616+
requireNonNull(cluster.getMetadataProvider()));
617+
final RelNode before =
618+
program.run(cluster.getPlanner(), originalRel, cluster.traitSet(),
619+
Collections.emptyList(), Collections.emptyList());
620+
final String planBefore = ""
621+
+ "LogicalProject(DNAME=[$1])\n"
622+
+ " LogicalProject(DEPTNO=[$0], DNAME=[$1], LOC=[$2])\n"
623+
+ " LogicalFilter(condition=[>(2000.00, CAST($3):DECIMAL(12, 2))])\n"
624+
+ " LogicalCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{0}])\n"
625+
+ " LogicalTableScan(table=[[scott, DEPT]])\n"
626+
+ " LogicalProject(SAL=[$0])\n"
627+
+ " LogicalSort(sort0=[$1], sort1=[$0], dir0=[ASC], dir1=[ASC], fetch=[1])\n"
628+
+ " LogicalProject(SAL=[$5], EXPR$1=[EXTRACT(FLAG(YEAR), $4)])\n"
629+
+ " LogicalFilter(condition=[=($cor0.DEPTNO, $7)])\n"
630+
+ " LogicalTableScan(table=[[scott, EMP]])\n";
631+
assertThat(before, hasTree(planBefore));
632+
633+
// Decorrelate without any rules, just "purely" decorrelation algorithm on RelDecorrelator
634+
final RelNode after =
635+
RelDecorrelator.decorrelateQuery(before, builder, RuleSets.ofList(Collections.emptyList()),
636+
RuleSets.ofList(Collections.emptyList()));
637+
// Verify plan
638+
final String planAfter = ""
639+
+ "LogicalProject(DNAME=[$1])\n"
640+
+ " LogicalJoin(condition=[=($0, $4)], joinType=[inner])\n"
641+
+ " LogicalTableScan(table=[[scott, DEPT]])\n"
642+
+ " LogicalFilter(condition=[>(2000.00, CAST($0):DECIMAL(12, 2))])\n"
643+
+ " LogicalProject(SAL=[$0], DEPTNO=[$2])\n"
644+
+ " LogicalFilter(condition=[<=($3, 1)])\n"
645+
+ " LogicalProject(SAL=[$5], EXPR$1=[EXTRACT(FLAG(YEAR), $4)], DEPTNO=[$7], rn=[ROW_NUMBER() OVER (PARTITION BY $7 ORDER BY EXTRACT(FLAG(YEAR), $4) NULLS LAST, $5 NULLS LAST)])\n"
646+
+ " LogicalFilter(condition=[IS NOT NULL($7)])\n"
647+
+ " LogicalTableScan(table=[[scott, EMP]])\n";
648+
assertThat(after, hasTree(planAfter));
649+
}
587650
}

core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2680,12 +2680,12 @@ LogicalProject(NAME=[$1])
26802680
<Resource name="planAfter">
26812681
<![CDATA[
26822682
LogicalProject(NAME=[$1])
2683-
LogicalProject(DEPTNO=[$0], NAME=[$1], SAL=[CAST($2):INTEGER], DEPTNO0=[CAST($3):INTEGER])
2683+
LogicalProject(DEPTNO=[$0], NAME=[$1], SAL=[CAST($2):INTEGER], DEPTNO0=[CAST($3):INTEGER], rn=[CAST($4):BIGINT])
26842684
LogicalJoin(condition=[=($0, $3)], joinType=[inner])
26852685
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
26862686
LogicalFilter(condition=[>(10, $0)])
2687-
LogicalAggregate(group=[{0, 1}])
2688-
LogicalProject(SAL=[FIRST_VALUE($5) OVER (PARTITION BY $7 ORDER BY $5 DESC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)], DEPTNO=[$7])
2687+
LogicalFilter(condition=[<=($2, 1)])
2688+
LogicalProject(SAL=[$5], DEPTNO=[$7], rn=[ROW_NUMBER() OVER (PARTITION BY $7 ORDER BY $5 DESC NULLS FIRST)])
26892689
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
26902690
]]>
26912691
</Resource>
@@ -2729,8 +2729,8 @@ LogicalProject(NAME=[$1])
27292729
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
27302730
LogicalFilter(condition=[>(10, $0)])
27312731
LogicalProject(SAL=[$0], DEPTNO=[$2])
2732-
LogicalAggregate(group=[{0, 1, 2}])
2733-
LogicalProject(SAL=[FIRST_VALUE($5) OVER (PARTITION BY $7 ORDER BY EXTRACT(FLAG(YEAR), $4) NULLS LAST, $5 DESC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)], EXPR$1=[FIRST_VALUE(EXTRACT(FLAG(YEAR), $4)) OVER (PARTITION BY $7 ORDER BY EXTRACT(FLAG(YEAR), $4) NULLS LAST, $5 DESC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)], DEPTNO=[$7])
2732+
LogicalFilter(condition=[<=($3, 1)])
2733+
LogicalProject(SAL=[$5], EXPR$1=[EXTRACT(FLAG(YEAR), $4)], DEPTNO=[$7], rn=[ROW_NUMBER() OVER (PARTITION BY $7 ORDER BY EXTRACT(FLAG(YEAR), $4) NULLS LAST, $5 DESC NULLS FIRST)])
27342734
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
27352735
]]>
27362736
</Resource>
@@ -2843,8 +2843,8 @@ LogicalProject(NAME=[$1], EXPR$1=[$2])
28432843
LogicalJoin(condition=[=($0, $3)], joinType=[left])
28442844
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
28452845
LogicalProject(SAL=[$0], DEPTNO=[$2])
2846-
LogicalAggregate(group=[{0, 1, 2}])
2847-
LogicalProject(SAL=[FIRST_VALUE($5) OVER (PARTITION BY $7 ORDER BY EXTRACT(FLAG(YEAR), $4) NULLS LAST, $5 NULLS LAST RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)], EXPR$1=[FIRST_VALUE(EXTRACT(FLAG(YEAR), $4)) OVER (PARTITION BY $7 ORDER BY EXTRACT(FLAG(YEAR), $4) NULLS LAST, $5 NULLS LAST RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)], DEPTNO=[$7])
2846+
LogicalFilter(condition=[<=($3, 1)])
2847+
LogicalProject(SAL=[$5], EXPR$1=[EXTRACT(FLAG(YEAR), $4)], DEPTNO=[$7], rn=[ROW_NUMBER() OVER (PARTITION BY $7 ORDER BY EXTRACT(FLAG(YEAR), $4) NULLS LAST, $5 NULLS LAST)])
28482848
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
28492849
]]>
28502850
</Resource>
@@ -2884,8 +2884,8 @@ LogicalProject(NAME=[$1], EXPR$1=[$4])
28842884
LogicalProject(DEPTNO=[$0], NAME=[$1], DEPTNO0=[$0], NAME0=[CAST($1):VARCHAR(20) NOT NULL])
28852885
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
28862886
LogicalProject(SAL=[$0], DEPTNO=[$2], ENAME=[$3])
2887-
LogicalAggregate(group=[{0, 1, 2, 3}])
2888-
LogicalProject(SAL=[FIRST_VALUE($5) OVER (PARTITION BY $7, $1 ORDER BY EXTRACT(FLAG(YEAR), $4) NULLS LAST, $5 NULLS LAST RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)], EXPR$1=[FIRST_VALUE(EXTRACT(FLAG(YEAR), $4)) OVER (PARTITION BY $7, $1 ORDER BY EXTRACT(FLAG(YEAR), $4) NULLS LAST, $5 NULLS LAST RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)], DEPTNO=[$7], ENAME=[$1])
2887+
LogicalFilter(condition=[<=($4, 1)])
2888+
LogicalProject(SAL=[$5], EXPR$1=[EXTRACT(FLAG(YEAR), $4)], DEPTNO=[$7], ENAME=[$1], rn=[ROW_NUMBER() OVER (PARTITION BY $7, $1 ORDER BY EXTRACT(FLAG(YEAR), $4) NULLS LAST, $5 NULLS LAST)])
28892889
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
28902890
]]>
28912891
</Resource>
@@ -8855,8 +8855,8 @@ LogicalProject(A=[$0], TS=[$1], X=[$2], X0=[$4])
88558855
LogicalJoin(condition=[AND(=($1, $5), =($2, $6))], joinType=[left])
88568856
LogicalProject(EXPR$0=[$0], EXPR$1=[$1], EXPR$00=[CAST($0):VARCHAR(20) NOT NULL], EXPR$10=[$1])
88578857
LogicalValues(tuples=[[{ 'a', 1 }]])
8858-
LogicalAggregate(group=[{0, 1, 2}])
8859-
LogicalProject(X=[FIRST_VALUE($2) OVER (PARTITION BY $3, $4)], EXPR$1=[$3], EXPR$00=[$4])
8858+
LogicalFilter(condition=[<=($3, 1)])
8859+
LogicalProject(X=[$2], EXPR$1=[$3], EXPR$00=[$4], rn=[ROW_NUMBER() OVER (PARTITION BY $3, $4)])
88608860
LogicalJoin(condition=[AND(=($0, $4), <=($1, $3))], joinType=[inner])
88618861
LogicalProject(A=[$1], TS=[$0], X=[$3])
88628862
LogicalTableScan(table=[[CATALOG, SALES, EMP]])

core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -931,14 +931,13 @@ FROM
931931
<Resource name="plan">
932932
<![CDATA[
933933
LogicalProject(DEPTNO=[$0], ENAME=[$1])
934-
LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{0}])
934+
LogicalJoin(condition=[=($0, $3)], joinType=[inner])
935935
LogicalAggregate(group=[{0}])
936936
LogicalProject(DEPTNO=[$7])
937937
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
938-
LogicalSort(sort0=[$1], dir0=[DESC], fetch=[3])
939-
LogicalProject(ENAME=[$1], SAL=[$5])
940-
LogicalFilter(condition=[=($7, $cor0.DEPTNO)])
941-
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
938+
LogicalFilter(condition=[<=($3, 3)])
939+
LogicalProject(ENAME=[$1], SAL=[$5], DEPTNO=[$7], rn=[ROW_NUMBER() OVER (PARTITION BY $7 ORDER BY $5 DESC NULLS FIRST)])
940+
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
942941
]]>
943942
</Resource>
944943
</TestCase>
@@ -2196,8 +2195,8 @@ LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$
21962195
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
21972196
LogicalAggregate(group=[{0}], agg#0=[MIN($1)])
21982197
LogicalProject(DEPTNO=[$1], $f0=[true])
2199-
LogicalAggregate(group=[{0, 1}])
2200-
LogicalProject(EXPR$0=[FIRST_VALUE(1) OVER (PARTITION BY $0)], DEPTNO=[$0])
2198+
LogicalFilter(condition=[<=($2, 1)])
2199+
LogicalProject(EXPR$0=[1], DEPTNO=[$0], rn=[ROW_NUMBER() OVER (PARTITION BY $0)])
22012200
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
22022201
]]>
22032202
</Resource>

0 commit comments

Comments
 (0)