Skip to content

Commit 3eac1f7

Browse files
iwanttobepowerfulmihaibudiu
authored andcommitted
[CALCITE-7272] Subqueries cannot be decorrelated if have set op
1 parent 5c6456d commit 3eac1f7

File tree

4 files changed

+1779
-47
lines changed

4 files changed

+1779
-47
lines changed

core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java

Lines changed: 173 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
import org.apache.calcite.rel.core.JoinRelType;
4545
import org.apache.calcite.rel.core.Project;
4646
import org.apache.calcite.rel.core.RelFactories;
47+
import org.apache.calcite.rel.core.SetOp;
4748
import org.apache.calcite.rel.core.Sort;
4849
import org.apache.calcite.rel.logical.LogicalAggregate;
4950
import org.apache.calcite.rel.logical.LogicalCorrelate;
@@ -817,8 +818,7 @@ protected RexNode removeCorrelationExpr(
817818
}
818819
}
819820

820-
if (rel.getGroupType() == Aggregate.Group.SIMPLE
821-
&& rel.getGroupSet().isEmpty()
821+
if ((rel.hasEmptyGroup() || rel.getGroupSet().isEmpty())
822822
&& !frame.corDefOutputs.isEmpty()
823823
&& !parentPropagatesNullValues) {
824824
newRel = rewriteScalarAggregate(rel, newRel, outputMap, corDefOutputs);
@@ -930,71 +930,63 @@ private RelNode rewriteScalarAggregate(Aggregate oldRel,
930930
RelNode newRel,
931931
Map<Integer, Integer> outputMap,
932932
NavigableMap<CorDef, Integer> corDefOutputs) {
933-
final Pair<CorrelationId, Frame> outerFramePair = requireNonNull(this.frameStack.peek());
934-
final Frame outFrame = outerFramePair.right;
935-
RexBuilder rexBuilder = relBuilder.getRexBuilder();
933+
final CorelMap localCorelMap = new CorelMapBuilder().build(oldRel);
934+
final List<CorRef> corVarList = new ArrayList<>(localCorelMap.mapRefRelToCorRef.values());
935+
Collections.sort(corVarList);
936936

937-
int groupKeySize = (int) corDefOutputs.keySet().stream()
938-
.filter(a -> a.corr.equals(outerFramePair.left))
939-
.count();
940-
List<RelDataTypeField> newRelFields = newRel.getRowType().getFieldList();
941-
ImmutableBitSet.Builder corFieldBuilder = ImmutableBitSet.builder();
937+
final NavigableMap<CorDef, Integer> valueGenCorDefOutputs = new TreeMap<>();
938+
final RelNode valueGen =
939+
requireNonNull(createValueGenerator(corVarList, 0, valueGenCorDefOutputs));
940+
final int valueGenFieldCount = valueGen.getRowType().getFieldCount();
942941

943-
// Here we record the mapping between the original index and the new project.
944-
// For the count, we map it as `case when x is null then 0 else x`.
942+
// Build join conditions
945943
final Map<Integer, RexNode> newProjectMap = new HashMap<>();
946944
final List<RexNode> conditions = new ArrayList<>();
947945
for (Map.Entry<CorDef, Integer> corDefOutput : corDefOutputs.entrySet()) {
948-
CorDef corDef = corDefOutput.getKey();
949-
Integer corIndex = corDefOutput.getValue();
950-
if (corDef.corr.equals(outerFramePair.left)) {
951-
int newIdx = requireNonNull(outFrame.oldToNewOutputs.get(corDef.field));
952-
corFieldBuilder.set(newIdx);
953-
954-
RelDataType type = outFrame.r.getRowType().getFieldList().get(newIdx).getType();
955-
RexNode left = new RexInputRef(corFieldBuilder.cardinality() - 1, type);
956-
newProjectMap.put(corIndex + groupKeySize, left);
957-
conditions.add(
958-
relBuilder.isNotDistinctFrom(left,
959-
new RexInputRef(corIndex + groupKeySize,
960-
newRelFields.get(corIndex).getType())));
961-
}
962-
}
963-
964-
ImmutableBitSet groupSet = corFieldBuilder.build();
965-
// Build [09] LogicalAggregate(group=[{0}]) to obtain the distinct set of
966-
// corVar from outFrame.
967-
relBuilder.push(outFrame.r)
968-
.aggregate(relBuilder.groupKey(groupSet));
946+
final CorDef corDef = corDefOutput.getKey();
947+
final int leftPos = requireNonNull(valueGenCorDefOutputs.get(corDef));
948+
final int rightPos = corDefOutput.getValue();
949+
final RelDataType leftType = valueGen.getRowType().getFieldList().get(leftPos).getType();
950+
final RelDataType rightType = newRel.getRowType().getFieldList().get(rightPos).getType();
951+
final RexNode leftRef = new RexInputRef(leftPos, leftType);
952+
final RexNode rightRef = new RexInputRef(valueGenFieldCount + rightPos, rightType);
953+
conditions.add(relBuilder.isNotDistinctFrom(leftRef, rightRef));
954+
newProjectMap.put(valueGenFieldCount + rightPos, leftRef);
955+
}
956+
final RexNode joinCond = RexUtil.composeConjunction(relBuilder.getRexBuilder(), conditions);
969957

970958
// Build [08] LogicalJoin(condition=[IS NOT DISTINCT FROM($0, $1)], joinType=[left])
971959
// to ensure each corVar's aggregate result is output.
972-
final RelNode join = relBuilder.push(newRel)
973-
.join(JoinRelType.LEFT, conditions).build();
960+
final RelNode join = relBuilder.push(valueGen).push(newRel)
961+
.join(JoinRelType.LEFT, joinCond).build();
962+
RelDataType joinRowType = join.getRowType();
974963

964+
RexBuilder rexBuilder = relBuilder.getRexBuilder();
965+
// Here we record the mapping between the original index and the new project.
966+
// For the count, we map it as `case when x is null then 0 else x`.
975967
for (int i1 = 0; i1 < oldRel.getAggCallList().size(); i1++) {
976968
AggregateCall aggCall = oldRel.getAggCallList().get(i1);
977969
if (aggCall.getAggregation() instanceof SqlCountAggFunction) {
978970
int index = requireNonNull(outputMap.get(i1 + oldRel.getGroupSet().size()));
979-
final RexInputRef ref = RexInputRef.of(index + groupKeySize, join.getRowType());
980-
RexNode specificCountValue =
981-
rexBuilder.makeCall(SqlStdOperatorTable.CASE,
982-
ImmutableList.of(relBuilder.isNotNull(ref), ref, relBuilder.literal(0)));
971+
final RexInputRef ref = RexInputRef.of(index + valueGenFieldCount, joinRowType);
972+
ImmutableList<RexNode> exprs =
973+
ImmutableList.of(relBuilder.isNotNull(ref), ref, relBuilder.literal(0));
974+
RexNode specificCountValue = rexBuilder.makeCall(SqlStdOperatorTable.CASE, exprs);
983975
newProjectMap.put(ref.getIndex(), specificCountValue);
984976
}
985977
}
986978

979+
// Build [07] LogicalProject(DEPTNO=[$0], EXPR$0=[CASE(IS NOT NULL($2), $2, 0)])
980+
// to handle COUNT function by converting nulls to zero.
987981
final List<RexNode> newProjects = new ArrayList<>();
988-
for (int index : ImmutableBitSet.range(groupKeySize, join.getRowType().getFieldCount())) {
982+
for (int index : ImmutableBitSet.range(valueGenFieldCount, joinRowType.getFieldCount())) {
989983
if (newProjectMap.containsKey(index)) {
990984
newProjects.add(requireNonNull(newProjectMap.get(index)));
991985
} else {
992-
newProjects.add(RexInputRef.of(index, join.getRowType()));
986+
newProjects.add(RexInputRef.of(index, joinRowType));
993987
}
994988
}
995989

996-
// Build [07] LogicalProject(DEPTNO=[$0], EXPR$0=[CASE(IS NOT NULL($2), $2, 0)])
997-
// to handle COUNT function by converting nulls to zero.
998990
return relBuilder.push(join)
999991
.project(newProjects, newRel.getRowType().getFieldNames())
1000992
.build();
@@ -1184,6 +1176,144 @@ private static void shiftMapping(Map<Integer, Integer> mapping, int startIndex,
11841176
return null;
11851177
}
11861178

1179+
/**
1180+
* Given the SQL:
1181+
* SELECT ename,
1182+
* (SELECT sum(c)
1183+
* FROM
1184+
* (SELECT deptno AS c
1185+
* FROM dept
1186+
* WHERE dept.deptno = emp.deptno
1187+
* UNION ALL
1188+
* SELECT 2 AS c
1189+
* FROM bonus
1190+
* WHERE bonus.job = emp.job) AS union_subquery
1191+
* ) AS correlated_sum
1192+
* FROM emp;
1193+
*
1194+
* <p>from:
1195+
* LogicalUnion(all=[true])
1196+
* LogicalProject(C=[CAST($0):INTEGER NOT NULL])
1197+
* LogicalFilter(condition=[=($0, $cor0.DEPTNO)])
1198+
* LogicalTableScan(table=[[scott, DEPT]])
1199+
* LogicalProject(C=[2])
1200+
* LogicalFilter(condition=[=($1, $cor0.JOB)])
1201+
* LogicalTableScan(table=[[scott, BONUS]])
1202+
*
1203+
* <p>to:
1204+
* LogicalUnion(all=[true])
1205+
* LogicalProject(JOB=[$0], DEPTNO=[$1], C=[$2])
1206+
* LogicalJoin(condition=[IS NOT DISTINCT FROM($1, $3)], joinType=[inner])
1207+
* LogicalAggregate(group=[{0, 1}])
1208+
* LogicalProject(JOB=[$2], DEPTNO=[$7])
1209+
* LogicalTableScan(table=[[scott, EMP]])
1210+
* LogicalProject(C=[CAST($0):INTEGER NOT NULL], DEPTNO=[$0])
1211+
* LogicalTableScan(table=[[scott, DEPT]])
1212+
* LogicalProject(JOB=[$0], DEPTNO=[$1], C=[$2])
1213+
* LogicalJoin(condition=[IS NOT DISTINCT FROM($0, $3)], joinType=[inner])
1214+
* LogicalAggregate(group=[{0, 1}])
1215+
* LogicalProject(JOB=[$2], DEPTNO=[$7])
1216+
* LogicalTableScan(table=[[scott, EMP]])
1217+
* LogicalProject(C=[2], JOB=[$1])
1218+
* LogicalFilter(condition=[IS NOT NULL($1)])
1219+
* LogicalTableScan(table=[[scott, BONUS]])
1220+
*/
1221+
public @Nullable Frame decorrelateRel(SetOp rel, boolean isCorVarDefined,
1222+
boolean parentPropagatesNullValues) {
1223+
if (!isCorVarDefined) {
1224+
return decorrelateRel((RelNode) rel, false, parentPropagatesNullValues);
1225+
}
1226+
1227+
final CorelMap localCorelMap = new CorelMapBuilder().build(rel);
1228+
final List<CorRef> corVarList = new ArrayList<>(localCorelMap.mapRefRelToCorRef.values());
1229+
Collections.sort(corVarList);
1230+
1231+
final NavigableMap<CorDef, Integer> valueGenCorDefOutputs = new TreeMap<>();
1232+
final RelNode valueGen =
1233+
requireNonNull(createValueGenerator(corVarList, 0, valueGenCorDefOutputs));
1234+
final int valueGenFieldCount = valueGen.getRowType().getFieldCount();
1235+
// Original SetOp payload width.
1236+
final int payloadFieldCount = rel.getRowType().getFieldCount();
1237+
final List<RelNode> newInputs = new ArrayList<>();
1238+
final Map<Integer, Integer> setOpOldToNewOutputs = new HashMap<>();
1239+
final NavigableMap<CorDef, Integer> setOpCorDefOutputs = new TreeMap<>();
1240+
1241+
for (int i = 0; i < rel.getInputs().size(); i++) {
1242+
RelNode oldInput = rel.getInput(i);
1243+
Frame frame = getInvoke(oldInput, true, rel, parentPropagatesNullValues);
1244+
if (frame == null) {
1245+
// If input has not been rewritten, do not rewrite this rel.
1246+
return null;
1247+
}
1248+
1249+
// Build join conditions: for each CorDef of this branch that belongs
1250+
// to the current outFrameCorrId, equate valueGen(col) with branch(col).
1251+
final List<RexNode> conditions = new ArrayList<>();
1252+
for (Map.Entry<CorDef, Integer> e : frame.corDefOutputs.entrySet()) {
1253+
final CorDef corDef = e.getKey();
1254+
final int leftPos = requireNonNull(valueGenCorDefOutputs.get(corDef));
1255+
final int rightPos = e.getValue();
1256+
final RelDataType leftType = valueGen.getRowType().getFieldList().get(leftPos).getType();
1257+
final RelDataType rightType = frame.r.getRowType().getFieldList().get(rightPos).getType();
1258+
final RexNode leftRef = new RexInputRef(leftPos, leftType);
1259+
final RexNode rightRef = new RexInputRef(valueGenFieldCount + rightPos, rightType);
1260+
conditions.add(relBuilder.isNotDistinctFrom(leftRef, rightRef));
1261+
}
1262+
final RexNode joinCondition =
1263+
RexUtil.composeConjunction(relBuilder.getRexBuilder(), conditions);
1264+
RelNode join = relBuilder.push(valueGen).push(frame.r)
1265+
.join(JoinRelType.INNER, joinCondition).build();
1266+
1267+
final List<RelDataTypeField> joinFields = join.getRowType().getFieldList();
1268+
1269+
// Build the final projection for this branch:
1270+
// all correlated columns (from valueGen), original payload columns (from branch)
1271+
final PairList<RexNode, String> projects = PairList.of();
1272+
final Map<Integer, Integer> childOldToNew = new HashMap<>();
1273+
final NavigableMap<CorDef, Integer> childCorDefOutputs = new TreeMap<>();
1274+
1275+
// a) Correlated columns, in the order of valueGenCorDefOutputs.
1276+
int newPos = 0;
1277+
for (Map.Entry<CorDef, Integer> e : valueGenCorDefOutputs.entrySet()) {
1278+
final int srcIndex = e.getValue();
1279+
RexInputRef inputRef = RexInputRef.of(srcIndex, join.getRowType());
1280+
String name = joinFields.get(srcIndex).getName();
1281+
1282+
projects.add(inputRef, name);
1283+
childCorDefOutputs.put(e.getKey(), newPos);
1284+
newPos++;
1285+
}
1286+
1287+
// b) Original SetOp payload columns.
1288+
for (int oldIndex = 0; oldIndex < payloadFieldCount; oldIndex++) {
1289+
final Integer srcInFrame = requireNonNull(frame.oldToNewOutputs.get(oldIndex));
1290+
final int srcInJoin = valueGenFieldCount + srcInFrame;
1291+
RexInputRef inputRef = RexInputRef.of(srcInJoin, join.getRowType());
1292+
String name = joinFields.get(srcInJoin).getName();
1293+
1294+
projects.add(inputRef, name);
1295+
childOldToNew.put(oldIndex, newPos);
1296+
newPos++;
1297+
}
1298+
1299+
final RelNode newInput = relBuilder.push(join)
1300+
.projectNamed(projects.leftList(), projects.rightList(), true)
1301+
.build();
1302+
newInputs.add(newInput);
1303+
1304+
register(oldInput, newInput, childOldToNew, childCorDefOutputs);
1305+
1306+
// Use the first branch as prototype for the SetOp's frame mappings.
1307+
if (i == 0) {
1308+
setOpOldToNewOutputs.putAll(childOldToNew);
1309+
setOpCorDefOutputs.putAll(childCorDefOutputs);
1310+
}
1311+
}
1312+
1313+
final SetOp newSetOp = rel.copy(rel.getTraitSet(), newInputs, rel.all);
1314+
return register(rel, newSetOp, setOpOldToNewOutputs, setOpCorDefOutputs);
1315+
}
1316+
11871317
public @Nullable Frame decorrelateRel(LogicalProject rel, boolean isCorVarDefined,
11881318
boolean parentPropagatesNullValues) {
11891319
return decorrelateRel((Project) rel, isCorVarDefined, parentPropagatesNullValues);

0 commit comments

Comments
 (0)