Skip to content

Commit 5ec06f2

Browse files
[CALCITE-6942] Support decorrelated for sub-queries with LIMIT 1 and OFFSET
1 parent f19e854 commit 5ec06f2

File tree

5 files changed

+674
-111
lines changed

5 files changed

+674
-111
lines changed

core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java

Lines changed: 68 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
import org.apache.calcite.rex.RexSubQuery;
7474
import org.apache.calcite.rex.RexUtil;
7575
import org.apache.calcite.rex.RexVisitorImpl;
76+
import org.apache.calcite.rex.RexWindowBounds;
7677
import org.apache.calcite.runtime.PairList;
7778
import org.apache.calcite.sql.SqlAggFunction;
7879
import org.apache.calcite.sql.SqlExplainFormat;
@@ -574,16 +575,7 @@ protected RexNode removeCorrelationExpr(
574575
}
575576

576577
if (isCorVarDefined && (rel.fetch != null || rel.offset != null)) {
577-
if (rel.fetch != null
578-
&& rel.offset == null
579-
&& RexLiteral.intValue(rel.fetch) == 1) {
580-
return decorrelateFetchOneSort(rel, frame);
581-
}
582-
// Can not decorrelate if the sort has per-correlate-key attributes like
583-
// offset or fetch limit, because these attributes scope would change to
584-
// global after decorrelation. They should take effect within the scope
585-
// of the correlation key actually.
586-
return null;
578+
return decorrelateFetchOneSort(rel, frame);
587579
}
588580

589581
final RelNode newInput = frame.r;
@@ -1045,29 +1037,31 @@ private static void shiftMapping(Map<Integer, Integer> mapping, int startIndex,
10451037
}
10461038

