Skip to content

Commit 1e92bbe

Browse files
iwanttobepowerfulmihaibudiu
authored andcommitted
[CALCITE-6942] Support decorrelated for sub-queries with LIMIT 1 and OFFSET
1 parent ae42f77 commit 1e92bbe

File tree

6 files changed

+687
-117
lines changed

6 files changed

+687
-117
lines changed

core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java

Lines changed: 78 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
import org.apache.calcite.rex.RexSubQuery;
7575
import org.apache.calcite.rex.RexUtil;
7676
import org.apache.calcite.rex.RexVisitorImpl;
77+
import org.apache.calcite.rex.RexWindowBounds;
7778
import org.apache.calcite.runtime.PairList;
7879
import org.apache.calcite.sql.SqlAggFunction;
7980
import org.apache.calcite.sql.SqlExplainFormat;
@@ -575,16 +576,33 @@ protected RexNode removeCorrelationExpr(
575576
}
576577

577578
if (isCorVarDefined && (rel.fetch != null || rel.offset != null)) {
578-
if (rel.fetch != null
579-
&& rel.offset == null
580-
&& RexLiteral.intValue(rel.fetch) == 1) {
581-
return decorrelateFetchOneSort(rel, frame);
582-
}
583-
// Can not decorrelate if the sort has per-correlate-key attributes like
584-
// offset or fetch limit, because these attributes scope would change to
585-
// global after decorrelation. They should take effect within the scope
586-
// of the correlation key actually.
587-
return null;
579+
if (rel.offset == null && rel.fetch instanceof RexLiteral) {
580+
final RexLiteral fetchLiteral = (RexLiteral) requireNonNull(rel.fetch, "fetch");
581+
final BigDecimal fetch = fetchLiteral.getValueAs(BigDecimal.class);
582+
assert fetch != null;
583+
if (fetch.equals(BigDecimal.ZERO)) {
584+
return null;
585+
}
586+
}
587+
588+
//
589+
// Rewrite logic:
590+
//
591+
// For correlated Sort with LIMIT/OFFSET:
592+
// Special case: if OFFSET is null and FETCH = 1,
593+
// we may rewrite as an Aggregate using MIN/MAX.
594+
Frame aggFrame = decorrelateSortAsAggregate(rel, frame);
595+
if (aggFrame != null) {
596+
return aggFrame;
597+
}
598+
599+
// General case: rewrite as
600+
// Project(original_fields..., corVars..., rn)
601+
// where rn = ROW_NUMBER() OVER (PARTITION BY corVars ORDER BY sortExprs
602+
// ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)
603+
// Filter(rn > offset, rn <= offset + fetch)
604+
// This preserves per-corVar LIMIT/OFFSET semantics.
605+
return decorrelateSortWithRowNumber(rel, frame);
588606
}
589607

590608
final RelNode newInput = frame.r;
@@ -1036,30 +1054,7 @@ private static void shiftMapping(Map<Integer, Integer> mapping, int startIndex,
10361054
return null;
10371055
}
10381056

