Skip to content

Commit c55cc3f

Browse files
[CALCITE-7257] Subqueries cannot be decorrelated if join condition contains RexFieldAccess
1 parent e4b2baa commit c55cc3f

File tree

3 files changed

+490
-24
lines changed

3 files changed

+490
-24
lines changed

core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java

Lines changed: 54 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1384,17 +1384,17 @@ private RelNode getCorRel(CorRef corVar) {
13841384
/** Adds a value generator to satisfy the correlating variables used by
13851385
* a relational expression, if those variables are not already provided by
13861386
* its input. */
1387-
private Frame maybeAddValueGenerator(RelNode rel, Frame frame) {
1388-
final CorelMap cm1 = new CorelMapBuilder().build(frame.r, rel);
1387+
private Frame maybeAddValueGenerator(RelNode rel, Frame inputFrame) {
1388+
final CorelMap cm1 = new CorelMapBuilder().build(inputFrame.r, rel);
13891389
if (!cm1.mapRefRelToCorRef.containsKey(rel)) {
1390-
return frame;
1390+
return inputFrame;
13911391
}
13921392
final Collection<CorRef> needs = cm1.mapRefRelToCorRef.get(rel);
1393-
final ImmutableSortedSet<CorDef> haves = frame.corDefOutputs.keySet();
1393+
final ImmutableSortedSet<CorDef> haves = inputFrame.corDefOutputs.keySet();
13941394
if (hasAll(needs, haves)) {
1395-
return frame;
1395+
return inputFrame;
13961396
}
1397-
return decorrelateInputWithValueGenerator(rel, frame);
1397+
return decorrelateInputWithValueGenerator(rel, inputFrame);
13981398
}
13991399

14001400
/** Returns whether all of a collection of {@link CorRef}s are satisfied
@@ -1420,13 +1420,13 @@ private static boolean has(Collection<CorDef> corDefs, CorRef corr) {
14201420
return false;
14211421
}
14221422

1423-
private Frame decorrelateInputWithValueGenerator(RelNode rel, Frame frame) {
1423+
private Frame decorrelateInputWithValueGenerator(RelNode rel, Frame inputFrame) {
14241424
// currently only handles one input
14251425
assert rel.getInputs().size() == 1;
1426-
RelNode oldInput = frame.r;
1426+
RelNode oldInput = inputFrame.r;
14271427

14281428
final NavigableMap<CorDef, Integer> corDefOutputs =
1429-
new TreeMap<>(frame.corDefOutputs);
1429+
new TreeMap<>(inputFrame.corDefOutputs);
14301430

14311431
final Collection<CorRef> corVarList = cm.mapRefRelToCorRef.get(rel);
14321432

@@ -1447,16 +1447,15 @@ private Frame decorrelateInputWithValueGenerator(RelNode rel, Frame frame) {
14471447
if (node instanceof RexInputRef) {
14481448
map.put(def, ((RexInputRef) node).getIndex());
14491449
} else {
1450-
map.put(def,
1451-
frame.r.getRowType().getFieldCount() + projects.size());
1450+
map.put(def, inputFrame.r.getRowType().getFieldCount() + projects.size());
14521451
projects.add((RexNode) node);
14531452
}
14541453
}
14551454
}
14561455
// If all correlation variables are now satisfied, skip creating a value
14571456
// generator.
14581457
if (map.size() == corVarList.size()) {
1459-
map.putAll(frame.corDefOutputs);
1458+
map.putAll(inputFrame.corDefOutputs);
14601459
final RelNode r;
14611460
if (!projects.isEmpty()) {
14621461
relBuilder.push(oldInput)
@@ -1465,17 +1464,40 @@ private Frame decorrelateInputWithValueGenerator(RelNode rel, Frame frame) {
14651464
} else {
14661465
r = oldInput;
14671466
}
1468-
return register(rel.getInput(0), r,
1469-
frame.oldToNewOutputs, map);
1467+
return register(rel.getInput(0), r, inputFrame.oldToNewOutputs, map);
14701468
}
14711469
}
14721470

1473-
int leftInputOutputCount = frame.r.getRowType().getFieldCount();
1471+
return createFrameWithValueGenerator(rel.getInput(0), inputFrame, corVarList, corDefOutputs);
1472+
}
1473+
1474+
/**
1475+
* Creates a new {@link Frame} for the given rel by joining its current
1476+
* decorrelated rel with a value generator that produces the required
1477+
* correlation variables.
1478+
*
1479+
* <p>The value generator is built from {@code corVarList} and joined with
1480+
* {@code frame.r} using an INNER join. The provided
1481+
* {@code corDefOutputs} map is updated to reflect the positions of all
1482+
* correlation definitions in the join output, and the resulting frame is
1483+
* registered for {@code rel}.
1484+
*
1485+
* @param rel target RelNode whose frame is updated to use the join of
1486+
* {@code frame.r} and the value generator
1487+
* @param frame existing Frame of the rel
1488+
* @param corVarList correlated variables that still need to be produced
1489+
* @param corDefOutputs mapping from {@link CorDef} to output positions; updated in place
1490+
* to include positions in the new join
1491+
* @return a new Frame describing {@code rel} after attaching the value generator
1492+
*/
1493+
private Frame createFrameWithValueGenerator(RelNode rel, Frame frame,
1494+
Collection<CorRef> corVarList, NavigableMap<CorDef, Integer> corDefOutputs) {
1495+
int leftFieldCount = frame.r.getRowType().getFieldCount();
14741496

14751497
// can directly add positions into corDefOutputs since join
14761498
// does not change the output ordering from the inputs.
14771499
final RelNode valueGen =
1478-
createValueGenerator(corVarList, leftInputOutputCount, corDefOutputs);
1500+
createValueGenerator(corVarList, leftFieldCount, corDefOutputs);
14791501
requireNonNull(valueGen, "valueGen");
14801502

14811503
RelNode join =
@@ -1488,8 +1510,7 @@ private Frame decorrelateInputWithValueGenerator(RelNode rel, Frame frame) {
14881510
// Join or Filter does not change the old input ordering. All
14891511
// input fields from newLeftInput (i.e. the original input to the old
14901512
// Filter) are in the output and in the same position.
1491-
return register(rel.getInput(0), join, frame.oldToNewOutputs,
1492-
corDefOutputs);
1513+
return register(rel, join, frame.oldToNewOutputs, corDefOutputs);
14931514
}
14941515

14951516
/** Finds a {@link RexInputRef} that is equivalent to a {@link CorRef},
@@ -1772,8 +1793,19 @@ private static boolean isWidening(RelDataType type, RelDataType type1) {
17721793
return null;
17731794
}
17741795

1796+
Frame newLeftFrame = leftFrame;
1797+
boolean joinConditionContainsFieldAccess = RexUtil.containsFieldAccess(rel.getCondition());
1798+
if (joinConditionContainsFieldAccess && isCorVarDefined) {
1799+
final CorelMap localCorelMap = new CorelMapBuilder().build(rel);
1800+
final List<CorRef> corVarList = new ArrayList<>(localCorelMap.mapRefRelToCorRef.values());
1801+
Collections.sort(corVarList);
1802+
1803+
final NavigableMap<CorDef, Integer> corDefOutputs = new TreeMap<>();
1804+
newLeftFrame = createFrameWithValueGenerator(oldLeft, leftFrame, corVarList, corDefOutputs);
1805+
}
1806+
17751807
RelNode newJoin = relBuilder
1776-
.push(leftFrame.r)
1808+
.push(newLeftFrame.r)
17771809
.push(rightFrame.r)
17781810
.join(rel.getJoinType(),
17791811
decorrelateExpr(castNonNull(currentRel), map, cm, rel.getCondition()),
@@ -1785,25 +1817,23 @@ private static boolean isWidening(RelDataType type, RelDataType type1) {
17851817
Map<Integer, Integer> mapOldToNewOutputs = new HashMap<>();
17861818

17871819
int oldLeftFieldCount = oldLeft.getRowType().getFieldCount();
1788-
int newLeftFieldCount = leftFrame.r.getRowType().getFieldCount();
1820+
int newLeftFieldCount = newLeftFrame.r.getRowType().getFieldCount();
17891821

17901822
int oldRightFieldCount = oldRight.getRowType().getFieldCount();
17911823
//noinspection AssertWithSideEffects
17921824
assert rel.getRowType().getFieldCount()
17931825
== oldLeftFieldCount + oldRightFieldCount;
17941826

17951827
// Left input positions are not changed.
1796-
mapOldToNewOutputs.putAll(leftFrame.oldToNewOutputs);
1797-
1828+
mapOldToNewOutputs.putAll(newLeftFrame.oldToNewOutputs);
17981829
// Right input positions are shifted by newLeftFieldCount.
17991830
for (int i = 0; i < oldRightFieldCount; i++) {
18001831
mapOldToNewOutputs.put(i + oldLeftFieldCount,
18011832
requireNonNull(rightFrame.oldToNewOutputs.get(i)) + newLeftFieldCount);
18021833
}
18031834

18041835
final NavigableMap<CorDef, Integer> corDefOutputs =
1805-
new TreeMap<>(leftFrame.corDefOutputs);
1806-
1836+
new TreeMap<>(newLeftFrame.corDefOutputs);
18071837
// Right input positions are shifted by newLeftFieldCount.
18081838
for (Map.Entry<CorDef, Integer> entry
18091839
: rightFrame.corDefOutputs.entrySet()) {

core/src/test/java/org/apache/calcite/sql2rel/RelDecorrelatorTest.java

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,4 +584,78 @@ public static Frameworks.ConfigBuilder config() {
584584
+ " LogicalTableScan(table=[[scott, DEPT]])\n";
585585
assertThat(decorrelatedNoRules, hasTree(planDecorrelatedNoRules));
586586
}
587+
588+
/** Test case for <a href="https://issues.apache.org/jira/browse/CALCITE-7257">[CALCITE-7257]
589+
* Subqueries cannot be decorrelated if join condition contains RexFieldAccess</a>. */
590+
@Test void testJoinConditionContainsRexFieldAccess() {
591+
final FrameworkConfig frameworkConfig = config().build();
592+
final RelBuilder builder = RelBuilder.create(frameworkConfig);
593+
final RelOptCluster cluster = builder.getCluster();
594+
final Planner planner = Frameworks.getPlanner(frameworkConfig);
595+
final String sql = ""
596+
+ "SELECT E1.* \n"
597+
+ "FROM\n"
598+
+ " EMP E1\n"
599+
+ "WHERE\n"
600+
+ " E1.EMPNO = (\n"
601+
+ " SELECT D1.DEPTNO FROM DEPT D1\n"
602+
+ " WHERE E1.ENAME IN (SELECT B1.ENAME FROM BONUS B1))";
603+
final RelNode originalRel;
604+
try {
605+
final SqlNode parse = planner.parse(sql);
606+
final SqlNode validate = planner.validate(parse);
607+
originalRel = planner.rel(validate).rel;
608+
} catch (Exception e) {
609+
throw TestUtil.rethrow(e);
610+
}
611+
612+
final HepProgram hepProgram = HepProgram.builder()
613+
.addRuleCollection(
614+
ImmutableList.of(
615+
// SubQuery program rules
616+
CoreRules.FILTER_SUB_QUERY_TO_CORRELATE,
617+
CoreRules.PROJECT_SUB_QUERY_TO_CORRELATE,
618+
CoreRules.JOIN_SUB_QUERY_TO_CORRELATE))
619+
.build();
620+
final Program program =
621+
Programs.of(hepProgram, true,
622+
requireNonNull(cluster.getMetadataProvider()));
623+
final RelNode before =
624+
program.run(cluster.getPlanner(), originalRel, cluster.traitSet(),
625+
Collections.emptyList(), Collections.emptyList());
626+
final String planBefore = ""
627+
+ "LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7])\n"
628+
+ " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7])\n"
629+
+ " LogicalFilter(condition=[=($0, CAST($8):SMALLINT)])\n"
630+
+ " LogicalCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{1}])\n"
631+
+ " LogicalTableScan(table=[[scott, EMP]])\n"
632+
+ " LogicalAggregate(group=[{}], agg#0=[SINGLE_VALUE($0)])\n"
633+
+ " LogicalProject(DEPTNO=[$0])\n"
634+
+ " LogicalProject(DEPTNO=[$0], DNAME=[$1], LOC=[$2])\n"
635+
+ " LogicalJoin(condition=[=($cor0.ENAME, $3)], joinType=[inner])\n"
636+
+ " LogicalTableScan(table=[[scott, DEPT]])\n"
637+
+ " LogicalProject(ENAME=[$0])\n"
638+
+ " LogicalTableScan(table=[[scott, BONUS]])\n";
639+
assertThat(before, hasTree(planBefore));
640+
641+
// Decorrelate without any rules, just "purely" decorrelation algorithm on RelDecorrelator
642+
final RelNode after =
643+
RelDecorrelator.decorrelateQuery(before, builder, RuleSets.ofList(Collections.emptyList()),
644+
RuleSets.ofList(Collections.emptyList()));
645+
final String planAfter = ""
646+
+ "LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7])\n"
647+
+ " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], ENAME0=[$8], $f1=[CAST($9):TINYINT])\n"
648+
+ " LogicalJoin(condition=[AND(=($1, $8), =($0, CAST($9):SMALLINT))], joinType=[inner])\n"
649+
+ " LogicalTableScan(table=[[scott, EMP]])\n"
650+
+ " LogicalAggregate(group=[{0}], agg#0=[SINGLE_VALUE($1)])\n"
651+
+ " LogicalProject(ENAME=[$3], DEPTNO=[$0])\n"
652+
+ " LogicalJoin(condition=[=($3, $4)], joinType=[inner])\n"
653+
+ " LogicalJoin(condition=[true], joinType=[inner])\n"
654+
+ " LogicalTableScan(table=[[scott, DEPT]])\n"
655+
+ " LogicalProject(ENAME=[$1])\n"
656+
+ " LogicalTableScan(table=[[scott, EMP]])\n"
657+
+ " LogicalProject(ENAME=[$0])\n"
658+
+ " LogicalTableScan(table=[[scott, BONUS]])\n";
659+
assertThat(after, hasTree(planAfter));
660+
}
587661
}

0 commit comments

Comments
 (0)