10471039
protected @Nullable Frame decorrelateFetchOneSort(Sort sort, final Frame frame) {
1048-
Frame aggFrame = decorrelateSortAsAggregate(sort, frame);
1049-
if (aggFrame != null) {
1050-
return aggFrame;
1040+
if (sort.offset == null
1041+
&& sort.fetch != null
1042+
&& RexLiteral.intValue(sort.fetch) == 0) {
1043+
return null;
10511044
}
1045+
10521046
//
10531047
// Rewrite logic:
10541048
//
1055-
// If sorted without offset and fetch = 1 (enforced by the caller), rewrite the sort to be
1056-
// Aggregate(group=(corVar.. , field..))
1057-
// project(first_value(field) over (partition by corVar order by (sort collation)))
1058-
// input
1049+
// For correlated Sort with LIMIT/OFFSET:
1050+
// 1) Special case: if OFFSET is null and FETCH = 1, we may rewrite as an Aggregate
1051+
// using MIN/MAX (see decorrelateSortAsAggregate).
1052+
// 2) General case: rewrite as
1053+
// Project(original_fields..., corVars..., rn)
1054+
// where rn = ROW_NUMBER() OVER (PARTITION BY corVars ORDER BY sortExprs
1055+
// ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)
1056+
// Filter(rn > offset, rn <= offset + fetch)
1057+
// This preserves per-corVar LIMIT/OFFSET semantics.
10591058
//
1060-
// 1. For the original sorted input, apply the FIRST_VALUE window function to produce
1061-
// the result of sorting with LIMIT 1, and the same as the decorrelate of aggregate,
1062-
// add correlated variables in partition list to maintain semantic consistency.
1063-
// 2. To ensure that there is at most one row of output for
1064-
// any combination of correlated variables, distinct for correlated variables.
1065-
// 3. Since we have partitioned by all correlated variables
1066-
// in the sorted output field window, so for any combination of correlated variables,
1067-
// all other field values are unique. So the following two are equivalent:
1068-
// - group by corVar1, covVar2, field1, field2
1069-
// - any_value(fields1), any_value(fields2) group by corVar1, covVar2
1070-
// Here we use the first.
1059+
1060+
Frame aggFrame = decorrelateSortAsAggregate(sort, frame);
1061+
if (aggFrame != null) {
1062+
return aggFrame;
1063+
}
1064+
10711065
final Map<Integer, Integer> mapOldToNewOutputs = new HashMap<>();
10721066
final NavigableMap<CorDef, Integer> corDefOutputs = new TreeMap<>();
10731067

@@ -1099,29 +1093,60 @@ private static void shiftMapping(Map<Integer, Integer> mapping, int startIndex,
10991093
for (RelDataTypeField field : sort.getRowType().getFieldList()) {
11001094
final int newIdx =
11011095
requireNonNull(frame.oldToNewOutputs.get(field.getIndex()));
1102-
1103-
RelBuilder.AggCall aggCall =
1104-
relBuilder.aggregateCall(SqlStdOperatorTable.FIRST_VALUE,
1105-
RexInputRef.of(newIdx, fieldList));
1106-
1107-
// Convert each field from the sorted output to a window function that partitions by
1108-
// correlated variables, orders by the collation, and return the first_value.
1109-
RexNode winCall = aggCall.over()
1110-
.orderBy(sortExprs)
1111-
.partitionBy(corVarProjects.leftList())
1112-
.toRex();
11131096
mapOldToNewOutputs.put(newProjExprs.size(), newProjExprs.size());
1114-
newProjExprs.add(winCall, field.getName());
1097+
newProjExprs.add(RexInputRef.of(newIdx, fieldList), field.getName());
11151098
}
11161099
newProjExprs.addAll(corVarProjects);
1117-
RelNode result = relBuilder.push(frame.r)
1118-
.project(newProjExprs.leftList(), newProjExprs.rightList())
1119-
.distinct().build();
11201100

1101+
relBuilder.push(frame.r);
1102+
1103+
RexNode rowNumberCall = relBuilder.aggregateCall(SqlStdOperatorTable.ROW_NUMBER)
1104+
.over()
1105+
.partitionBy(corVarProjects.leftList())
1106+
.orderBy(sortExprs)
1107+
.let(c -> c.rowsBetween(RexWindowBounds.UNBOUNDED_PRECEDING, RexWindowBounds.CURRENT_ROW))
1108+
.toRex();
1109+
newProjExprs.add(rowNumberCall, "rn"); // Add the row number column
1110+
relBuilder.project(newProjExprs.leftList(), newProjExprs.rightList());
1111+
1112+
List<RexNode> conditions = new ArrayList<>();
1113+
if (sort.offset != null) {
1114+
RexNode greaterThenLowerBound =
1115+
relBuilder.call(
1116+
SqlStdOperatorTable.GREATER_THAN,
1117+
relBuilder.field(newProjExprs.size() - 1),
1118+
sort.offset);
1119+
conditions.add(greaterThenLowerBound);
1120+
}
1121+
if (sort.fetch != null) {
1122+
RexNode upperBound = sort.offset == null
1123+
? sort.fetch
1124+
: relBuilder.call(SqlStdOperatorTable.PLUS, sort.offset, sort.fetch);
1125+
RexNode lessThenOrEqualUpperBound =
1126+
relBuilder.call(
1127+
SqlStdOperatorTable.LESS_THAN_OR_EQUAL,
1128+
relBuilder.field(newProjExprs.size() - 1),
1129+
upperBound);
1130+
conditions.add(lessThenOrEqualUpperBound);
1131+
}
1132+
1133+
RelNode result;
1134+
if (!conditions.isEmpty()) {
1135+
result = relBuilder.filter(conditions).build();
1136+
} else {
1137+
result = relBuilder.build();
1138+
}
11211139
return register(sort, result, mapOldToNewOutputs, corDefOutputs);
11221140
}
11231141

11241142
protected @Nullable Frame decorrelateSortAsAggregate(Sort sort, final Frame frame) {
1143+
if (sort.offset != null) {
1144+
return null;
1145+
}
1146+
if (sort.fetch == null || RexLiteral.intValue(sort.fetch) != 1) {
1147+
return null;
1148+
}
1149+
11251150
final Map<Integer, Integer> mapOldToNewOutputs = new HashMap<>();
11261151
final NavigableMap<CorDef, Integer> corDefOutputs = new TreeMap<>();
11271152
if (sort.getCollation().getFieldCollations().size() == 1

core/src/test/java/org/apache/calcite/sql2rel/RelDecorrelatorTest.java

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,4 +584,203 @@ public static Frameworks.ConfigBuilder config() {
584584
+ " LogicalTableScan(table=[[scott, DEPT]])\n";
585585
assertThat(decorrelatedNoRules, hasTree(planDecorrelatedNoRules));
586586
}
587+
588+
@Test void testDecorrelateCorrelatedOrderByLimitToRowNumber() {
589+
final FrameworkConfig frameworkConfig = config().build();
590+
final RelBuilder builder = RelBuilder.create(frameworkConfig);
591+
final RelOptCluster cluster = builder.getCluster();
592+
final Planner planner = Frameworks.getPlanner(frameworkConfig);
593+
final String sql = ""
594+
+ "SELECT dname FROM dept WHERE 2000 > (\n"
595+
+ "SELECT emp.sal FROM emp where dept.deptno = emp.deptno\n"
596+
+ "ORDER BY year(hiredate), emp.sal limit 1)";
597+
final RelNode originalRel;
598+
try {
599+
final SqlNode parse = planner.parse(sql);
600+
final SqlNode validate = planner.validate(parse);
601+
originalRel = planner.rel(validate).rel;
602+
} catch (Exception e) {
603+
throw TestUtil.rethrow(e);
604+
}
605+
606+
final HepProgram hepProgram = HepProgram.builder()
607+
.addRuleCollection(
608+
ImmutableList.of(
609+
// SubQuery program rules
610+
CoreRules.FILTER_SUB_QUERY_TO_CORRELATE,
611+
CoreRules.PROJECT_SUB_QUERY_TO_CORRELATE,
612+
CoreRules.JOIN_SUB_QUERY_TO_CORRELATE))
613+
.build();
614+
final Program program =
615+
Programs.of(hepProgram, true,
616+
requireNonNull(cluster.getMetadataProvider()));
617+
final RelNode before =
618+
program.run(cluster.getPlanner(), originalRel, cluster.traitSet(),
619+
Collections.emptyList(), Collections.emptyList());
620+
final String planBefore = ""
621+
+ "LogicalProject(DNAME=[$1])\n"
622+
+ " LogicalProject(DEPTNO=[$0], DNAME=[$1], LOC=[$2])\n"
623+
+ " LogicalFilter(condition=[>(2000.00, CAST($3):DECIMAL(12, 2))])\n"
624+
+ " LogicalCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{0}])\n"
625+
+ " LogicalTableScan(table=[[scott, DEPT]])\n"
626+
+ " LogicalProject(SAL=[$0])\n"
627+
+ " LogicalSort(sort0=[$1], sort1=[$0], dir0=[ASC], dir1=[ASC], fetch=[1])\n"
628+
+ " LogicalProject(SAL=[$5], EXPR$1=[EXTRACT(FLAG(YEAR), $4)])\n"
629+
+ " LogicalFilter(condition=[=($cor0.DEPTNO, $7)])\n"
630+
+ " LogicalTableScan(table=[[scott, EMP]])\n";
631+
assertThat(before, hasTree(planBefore));
632+
633+
// Decorrelate without any rules, just "purely" decorrelation algorithm on RelDecorrelator
634+
final RelNode after =
635+
RelDecorrelator.decorrelateQuery(before, builder, RuleSets.ofList(Collections.emptyList()),
636+
RuleSets.ofList(Collections.emptyList()));
637+
// Verify plan
638+
final String planAfter = ""
639+
+ "LogicalProject(DNAME=[$1])\n"
640+
+ " LogicalJoin(condition=[=($0, $4)], joinType=[inner])\n"
641+
+ " LogicalTableScan(table=[[scott, DEPT]])\n"
642+
+ " LogicalFilter(condition=[>(2000.00, CAST($0):DECIMAL(12, 2))])\n"
643+
+ " LogicalProject(SAL=[$0], DEPTNO=[$2])\n"
644+
+ " LogicalFilter(condition=[<=($3, 1)])\n"
645+
+ " LogicalProject(SAL=[$5], EXPR$1=[EXTRACT(FLAG(YEAR), $4)], DEPTNO=[$7], rn=[ROW_NUMBER() OVER (PARTITION BY $7 ORDER BY EXTRACT(FLAG(YEAR), $4) NULLS LAST, $5 NULLS LAST)])\n"
646+
+ " LogicalFilter(condition=[IS NOT NULL($7)])\n"
647+
+ " LogicalTableScan(table=[[scott, EMP]])\n";
648+
assertThat(after, hasTree(planAfter));
649+
}
650+
651+
@Test void testDecorrelateCorrelatedOrderByLimitToRowNumber2() {
652+
final FrameworkConfig frameworkConfig = config().build();
653+
final RelBuilder builder = RelBuilder.create(frameworkConfig);
654+
final RelOptCluster cluster = builder.getCluster();
655+
final Planner planner = Frameworks.getPlanner(frameworkConfig);
656+
final String sql = ""
657+
+ "SELECT *\n"
658+
+ "FROM dept d\n"
659+
+ "WHERE d.deptno IN (\n"
660+
+ " SELECT e.deptno\n"
661+
+ " FROM emp e\n"
662+
+ " WHERE d.deptno = e.deptno\n"
663+
+ " LIMIT 10\n"
664+
+ " OFFSET 2\n"
665+
+ ")\n"
666+
+ "LIMIT 2\n"
667+
+ "OFFSET 1";
668+
final RelNode originalRel;
669+
try {
670+
final SqlNode parse = planner.parse(sql);
671+
final SqlNode validate = planner.validate(parse);
672+
originalRel = planner.rel(validate).rel;
673+
} catch (Exception e) {
674+
throw TestUtil.rethrow(e);
675+
}
676+
677+
final HepProgram hepProgram = HepProgram.builder()
678+
.addRuleCollection(
679+
ImmutableList.of(
680+
// SubQuery program rules
681+
CoreRules.FILTER_SUB_QUERY_TO_CORRELATE,
682+
CoreRules.PROJECT_SUB_QUERY_TO_CORRELATE,
683+
CoreRules.JOIN_SUB_QUERY_TO_CORRELATE))
684+
.build();
685+
final Program program =
686+
Programs.of(hepProgram, true,
687+
requireNonNull(cluster.getMetadataProvider()));
688+
final RelNode before =
689+
program.run(cluster.getPlanner(), originalRel, cluster.traitSet(),
690+
Collections.emptyList(), Collections.emptyList());
691+
final String planBefore = ""
692+
+ "LogicalSort(offset=[1], fetch=[2])\n"
693+
+ " LogicalProject(DEPTNO=[$0], DNAME=[$1], LOC=[$2])\n"
694+
+ " LogicalProject(DEPTNO=[$0], DNAME=[$1], LOC=[$2])\n"
695+
+ " LogicalFilter(condition=[=($0, $3)])\n"
696+
+ " LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{0}])\n"
697+
+ " LogicalTableScan(table=[[scott, DEPT]])\n"
698+
+ " LogicalAggregate(group=[{0}])\n"
699+
+ " LogicalSort(offset=[2], fetch=[10])\n"
700+
+ " LogicalProject(DEPTNO=[$7])\n"
701+
+ " LogicalFilter(condition=[=($cor0.DEPTNO, $7)])\n"
702+
+ " LogicalTableScan(table=[[scott, EMP]])\n";
703+
assertThat(before, hasTree(planBefore));
704+
705+
// Decorrelate without any rules, just "purely" decorrelation algorithm on RelDecorrelator
706+
final RelNode after =
707+
RelDecorrelator.decorrelateQuery(before, builder, RuleSets.ofList(Collections.emptyList()),
708+
RuleSets.ofList(Collections.emptyList()));
709+
// Verify plan
710+
final String planAfter = ""
711+
+ "LogicalSort(offset=[1], fetch=[2])\n"
712+
+ " LogicalProject(DEPTNO=[$0], DNAME=[$1], LOC=[$2])\n"
713+
+ " LogicalJoin(condition=[=($0, $4)], joinType=[inner])\n"
714+
+ " LogicalTableScan(table=[[scott, DEPT]])\n"
715+
+ " LogicalFilter(condition=[=($1, $0)])\n"
716+
+ " LogicalAggregate(group=[{0, 1}])\n"
717+
+ " LogicalProject(DEPTNO=[$0], DEPTNO1=[$1])\n"
718+
+ " LogicalFilter(condition=[AND(>($2, 2), <=($2, +(2, 10)))])\n"
719+
+ " LogicalProject(DEPTNO=[$7], DEPTNO1=[$7], rn=[ROW_NUMBER() OVER (PARTITION BY $7)])\n"
720+
+ " LogicalFilter(condition=[IS NOT NULL($7)])\n"
721+
+ " LogicalTableScan(table=[[scott, EMP]])\n";
722+
assertThat(after, hasTree(planAfter));
723+
}
724+
725+
@Test void testDecorrelateCorrelatedOrderByLimitToRowNumber3() {
726+
final FrameworkConfig frameworkConfig = config().build();
727+
final RelBuilder builder = RelBuilder.create(frameworkConfig);
728+
final RelOptCluster cluster = builder.getCluster();
729+
final Planner planner = Frameworks.getPlanner(frameworkConfig);
730+
final String sql = ""
731+
+ "SELECT deptno FROM dept WHERE 1000.00 >\n"
732+
+ "(SELECT sal FROM emp WHERE dept.deptno = emp.deptno\n"
733+
+ "order by emp.sal limit 1 offset 10)";
734+
final RelNode originalRel;
735+
try {
736+
final SqlNode parse = planner.parse(sql);
737+
final SqlNode validate = planner.validate(parse);
738+
originalRel = planner.rel(validate).rel;
739+
} catch (Exception e) {
740+
throw TestUtil.rethrow(e);
741+
}
742+
743+
final HepProgram hepProgram = HepProgram.builder()
744+
.addRuleCollection(
745+
ImmutableList.of(
746+
// SubQuery program rules
747+
CoreRules.FILTER_SUB_QUERY_TO_CORRELATE,
748+
CoreRules.PROJECT_SUB_QUERY_TO_CORRELATE,
749+
CoreRules.JOIN_SUB_QUERY_TO_CORRELATE))
750+
.build();
751+
final Program program =
752+
Programs.of(hepProgram, true,
753+
requireNonNull(cluster.getMetadataProvider()));
754+
final RelNode before =
755+
program.run(cluster.getPlanner(), originalRel, cluster.traitSet(),
756+
Collections.emptyList(), Collections.emptyList());
757+
final String planBefore = ""
758+
+ "LogicalProject(DEPTNO=[$0])\n"
759+
+ " LogicalProject(DEPTNO=[$0], DNAME=[$1], LOC=[$2])\n"
760+
+ " LogicalFilter(condition=[>(1000.00, $3)])\n"
761+
+ " LogicalCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{0}])\n"
762+
+ " LogicalTableScan(table=[[scott, DEPT]])\n"
763+
+ " LogicalSort(sort0=[$0], dir0=[ASC], offset=[10], fetch=[1])\n"
764+
+ " LogicalProject(SAL=[$5])\n"
765+
+ " LogicalFilter(condition=[=($cor0.DEPTNO, $7)])\n"
766+
+ " LogicalTableScan(table=[[scott, EMP]])\n";
767+
assertThat(before, hasTree(planBefore));
768+
769+
// Decorrelate without any rules, just "purely" decorrelation algorithm on RelDecorrelator
770+
final RelNode after =
771+
RelDecorrelator.decorrelateQuery(before, builder, RuleSets.ofList(Collections.emptyList()),
772+
RuleSets.ofList(Collections.emptyList()));
773+
// Verify plan
774+
final String planAfter = ""
775+
+ "LogicalProject(DEPTNO=[$0])\n"
776+
+ " LogicalProject(DEPTNO=[$0], DNAME=[$1], LOC=[$2], SAL=[$3], DEPTNO0=[$4], rn=[CAST($5):BIGINT])\n"
777+
+ " LogicalJoin(condition=[=($0, $4)], joinType=[inner])\n"
778+
+ " LogicalTableScan(table=[[scott, DEPT]])\n"
779+
+ " LogicalFilter(condition=[>(1000.00, $0)])\n"
780+
+ " LogicalFilter(condition=[AND(>($2, 10), <=($2, +(10, 1)))])\n"
781+
+ " LogicalProject(SAL=[$5], DEPTNO=[$7], rn=[ROW_NUMBER() OVER (PARTITION BY $7 ORDER BY $5 NULLS LAST)])\n"
782+
+ " LogicalFilter(condition=[IS NOT NULL($7)])\n"
783+
+ " LogicalTableScan(table=[[scott, EMP]])\n";
784+
assertThat(after, hasTree(planAfter));
785+
}
587786
}

0 commit comments

Comments
 (0)