1039-
protected @Nullable Frame decorrelateFetchOneSort(Sort sort, final Frame frame) {
1040-
Frame aggFrame = decorrelateSortAsAggregate(sort, frame);
1041-
if (aggFrame != null) {
1042-
return aggFrame;
1043-
}
1044-
//
1045-
// Rewrite logic:
1046-
//
1047-
// If sorted without offset and fetch = 1 (enforced by the caller), rewrite the sort to be
1048-
// Aggregate(group=(corVar.. , field..))
1049-
// project(first_value(field) over (partition by corVar order by (sort collation)))
1050-
// input
1051-
//
1052-
// 1. For the original sorted input, apply the FIRST_VALUE window function to produce
1053-
// the result of sorting with LIMIT 1, and the same as the decorrelate of aggregate,
1054-
// add correlated variables in partition list to maintain semantic consistency.
1055-
// 2. To ensure that there is at most one row of output for
1056-
// any combination of correlated variables, distinct for correlated variables.
1057-
// 3. Since we have partitioned by all correlated variables
1058-
// in the sorted output field window, so for any combination of correlated variables,
1059-
// all other field values are unique. So the following two are equivalent:
1060-
// - group by corVar1, covVar2, field1, field2
1061-
// - any_value(fields1), any_value(fields2) group by corVar1, covVar2
1062-
// Here we use the first.
1057+
protected @Nullable Frame decorrelateSortWithRowNumber(Sort sort, final Frame frame) {
10631058
final Map<Integer, Integer> mapOldToNewOutputs = new HashMap<>();
10641059
final NavigableMap<CorDef, Integer> corDefOutputs = new TreeMap<>();
10651060

@@ -1091,29 +1086,63 @@ private static void shiftMapping(Map<Integer, Integer> mapping, int startIndex,
10911086
for (RelDataTypeField field : sort.getRowType().getFieldList()) {
10921087
final int newIdx =
10931088
requireNonNull(frame.oldToNewOutputs.get(field.getIndex()));
1094-
1095-
RelBuilder.AggCall aggCall =
1096-
relBuilder.aggregateCall(SqlStdOperatorTable.FIRST_VALUE,
1097-
RexInputRef.of(newIdx, fieldList));
1098-
1099-
// Convert each field from the sorted output to a window function that partitions by
1100-
// correlated variables, orders by the collation, and return the first_value.
1101-
RexNode winCall = aggCall.over()
1102-
.orderBy(sortExprs)
1103-
.partitionBy(corVarProjects.leftList())
1104-
.toRex();
11051089
mapOldToNewOutputs.put(newProjExprs.size(), newProjExprs.size());
1106-
newProjExprs.add(winCall, field.getName());
1090+
newProjExprs.add(RexInputRef.of(newIdx, fieldList), field.getName());
11071091
}
11081092
newProjExprs.addAll(corVarProjects);
1109-
RelNode result = relBuilder.push(frame.r)
1110-
.project(newProjExprs.leftList(), newProjExprs.rightList())
1111-
.distinct().build();
11121093

1094+
relBuilder.push(frame.r);
1095+
1096+
RexNode rowNumberCall = relBuilder.aggregateCall(SqlStdOperatorTable.ROW_NUMBER)
1097+
.over()
1098+
.partitionBy(corVarProjects.leftList())
1099+
.orderBy(sortExprs)
1100+
.let(c -> c.rowsBetween(RexWindowBounds.UNBOUNDED_PRECEDING, RexWindowBounds.CURRENT_ROW))
1101+
.toRex();
1102+
newProjExprs.add(rowNumberCall, "rn"); // Add the row number column
1103+
relBuilder.project(newProjExprs.leftList(), newProjExprs.rightList());
1104+
1105+
List<RexNode> conditions = new ArrayList<>();
1106+
if (sort.offset != null) {
1107+
RexNode greaterThenLowerBound =
1108+
relBuilder.call(
1109+
SqlStdOperatorTable.GREATER_THAN,
1110+
relBuilder.field(newProjExprs.size() - 1),
1111+
sort.offset);
1112+
conditions.add(greaterThenLowerBound);
1113+
}
1114+
if (sort.fetch != null) {
1115+
RexNode upperBound = sort.offset == null
1116+
? sort.fetch
1117+
: relBuilder.call(SqlStdOperatorTable.PLUS, sort.offset, sort.fetch);
1118+
RexNode lessThenOrEqualUpperBound =
1119+
relBuilder.call(
1120+
SqlStdOperatorTable.LESS_THAN_OR_EQUAL,
1121+
relBuilder.field(newProjExprs.size() - 1),
1122+
upperBound);
1123+
conditions.add(lessThenOrEqualUpperBound);
1124+
}
1125+
1126+
RelNode result;
1127+
if (!conditions.isEmpty()) {
1128+
result = relBuilder.filter(conditions).build();
1129+
} else {
1130+
result = relBuilder.build();
1131+
}
11131132
return register(sort, result, mapOldToNewOutputs, corDefOutputs);
11141133
}
11151134

11161135
protected @Nullable Frame decorrelateSortAsAggregate(Sort sort, final Frame frame) {
1136+
if (sort.offset != null || sort.fetch == null) {
1137+
return null;
1138+
}
1139+
1140+
final BigDecimal fetch = ((RexLiteral) sort.fetch).getValueAs(BigDecimal.class);
1141+
assert fetch != null;
1142+
if (!fetch.equals(BigDecimal.ONE)) {
1143+
return null;
1144+
}
1145+
11171146
final Map<Integer, Integer> mapOldToNewOutputs = new HashMap<>();
11181147
final NavigableMap<CorDef, Integer> corDefOutputs = new TreeMap<>();
11191148
if (sort.getCollation().getFieldCollations().size() == 1

core/src/test/java/org/apache/calcite/sql2rel/RelDecorrelatorTest.java

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -999,4 +999,203 @@ public static Frameworks.ConfigBuilder config() {
999999
+ " LogicalTableScan(table=[[scott, DEPT]])\n";
10001000
assertThat(decorrelatedNoRules, hasTree(planDecorrelatedNoRules));
10011001
}
1002+
1003+
@Test void testDecorrelateCorrelatedOrderByLimitToRowNumber() {
1004+
final FrameworkConfig frameworkConfig = config().build();
1005+
final RelBuilder builder = RelBuilder.create(frameworkConfig);
1006+
final RelOptCluster cluster = builder.getCluster();
1007+
final Planner planner = Frameworks.getPlanner(frameworkConfig);
1008+
final String sql = ""
1009+
+ "SELECT dname FROM dept WHERE 2000 > (\n"
1010+
+ "SELECT emp.sal FROM emp where dept.deptno = emp.deptno\n"
1011+
+ "ORDER BY year(hiredate), emp.sal limit 1)";
1012+
final RelNode originalRel;
1013+
try {
1014+
final SqlNode parse = planner.parse(sql);
1015+
final SqlNode validate = planner.validate(parse);
1016+
originalRel = planner.rel(validate).rel;
1017+
} catch (Exception e) {
1018+
throw TestUtil.rethrow(e);
1019+
}
1020+
1021+
final HepProgram hepProgram = HepProgram.builder()
1022+
.addRuleCollection(
1023+
ImmutableList.of(
1024+
// SubQuery program rules
1025+
CoreRules.FILTER_SUB_QUERY_TO_CORRELATE,
1026+
CoreRules.PROJECT_SUB_QUERY_TO_CORRELATE,
1027+
CoreRules.JOIN_SUB_QUERY_TO_CORRELATE))
1028+
.build();
1029+
final Program program =
1030+
Programs.of(hepProgram, true,
1031+
requireNonNull(cluster.getMetadataProvider()));
1032+
final RelNode before =
1033+
program.run(cluster.getPlanner(), originalRel, cluster.traitSet(),
1034+
Collections.emptyList(), Collections.emptyList());
1035+
final String planBefore = ""
1036+
+ "LogicalProject(DNAME=[$1])\n"
1037+
+ " LogicalProject(DEPTNO=[$0], DNAME=[$1], LOC=[$2])\n"
1038+
+ " LogicalFilter(condition=[>(2000.00, CAST($3):DECIMAL(12, 2))])\n"
1039+
+ " LogicalCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{0}])\n"
1040+
+ " LogicalTableScan(table=[[scott, DEPT]])\n"
1041+
+ " LogicalProject(SAL=[$0])\n"
1042+
+ " LogicalSort(sort0=[$1], sort1=[$0], dir0=[ASC], dir1=[ASC], fetch=[1])\n"
1043+
+ " LogicalProject(SAL=[$5], EXPR$1=[EXTRACT(FLAG(YEAR), $4)])\n"
1044+
+ " LogicalFilter(condition=[=($cor0.DEPTNO, $7)])\n"
1045+
+ " LogicalTableScan(table=[[scott, EMP]])\n";
1046+
assertThat(before, hasTree(planBefore));
1047+
1048+
// Decorrelate without any rules, just "purely" decorrelation algorithm on RelDecorrelator
1049+
final RelNode after =
1050+
RelDecorrelator.decorrelateQuery(before, builder, RuleSets.ofList(Collections.emptyList()),
1051+
RuleSets.ofList(Collections.emptyList()));
1052+
// Verify plan
1053+
final String planAfter = ""
1054+
+ "LogicalProject(DNAME=[$1])\n"
1055+
+ " LogicalJoin(condition=[=($0, $4)], joinType=[inner])\n"
1056+
+ " LogicalTableScan(table=[[scott, DEPT]])\n"
1057+
+ " LogicalFilter(condition=[>(2000.00, CAST($0):DECIMAL(12, 2))])\n"
1058+
+ " LogicalProject(SAL=[$0], DEPTNO=[$2])\n"
1059+
+ " LogicalFilter(condition=[<=($3, 1)])\n"
1060+
+ " LogicalProject(SAL=[$5], EXPR$1=[EXTRACT(FLAG(YEAR), $4)], DEPTNO=[$7], rn=[ROW_NUMBER() OVER (PARTITION BY $7 ORDER BY EXTRACT(FLAG(YEAR), $4) NULLS LAST, $5 NULLS LAST)])\n"
1061+
+ " LogicalFilter(condition=[IS NOT NULL($7)])\n"
1062+
+ " LogicalTableScan(table=[[scott, EMP]])\n";
1063+
assertThat(after, hasTree(planAfter));
1064+
}
1065+
1066+
@Test void testDecorrelateCorrelatedOrderByLimitToRowNumber2() {
1067+
final FrameworkConfig frameworkConfig = config().build();
1068+
final RelBuilder builder = RelBuilder.create(frameworkConfig);
1069+
final RelOptCluster cluster = builder.getCluster();
1070+
final Planner planner = Frameworks.getPlanner(frameworkConfig);
1071+
final String sql = ""
1072+
+ "SELECT *\n"
1073+
+ "FROM dept d\n"
1074+
+ "WHERE d.deptno IN (\n"
1075+
+ " SELECT e.deptno\n"
1076+
+ " FROM emp e\n"
1077+
+ " WHERE d.deptno = e.deptno\n"
1078+
+ " LIMIT 10\n"
1079+
+ " OFFSET 2\n"
1080+
+ ")\n"
1081+
+ "LIMIT 2\n"
1082+
+ "OFFSET 1";
1083+
final RelNode originalRel;
1084+
try {
1085+
final SqlNode parse = planner.parse(sql);
1086+
final SqlNode validate = planner.validate(parse);
1087+
originalRel = planner.rel(validate).rel;
1088+
} catch (Exception e) {
1089+
throw TestUtil.rethrow(e);
1090+
}
1091+
1092+
final HepProgram hepProgram = HepProgram.builder()
1093+
.addRuleCollection(
1094+
ImmutableList.of(
1095+
// SubQuery program rules
1096+
CoreRules.FILTER_SUB_QUERY_TO_CORRELATE,
1097+
CoreRules.PROJECT_SUB_QUERY_TO_CORRELATE,
1098+
CoreRules.JOIN_SUB_QUERY_TO_CORRELATE))
1099+
.build();
1100+
final Program program =
1101+
Programs.of(hepProgram, true,
1102+
requireNonNull(cluster.getMetadataProvider()));
1103+
final RelNode before =
1104+
program.run(cluster.getPlanner(), originalRel, cluster.traitSet(),
1105+
Collections.emptyList(), Collections.emptyList());
1106+
final String planBefore = ""
1107+
+ "LogicalSort(offset=[1], fetch=[2])\n"
1108+
+ " LogicalProject(DEPTNO=[$0], DNAME=[$1], LOC=[$2])\n"
1109+
+ " LogicalProject(DEPTNO=[$0], DNAME=[$1], LOC=[$2])\n"
1110+
+ " LogicalFilter(condition=[=($0, $3)])\n"
1111+
+ " LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{0}])\n"
1112+
+ " LogicalTableScan(table=[[scott, DEPT]])\n"
1113+
+ " LogicalAggregate(group=[{0}])\n"
1114+
+ " LogicalSort(offset=[2], fetch=[10])\n"
1115+
+ " LogicalProject(DEPTNO=[$7])\n"
1116+
+ " LogicalFilter(condition=[=($cor0.DEPTNO, $7)])\n"
1117+
+ " LogicalTableScan(table=[[scott, EMP]])\n";
1118+
assertThat(before, hasTree(planBefore));
1119+
1120+
// Decorrelate without any rules, just "purely" decorrelation algorithm on RelDecorrelator
1121+
final RelNode after =
1122+
RelDecorrelator.decorrelateQuery(before, builder, RuleSets.ofList(Collections.emptyList()),
1123+
RuleSets.ofList(Collections.emptyList()));
1124+
// Verify plan
1125+
final String planAfter = ""
1126+
+ "LogicalSort(offset=[1], fetch=[2])\n"
1127+
+ " LogicalProject(DEPTNO=[$0], DNAME=[$1], LOC=[$2])\n"
1128+
+ " LogicalJoin(condition=[=($0, $4)], joinType=[inner])\n"
1129+
+ " LogicalTableScan(table=[[scott, DEPT]])\n"
1130+
+ " LogicalFilter(condition=[=($1, $0)])\n"
1131+
+ " LogicalAggregate(group=[{0, 1}])\n"
1132+
+ " LogicalProject(DEPTNO=[$0], DEPTNO1=[$1])\n"
1133+
+ " LogicalFilter(condition=[AND(>($2, 2), <=($2, +(2, 10)))])\n"
1134+
+ " LogicalProject(DEPTNO=[$7], DEPTNO1=[$7], rn=[ROW_NUMBER() OVER (PARTITION BY $7)])\n"
1135+
+ " LogicalFilter(condition=[IS NOT NULL($7)])\n"
1136+
+ " LogicalTableScan(table=[[scott, EMP]])\n";
1137+
assertThat(after, hasTree(planAfter));
1138+
}
1139+
1140+
@Test void testDecorrelateCorrelatedOrderByLimitToRowNumber3() {
1141+
final FrameworkConfig frameworkConfig = config().build();
1142+
final RelBuilder builder = RelBuilder.create(frameworkConfig);
1143+
final RelOptCluster cluster = builder.getCluster();
1144+
final Planner planner = Frameworks.getPlanner(frameworkConfig);
1145+
final String sql = ""
1146+
+ "SELECT deptno FROM dept WHERE 1000.00 >\n"
1147+
+ "(SELECT sal FROM emp WHERE dept.deptno = emp.deptno\n"
1148+
+ "order by emp.sal limit 1 offset 10)";
1149+
final RelNode originalRel;
1150+
try {
1151+
final SqlNode parse = planner.parse(sql);
1152+
final SqlNode validate = planner.validate(parse);
1153+
originalRel = planner.rel(validate).rel;
1154+
} catch (Exception e) {
1155+
throw TestUtil.rethrow(e);
1156+
}
1157+
1158+
final HepProgram hepProgram = HepProgram.builder()
1159+
.addRuleCollection(
1160+
ImmutableList.of(
1161+
// SubQuery program rules
1162+
CoreRules.FILTER_SUB_QUERY_TO_CORRELATE,
1163+
CoreRules.PROJECT_SUB_QUERY_TO_CORRELATE,
1164+
CoreRules.JOIN_SUB_QUERY_TO_CORRELATE))
1165+
.build();
1166+
final Program program =
1167+
Programs.of(hepProgram, true,
1168+
requireNonNull(cluster.getMetadataProvider()));
1169+
final RelNode before =
1170+
program.run(cluster.getPlanner(), originalRel, cluster.traitSet(),
1171+
Collections.emptyList(), Collections.emptyList());
1172+
final String planBefore = ""
1173+
+ "LogicalProject(DEPTNO=[$0])\n"
1174+
+ " LogicalProject(DEPTNO=[$0], DNAME=[$1], LOC=[$2])\n"
1175+
+ " LogicalFilter(condition=[>(1000.00, $3)])\n"
1176+
+ " LogicalCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{0}])\n"
1177+
+ " LogicalTableScan(table=[[scott, DEPT]])\n"
1178+
+ " LogicalSort(sort0=[$0], dir0=[ASC], offset=[10], fetch=[1])\n"
1179+
+ " LogicalProject(SAL=[$5])\n"
1180+
+ " LogicalFilter(condition=[=($cor0.DEPTNO, $7)])\n"
1181+
+ " LogicalTableScan(table=[[scott, EMP]])\n";
1182+
assertThat(before, hasTree(planBefore));
1183+
1184+
// Decorrelate without any rules, just "purely" decorrelation algorithm on RelDecorrelator
1185+
final RelNode after =
1186+
RelDecorrelator.decorrelateQuery(before, builder, RuleSets.ofList(Collections.emptyList()),
1187+
RuleSets.ofList(Collections.emptyList()));
1188+
// Verify plan
1189+
final String planAfter = ""
1190+
+ "LogicalProject(DEPTNO=[$0])\n"
1191+
+ " LogicalProject(DEPTNO=[$0], DNAME=[$1], LOC=[$2], SAL=[$3], DEPTNO0=[$4], rn=[CAST($5):BIGINT])\n"
1192+
+ " LogicalJoin(condition=[=($0, $4)], joinType=[inner])\n"
1193+
+ " LogicalTableScan(table=[[scott, DEPT]])\n"
1194+
+ " LogicalFilter(condition=[>(1000.00, $0)])\n"
1195+
+ " LogicalFilter(condition=[AND(>($2, 10), <=($2, +(10, 1)))])\n"
1196+
+ " LogicalProject(SAL=[$5], DEPTNO=[$7], rn=[ROW_NUMBER() OVER (PARTITION BY $7 ORDER BY $5 NULLS LAST)])\n"
1197+
+ " LogicalFilter(condition=[IS NOT NULL($7)])\n"
1198+
+ " LogicalTableScan(table=[[scott, EMP]])\n";
1199+
assertThat(after, hasTree(planAfter));
1200+
}
10021201
}

0 commit comments

Comments
 (0)