Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 54 additions & 24 deletions core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java
Original file line number Diff line number Diff line change
Expand Up @@ -1543,17 +1543,17 @@ private RelNode getCorRel(CorRef corVar) {
/** Adds a value generator to satisfy the correlating variables used by
* a relational expression, if those variables are not already provided by
* its input. */
private Frame maybeAddValueGenerator(RelNode rel, Frame frame) {
final CorelMap cm1 = new CorelMapBuilder().build(frame.r, rel);
private Frame maybeAddValueGenerator(RelNode rel, Frame inputFrame) {
final CorelMap cm1 = new CorelMapBuilder().build(inputFrame.r, rel);
if (!cm1.mapRefRelToCorRef.containsKey(rel)) {
return frame;
return inputFrame;
}
final Collection<CorRef> needs = cm1.mapRefRelToCorRef.get(rel);
final ImmutableSortedSet<CorDef> haves = frame.corDefOutputs.keySet();
final ImmutableSortedSet<CorDef> haves = inputFrame.corDefOutputs.keySet();
if (hasAll(needs, haves)) {
return frame;
return inputFrame;
}
return decorrelateInputWithValueGenerator(rel, frame);
return decorrelateInputWithValueGenerator(rel, inputFrame);
}

/** Returns whether all of a collection of {@link CorRef}s are satisfied
Expand All @@ -1579,13 +1579,13 @@ private static boolean has(Collection<CorDef> corDefs, CorRef corr) {
return false;
}

private Frame decorrelateInputWithValueGenerator(RelNode rel, Frame frame) {
private Frame decorrelateInputWithValueGenerator(RelNode rel, Frame inputFrame) {
// currently only handles one input
assert rel.getInputs().size() == 1;
RelNode oldInput = frame.r;
RelNode oldInput = inputFrame.r;

final NavigableMap<CorDef, Integer> corDefOutputs =
new TreeMap<>(frame.corDefOutputs);
new TreeMap<>(inputFrame.corDefOutputs);

final Collection<CorRef> corVarList = cm.mapRefRelToCorRef.get(rel);

Expand All @@ -1606,16 +1606,15 @@ private Frame decorrelateInputWithValueGenerator(RelNode rel, Frame frame) {
if (node instanceof RexInputRef) {
map.put(def, ((RexInputRef) node).getIndex());
} else {
map.put(def,
frame.r.getRowType().getFieldCount() + projects.size());
map.put(def, inputFrame.r.getRowType().getFieldCount() + projects.size());
projects.add((RexNode) node);
}
}
}
// If all correlation variables are now satisfied, skip creating a value
// generator.
if (map.size() == corVarList.size()) {
map.putAll(frame.corDefOutputs);
map.putAll(inputFrame.corDefOutputs);
final RelNode r;
if (!projects.isEmpty()) {
relBuilder.push(oldInput)
Expand All @@ -1624,17 +1623,40 @@ private Frame decorrelateInputWithValueGenerator(RelNode rel, Frame frame) {
} else {
r = oldInput;
}
return register(rel.getInput(0), r,
frame.oldToNewOutputs, map);
return register(rel.getInput(0), r, inputFrame.oldToNewOutputs, map);
}
}

int leftInputOutputCount = frame.r.getRowType().getFieldCount();
return createFrameWithValueGenerator(rel.getInput(0), inputFrame, corVarList, corDefOutputs);
}

/**
* Creates a new {@link Frame} for the given rel by joining its current
* decorrelated rel with a value generator that produces the required
* correlation variables.
*
* <p>The value generator is built from {@code corVarList} and joined with
* {@code frame.r} using an INNER join. The provided
* {@code corDefOutputs} map is updated to reflect the positions of all
* correlation definitions in the join output, and the resulting frame is
* registered for {@code rel}.
*
* @param rel target RelNode whose frame is updated to use the join of
* {@code frame.r} and the value generator
* @param frame existing Frame of the rel
* @param corVarList correlated variables that still need to be produced
* @param corDefOutputs mapping from {@link CorDef} to output positions; updated in place
* to include positions in the new join
* @return a new Frame describing {@code rel} after attaching the value generator
*/
private Frame createFrameWithValueGenerator(RelNode rel, Frame frame,
Collection<CorRef> corVarList, NavigableMap<CorDef, Integer> corDefOutputs) {
int leftFieldCount = frame.r.getRowType().getFieldCount();

// can directly add positions into corDefOutputs since join
// does not change the output ordering from the inputs.
final RelNode valueGen =
createValueGenerator(corVarList, leftInputOutputCount, corDefOutputs);
createValueGenerator(corVarList, leftFieldCount, corDefOutputs);
requireNonNull(valueGen, "valueGen");

RelNode join =
Expand All @@ -1647,8 +1669,7 @@ private Frame decorrelateInputWithValueGenerator(RelNode rel, Frame frame) {
// Join or Filter does not change the old input ordering. All
// input fields from newLeftInput (i.e. the original input to the old
// Filter) are in the output and in the same position.
return register(rel.getInput(0), join, frame.oldToNewOutputs,
corDefOutputs);
return register(rel, join, frame.oldToNewOutputs, corDefOutputs);
}

/** Finds a {@link RexInputRef} that is equivalent to a {@link CorRef},
Expand Down Expand Up @@ -1931,8 +1952,19 @@ private static boolean isWidening(RelDataType type, RelDataType type1) {
return null;
}

Frame newLeftFrame = leftFrame;
boolean joinConditionContainsFieldAccess = RexUtil.containsFieldAccess(rel.getCondition());
if (joinConditionContainsFieldAccess && isCorVarDefined) {
final CorelMap localCorelMap = new CorelMapBuilder().build(rel);
final List<CorRef> corVarList = new ArrayList<>(localCorelMap.mapRefRelToCorRef.values());
Collections.sort(corVarList);

final NavigableMap<CorDef, Integer> corDefOutputs = new TreeMap<>();
newLeftFrame = createFrameWithValueGenerator(oldLeft, leftFrame, corVarList, corDefOutputs);
}

RelNode newJoin = relBuilder
.push(leftFrame.r)
.push(newLeftFrame.r)
.push(rightFrame.r)
.join(rel.getJoinType(),
decorrelateExpr(castNonNull(currentRel), map, cm, rel.getCondition()),
Expand All @@ -1944,25 +1976,23 @@ private static boolean isWidening(RelDataType type, RelDataType type1) {
Map<Integer, Integer> mapOldToNewOutputs = new HashMap<>();

int oldLeftFieldCount = oldLeft.getRowType().getFieldCount();
int newLeftFieldCount = leftFrame.r.getRowType().getFieldCount();
int newLeftFieldCount = newLeftFrame.r.getRowType().getFieldCount();

int oldRightFieldCount = oldRight.getRowType().getFieldCount();
//noinspection AssertWithSideEffects
assert rel.getRowType().getFieldCount()
== oldLeftFieldCount + oldRightFieldCount;

// Left input positions are not changed.
mapOldToNewOutputs.putAll(leftFrame.oldToNewOutputs);

mapOldToNewOutputs.putAll(newLeftFrame.oldToNewOutputs);
// Right input positions are shifted by newLeftFieldCount.
for (int i = 0; i < oldRightFieldCount; i++) {
mapOldToNewOutputs.put(i + oldLeftFieldCount,
requireNonNull(rightFrame.oldToNewOutputs.get(i)) + newLeftFieldCount);
}

final NavigableMap<CorDef, Integer> corDefOutputs =
new TreeMap<>(leftFrame.corDefOutputs);

new TreeMap<>(newLeftFrame.corDefOutputs);
// Right input positions are shifted by newLeftFieldCount.
for (Map.Entry<CorDef, Integer> entry
: rightFrame.corDefOutputs.entrySet()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1198,4 +1198,78 @@ public static Frameworks.ConfigBuilder config() {
+ " LogicalTableScan(table=[[scott, EMP]])\n";
assertThat(after, hasTree(planAfter));
}

/** Test case for <a href="https://issues.apache.org/jira/browse/CALCITE-7257">[CALCITE-7257]
* Subqueries cannot be decorrelated if join condition contains RexFieldAccess</a>. */
@Test void testJoinConditionContainsRexFieldAccess() {
final FrameworkConfig frameworkConfig = config().build();
final RelBuilder builder = RelBuilder.create(frameworkConfig);
final RelOptCluster cluster = builder.getCluster();
final Planner planner = Frameworks.getPlanner(frameworkConfig);
final String sql = ""
+ "SELECT E1.* \n"
+ "FROM\n"
+ " EMP E1\n"
+ "WHERE\n"
+ " E1.EMPNO = (\n"
+ " SELECT D1.DEPTNO FROM DEPT D1\n"
+ " WHERE E1.ENAME IN (SELECT B1.ENAME FROM BONUS B1))";
final RelNode originalRel;
try {
final SqlNode parse = planner.parse(sql);
final SqlNode validate = planner.validate(parse);
originalRel = planner.rel(validate).rel;
} catch (Exception e) {
throw TestUtil.rethrow(e);
}

final HepProgram hepProgram = HepProgram.builder()
.addRuleCollection(
ImmutableList.of(
// SubQuery program rules
CoreRules.FILTER_SUB_QUERY_TO_CORRELATE,
CoreRules.PROJECT_SUB_QUERY_TO_CORRELATE,
CoreRules.JOIN_SUB_QUERY_TO_CORRELATE))
.build();
final Program program =
Programs.of(hepProgram, true,
requireNonNull(cluster.getMetadataProvider()));
final RelNode before =
program.run(cluster.getPlanner(), originalRel, cluster.traitSet(),
Collections.emptyList(), Collections.emptyList());
final String planBefore = ""
+ "LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7])\n"
+ " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7])\n"
+ " LogicalFilter(condition=[=($0, CAST($8):SMALLINT)])\n"
+ " LogicalCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{1}])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n"
+ " LogicalAggregate(group=[{}], agg#0=[SINGLE_VALUE($0)])\n"
+ " LogicalProject(DEPTNO=[$0])\n"
+ " LogicalProject(DEPTNO=[$0], DNAME=[$1], LOC=[$2])\n"
+ " LogicalJoin(condition=[=($cor0.ENAME, $3)], joinType=[inner])\n"
+ " LogicalTableScan(table=[[scott, DEPT]])\n"
+ " LogicalProject(ENAME=[$0])\n"
+ " LogicalTableScan(table=[[scott, BONUS]])\n";
assertThat(before, hasTree(planBefore));

// Decorrelate without any rules, just "purely" decorrelation algorithm on RelDecorrelator
final RelNode after =
RelDecorrelator.decorrelateQuery(before, builder, RuleSets.ofList(Collections.emptyList()),
RuleSets.ofList(Collections.emptyList()));
final String planAfter = ""
+ "LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7])\n"
+ " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], ENAME0=[$8], $f1=[CAST($9):TINYINT])\n"
+ " LogicalJoin(condition=[AND(=($1, $8), =($0, CAST($9):SMALLINT))], joinType=[inner])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n"
+ " LogicalAggregate(group=[{0}], agg#0=[SINGLE_VALUE($1)])\n"
+ " LogicalProject(ENAME=[$3], DEPTNO=[$0])\n"
+ " LogicalJoin(condition=[=($3, $4)], joinType=[inner])\n"
+ " LogicalJoin(condition=[true], joinType=[inner])\n"
+ " LogicalTableScan(table=[[scott, DEPT]])\n"
+ " LogicalProject(ENAME=[$1])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n"
+ " LogicalProject(ENAME=[$0])\n"
+ " LogicalTableScan(table=[[scott, BONUS]])\n";
assertThat(after, hasTree(planAfter));
}
}
Loading
Loading