|
44 | 44 | import org.apache.calcite.rel.core.JoinRelType; |
45 | 45 | import org.apache.calcite.rel.core.Project; |
46 | 46 | import org.apache.calcite.rel.core.RelFactories; |
| 47 | +import org.apache.calcite.rel.core.SetOp; |
47 | 48 | import org.apache.calcite.rel.core.Sort; |
48 | 49 | import org.apache.calcite.rel.logical.LogicalAggregate; |
49 | 50 | import org.apache.calcite.rel.logical.LogicalCorrelate; |
@@ -817,8 +818,7 @@ protected RexNode removeCorrelationExpr( |
817 | 818 | } |
818 | 819 | } |
819 | 820 |
|
820 | | - if (rel.getGroupType() == Aggregate.Group.SIMPLE |
821 | | - && rel.getGroupSet().isEmpty() |
| 821 | + if ((rel.hasEmptyGroup() || rel.getGroupSet().isEmpty()) |
822 | 822 | && !frame.corDefOutputs.isEmpty() |
823 | 823 | && !parentPropagatesNullValues) { |
824 | 824 | newRel = rewriteScalarAggregate(rel, newRel, outputMap, corDefOutputs); |
@@ -930,71 +930,63 @@ private RelNode rewriteScalarAggregate(Aggregate oldRel, |
930 | 930 | RelNode newRel, |
931 | 931 | Map<Integer, Integer> outputMap, |
932 | 932 | NavigableMap<CorDef, Integer> corDefOutputs) { |
933 | | - final Pair<CorrelationId, Frame> outerFramePair = requireNonNull(this.frameStack.peek()); |
934 | | - final Frame outFrame = outerFramePair.right; |
935 | | - RexBuilder rexBuilder = relBuilder.getRexBuilder(); |
| 933 | + final CorelMap localCorelMap = new CorelMapBuilder().build(oldRel); |
| 934 | + final List<CorRef> corVarList = new ArrayList<>(localCorelMap.mapRefRelToCorRef.values()); |
| 935 | + Collections.sort(corVarList); |
936 | 936 |
|
937 | | - int groupKeySize = (int) corDefOutputs.keySet().stream() |
938 | | - .filter(a -> a.corr.equals(outerFramePair.left)) |
939 | | - .count(); |
940 | | - List<RelDataTypeField> newRelFields = newRel.getRowType().getFieldList(); |
941 | | - ImmutableBitSet.Builder corFieldBuilder = ImmutableBitSet.builder(); |
| 937 | + final NavigableMap<CorDef, Integer> valueGenCorDefOutputs = new TreeMap<>(); |
| 938 | + final RelNode valueGen = |
| 939 | + requireNonNull(createValueGenerator(corVarList, 0, valueGenCorDefOutputs)); |
| 940 | + final int valueGenFieldCount = valueGen.getRowType().getFieldCount(); |
942 | 941 |
|
943 | | - // Here we record the mapping between the original index and the new project. |
944 | | - // For the count, we map it as `case when x is null then 0 else x`. |
| 942 | + // Build join conditions |
945 | 943 | final Map<Integer, RexNode> newProjectMap = new HashMap<>(); |
946 | 944 | final List<RexNode> conditions = new ArrayList<>(); |
947 | 945 | for (Map.Entry<CorDef, Integer> corDefOutput : corDefOutputs.entrySet()) { |
948 | | - CorDef corDef = corDefOutput.getKey(); |
949 | | - Integer corIndex = corDefOutput.getValue(); |
950 | | - if (corDef.corr.equals(outerFramePair.left)) { |
951 | | - int newIdx = requireNonNull(outFrame.oldToNewOutputs.get(corDef.field)); |
952 | | - corFieldBuilder.set(newIdx); |
953 | | - |
954 | | - RelDataType type = outFrame.r.getRowType().getFieldList().get(newIdx).getType(); |
955 | | - RexNode left = new RexInputRef(corFieldBuilder.cardinality() - 1, type); |
956 | | - newProjectMap.put(corIndex + groupKeySize, left); |
957 | | - conditions.add( |
958 | | - relBuilder.isNotDistinctFrom(left, |
959 | | - new RexInputRef(corIndex + groupKeySize, |
960 | | - newRelFields.get(corIndex).getType()))); |
961 | | - } |
962 | | - } |
963 | | - |
964 | | - ImmutableBitSet groupSet = corFieldBuilder.build(); |
965 | | - // Build [09] LogicalAggregate(group=[{0}]) to obtain the distinct set of |
966 | | - // corVar from outFrame. |
967 | | - relBuilder.push(outFrame.r) |
968 | | - .aggregate(relBuilder.groupKey(groupSet)); |
| 946 | + final CorDef corDef = corDefOutput.getKey(); |
| 947 | + final int leftPos = requireNonNull(valueGenCorDefOutputs.get(corDef)); |
| 948 | + final int rightPos = corDefOutput.getValue(); |
| 949 | + final RelDataType leftType = valueGen.getRowType().getFieldList().get(leftPos).getType(); |
| 950 | + final RelDataType rightType = newRel.getRowType().getFieldList().get(rightPos).getType(); |
| 951 | + final RexNode leftRef = new RexInputRef(leftPos, leftType); |
| 952 | + final RexNode rightRef = new RexInputRef(valueGenFieldCount + rightPos, rightType); |
| 953 | + conditions.add(relBuilder.isNotDistinctFrom(leftRef, rightRef)); |
| 954 | + newProjectMap.put(valueGenFieldCount + rightPos, leftRef); |
| 955 | + } |
| 956 | + final RexNode joinCond = RexUtil.composeConjunction(relBuilder.getRexBuilder(), conditions); |
969 | 957 |
|
970 | 958 | // Build [08] LogicalJoin(condition=[IS NOT DISTINCT FROM($0, $1)], joinType=[left]) |
971 | 959 | // to ensure each corVar's aggregate result is output. |
972 | | - final RelNode join = relBuilder.push(newRel) |
973 | | - .join(JoinRelType.LEFT, conditions).build(); |
| 960 | + final RelNode join = relBuilder.push(valueGen).push(newRel) |
| 961 | + .join(JoinRelType.LEFT, joinCond).build(); |
| 962 | + RelDataType joinRowType = join.getRowType(); |
974 | 963 |
|
| 964 | + RexBuilder rexBuilder = relBuilder.getRexBuilder(); |
| 965 | + // Here we record the mapping between the original index and the new project. |
| 966 | + // For the count, we map it as `case when x is null then 0 else x`. |
975 | 967 | for (int i1 = 0; i1 < oldRel.getAggCallList().size(); i1++) { |
976 | 968 | AggregateCall aggCall = oldRel.getAggCallList().get(i1); |
977 | 969 | if (aggCall.getAggregation() instanceof SqlCountAggFunction) { |
978 | 970 | int index = requireNonNull(outputMap.get(i1 + oldRel.getGroupSet().size())); |
979 | | - final RexInputRef ref = RexInputRef.of(index + groupKeySize, join.getRowType()); |
980 | | - RexNode specificCountValue = |
981 | | - rexBuilder.makeCall(SqlStdOperatorTable.CASE, |
982 | | - ImmutableList.of(relBuilder.isNotNull(ref), ref, relBuilder.literal(0))); |
| 971 | + final RexInputRef ref = RexInputRef.of(index + valueGenFieldCount, joinRowType); |
| 972 | + ImmutableList<RexNode> exprs = |
| 973 | + ImmutableList.of(relBuilder.isNotNull(ref), ref, relBuilder.literal(0)); |
| 974 | + RexNode specificCountValue = rexBuilder.makeCall(SqlStdOperatorTable.CASE, exprs); |
983 | 975 | newProjectMap.put(ref.getIndex(), specificCountValue); |
984 | 976 | } |
985 | 977 | } |
986 | 978 |
|
| 979 | + // Build [07] LogicalProject(DEPTNO=[$0], EXPR$0=[CASE(IS NOT NULL($2), $2, 0)]) |
| 980 | + // to handle COUNT function by converting nulls to zero. |
987 | 981 | final List<RexNode> newProjects = new ArrayList<>(); |
988 | | - for (int index : ImmutableBitSet.range(groupKeySize, join.getRowType().getFieldCount())) { |
| 982 | + for (int index : ImmutableBitSet.range(valueGenFieldCount, joinRowType.getFieldCount())) { |
989 | 983 | if (newProjectMap.containsKey(index)) { |
990 | 984 | newProjects.add(requireNonNull(newProjectMap.get(index))); |
991 | 985 | } else { |
992 | | - newProjects.add(RexInputRef.of(index, join.getRowType())); |
| 986 | + newProjects.add(RexInputRef.of(index, joinRowType)); |
993 | 987 | } |
994 | 988 | } |
995 | 989 |
|
996 | | - // Build [07] LogicalProject(DEPTNO=[$0], EXPR$0=[CASE(IS NOT NULL($2), $2, 0)]) |
997 | | - // to handle COUNT function by converting nulls to zero. |
998 | 990 | return relBuilder.push(join) |
999 | 991 | .project(newProjects, newRel.getRowType().getFieldNames()) |
1000 | 992 | .build(); |
@@ -1184,6 +1176,144 @@ private static void shiftMapping(Map<Integer, Integer> mapping, int startIndex, |
1184 | 1176 | return null; |
1185 | 1177 | } |
1186 | 1178 |
|
| 1179 | + /** |
| 1180 | + * Given the SQL: |
| 1181 | + * SELECT ename, |
| 1182 | + * (SELECT sum(c) |
| 1183 | + * FROM |
| 1184 | + * (SELECT deptno AS c |
| 1185 | + * FROM dept |
| 1186 | + * WHERE dept.deptno = emp.deptno |
| 1187 | + * UNION ALL |
| 1188 | + * SELECT 2 AS c |
| 1189 | + * FROM bonus |
| 1190 | + * WHERE bonus.job = emp.job) AS union_subquery |
| 1191 | + * ) AS correlated_sum |
| 1192 | + * FROM emp; |
| 1193 | + * |
| 1194 | + * <p>from: |
| 1195 | + * LogicalUnion(all=[true]) |
| 1196 | + * LogicalProject(C=[CAST($0):INTEGER NOT NULL]) |
| 1197 | + * LogicalFilter(condition=[=($0, $cor0.DEPTNO)]) |
| 1198 | + * LogicalTableScan(table=[[scott, DEPT]]) |
| 1199 | + * LogicalProject(C=[2]) |
| 1200 | + * LogicalFilter(condition=[=($1, $cor0.JOB)]) |
| 1201 | + * LogicalTableScan(table=[[scott, BONUS]]) |
| 1202 | + * |
| 1203 | + * <p>to: |
| 1204 | + * LogicalUnion(all=[true]) |
| 1205 | + * LogicalProject(JOB=[$0], DEPTNO=[$1], C=[$2]) |
| 1206 | + * LogicalJoin(condition=[IS NOT DISTINCT FROM($1, $3)], joinType=[inner]) |
| 1207 | + * LogicalAggregate(group=[{0, 1}]) |
| 1208 | + * LogicalProject(JOB=[$2], DEPTNO=[$7]) |
| 1209 | + * LogicalTableScan(table=[[scott, EMP]]) |
| 1210 | + * LogicalProject(C=[CAST($0):INTEGER NOT NULL], DEPTNO=[$0]) |
| 1211 | + * LogicalTableScan(table=[[scott, DEPT]]) |
| 1212 | + * LogicalProject(JOB=[$0], DEPTNO=[$1], C=[$2]) |
| 1213 | + * LogicalJoin(condition=[IS NOT DISTINCT FROM($0, $3)], joinType=[inner]) |
| 1214 | + * LogicalAggregate(group=[{0, 1}]) |
| 1215 | + * LogicalProject(JOB=[$2], DEPTNO=[$7]) |
| 1216 | + * LogicalTableScan(table=[[scott, EMP]]) |
| 1217 | + * LogicalProject(C=[2], JOB=[$1]) |
| 1218 | + * LogicalFilter(condition=[IS NOT NULL($1)]) |
| 1219 | + * LogicalTableScan(table=[[scott, BONUS]]) |
| 1220 | + */ |
| 1221 | + public @Nullable Frame decorrelateRel(SetOp rel, boolean isCorVarDefined, |
| 1222 | + boolean parentPropagatesNullValues) { |
| 1223 | + if (!isCorVarDefined) { |
| 1224 | + return decorrelateRel((RelNode) rel, false, parentPropagatesNullValues); |
| 1225 | + } |
| 1226 | + |
| 1227 | + final CorelMap localCorelMap = new CorelMapBuilder().build(rel); |
| 1228 | + final List<CorRef> corVarList = new ArrayList<>(localCorelMap.mapRefRelToCorRef.values()); |
| 1229 | + Collections.sort(corVarList); |
| 1230 | + |
| 1231 | + final NavigableMap<CorDef, Integer> valueGenCorDefOutputs = new TreeMap<>(); |
| 1232 | + final RelNode valueGen = |
| 1233 | + requireNonNull(createValueGenerator(corVarList, 0, valueGenCorDefOutputs)); |
| 1234 | + final int valueGenFieldCount = valueGen.getRowType().getFieldCount(); |
| 1235 | + // Original SetOp payload width. |
| 1236 | + final int payloadFieldCount = rel.getRowType().getFieldCount(); |
| 1237 | + final List<RelNode> newInputs = new ArrayList<>(); |
| 1238 | + final Map<Integer, Integer> setOpOldToNewOutputs = new HashMap<>(); |
| 1239 | + final NavigableMap<CorDef, Integer> setOpCorDefOutputs = new TreeMap<>(); |
| 1240 | + |
| 1241 | + for (int i = 0; i < rel.getInputs().size(); i++) { |
| 1242 | + RelNode oldInput = rel.getInput(i); |
| 1243 | + Frame frame = getInvoke(oldInput, true, rel, parentPropagatesNullValues); |
| 1244 | + if (frame == null) { |
| 1245 | + // If input has not been rewritten, do not rewrite this rel. |
| 1246 | + return null; |
| 1247 | + } |
| 1248 | + |
| 1249 | + // Build join conditions: for each CorDef of this branch that belongs |
| 1250 | + // to the current outFrameCorrId, equate valueGen(col) with branch(col). |
| 1251 | + final List<RexNode> conditions = new ArrayList<>(); |
| 1252 | + for (Map.Entry<CorDef, Integer> e : frame.corDefOutputs.entrySet()) { |
| 1253 | + final CorDef corDef = e.getKey(); |
| 1254 | + final int leftPos = requireNonNull(valueGenCorDefOutputs.get(corDef)); |
| 1255 | + final int rightPos = e.getValue(); |
| 1256 | + final RelDataType leftType = valueGen.getRowType().getFieldList().get(leftPos).getType(); |
| 1257 | + final RelDataType rightType = frame.r.getRowType().getFieldList().get(rightPos).getType(); |
| 1258 | + final RexNode leftRef = new RexInputRef(leftPos, leftType); |
| 1259 | + final RexNode rightRef = new RexInputRef(valueGenFieldCount + rightPos, rightType); |
| 1260 | + conditions.add(relBuilder.isNotDistinctFrom(leftRef, rightRef)); |
| 1261 | + } |
| 1262 | + final RexNode joinCondition = |
| 1263 | + RexUtil.composeConjunction(relBuilder.getRexBuilder(), conditions); |
| 1264 | + RelNode join = relBuilder.push(valueGen).push(frame.r) |
| 1265 | + .join(JoinRelType.INNER, joinCondition).build(); |
| 1266 | + |
| 1267 | + final List<RelDataTypeField> joinFields = join.getRowType().getFieldList(); |
| 1268 | + |
| 1269 | + // Build the final projection for this branch: |
| 1270 | + // all correlated columns (from valueGen), original payload columns (from branch) |
| 1271 | + final PairList<RexNode, String> projects = PairList.of(); |
| 1272 | + final Map<Integer, Integer> childOldToNew = new HashMap<>(); |
| 1273 | + final NavigableMap<CorDef, Integer> childCorDefOutputs = new TreeMap<>(); |
| 1274 | + |
| 1275 | + // a) Correlated columns, in the order of valueGenCorDefOutputs. |
| 1276 | + int newPos = 0; |
| 1277 | + for (Map.Entry<CorDef, Integer> e : valueGenCorDefOutputs.entrySet()) { |
| 1278 | + final int srcIndex = e.getValue(); |
| 1279 | + RexInputRef inputRef = RexInputRef.of(srcIndex, join.getRowType()); |
| 1280 | + String name = joinFields.get(srcIndex).getName(); |
| 1281 | + |
| 1282 | + projects.add(inputRef, name); |
| 1283 | + childCorDefOutputs.put(e.getKey(), newPos); |
| 1284 | + newPos++; |
| 1285 | + } |
| 1286 | + |
| 1287 | + // b) Original SetOp payload columns. |
| 1288 | + for (int oldIndex = 0; oldIndex < payloadFieldCount; oldIndex++) { |
| 1289 | + final Integer srcInFrame = requireNonNull(frame.oldToNewOutputs.get(oldIndex)); |
| 1290 | + final int srcInJoin = valueGenFieldCount + srcInFrame; |
| 1291 | + RexInputRef inputRef = RexInputRef.of(srcInJoin, join.getRowType()); |
| 1292 | + String name = joinFields.get(srcInJoin).getName(); |
| 1293 | + |
| 1294 | + projects.add(inputRef, name); |
| 1295 | + childOldToNew.put(oldIndex, newPos); |
| 1296 | + newPos++; |
| 1297 | + } |
| 1298 | + |
| 1299 | + final RelNode newInput = relBuilder.push(join) |
| 1300 | + .projectNamed(projects.leftList(), projects.rightList(), true) |
| 1301 | + .build(); |
| 1302 | + newInputs.add(newInput); |
| 1303 | + |
| 1304 | + register(oldInput, newInput, childOldToNew, childCorDefOutputs); |
| 1305 | + |
| 1306 | + // Use the first branch as prototype for the SetOp's frame mappings. |
| 1307 | + if (i == 0) { |
| 1308 | + setOpOldToNewOutputs.putAll(childOldToNew); |
| 1309 | + setOpCorDefOutputs.putAll(childCorDefOutputs); |
| 1310 | + } |
| 1311 | + } |
| 1312 | + |
| 1313 | + final SetOp newSetOp = rel.copy(rel.getTraitSet(), newInputs, rel.all); |
| 1314 | + return register(rel, newSetOp, setOpOldToNewOutputs, setOpCorDefOutputs); |
| 1315 | + } |
| 1316 | + |
1187 | 1317 | public @Nullable Frame decorrelateRel(LogicalProject rel, boolean isCorVarDefined, |
1188 | 1318 | boolean parentPropagatesNullValues) { |
1189 | 1319 | return decorrelateRel((Project) rel, isCorVarDefined, parentPropagatesNullValues); |
|
0 commit comments