Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 23 additions & 12 deletions core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java
Original file line number Diff line number Diff line change
Expand Up @@ -954,16 +954,12 @@ private RelNode rewriteScalarAggregate(Aggregate oldRel,
RelNode newRel,
Map<Integer, Integer> outputMap,
NavigableMap<CorDef, Integer> corDefOutputs) {
final CorelMap localCorelMap = new CorelMapBuilder().build(oldRel);
final List<CorRef> corVarList = new ArrayList<>(localCorelMap.mapRefRelToCorRef.values());
Collections.sort(corVarList);

final List<CorRef> corVarList = collectExternalCorVars(oldRel);
final NavigableMap<CorDef, Integer> valueGenCorDefOutputs = new TreeMap<>();
final RelNode valueGen =
requireNonNull(createValueGenerator(corVarList, 0, valueGenCorDefOutputs));
final int valueGenFieldCount = valueGen.getRowType().getFieldCount();

// Build join conditions
final Map<Integer, RexNode> newProjectMap = new HashMap<>();
for (Map.Entry<CorDef, Integer> corDefOutput : corDefOutputs.entrySet()) {
final CorDef corDef = corDefOutput.getKey();
Expand All @@ -974,6 +970,7 @@ private RelNode rewriteScalarAggregate(Aggregate oldRel,
newProjectMap.put(valueGenFieldCount + rightPos, leftRef);
}

// Build join conditions
final List<RexNode> conditions =
buildCorDefJoinConditions(valueGenCorDefOutputs, corDefOutputs,
valueGen, newRel, relBuilder);
Expand Down Expand Up @@ -1260,10 +1257,7 @@ private static void shiftMapping(Map<Integer, Integer> mapping, int startIndex,
return decorrelateRel((RelNode) rel, false, parentPropagatesNullValues);
}

final CorelMap localCorelMap = new CorelMapBuilder().build(rel);
final List<CorRef> corVarList = new ArrayList<>(localCorelMap.mapRefRelToCorRef.values());
Collections.sort(corVarList);

final List<CorRef> corVarList = collectExternalCorVars(rel);
final NavigableMap<CorDef, Integer> valueGenCorDefOutputs = new TreeMap<>();
final RelNode valueGen =
requireNonNull(createValueGenerator(corVarList, 0, valueGenCorDefOutputs));
Expand Down Expand Up @@ -1958,9 +1952,7 @@ private static boolean isWidening(RelDataType type, RelDataType type1) {
}

// 1. Collect all CorRefs involved
final CorelMap localCorelMap = new CorelMapBuilder().build(rel);
final List<CorRef> corVarList = new ArrayList<>(localCorelMap.mapRefRelToCorRef.values());
Collections.sort(corVarList);
final List<CorRef> corVarList = collectExternalCorVars(rel);

// 2. Ensure CorVars are present in inputs (adding ValueGenerators if needed)
Frame newLeftFrame = leftFrame;
Expand Down Expand Up @@ -3849,6 +3841,25 @@ private static boolean isFieldNotNullRecursive(RelNode rel, int index) {
}
}

/**
* Collects all correlated variables used in the given relational expression
* that are not defined within the expression itself.
*
* @param rel The relational expression to inspect
* @return A sorted list of external correlated variables
*/
private static List<CorRef> collectExternalCorVars(RelNode rel) {
final CorelMap localCorelMap = new CorelMapBuilder().build(rel);
final List<CorRef> corVarList = new ArrayList<>();
for (CorRef corVar : localCorelMap.mapRefRelToCorRef.values()) {
if (!localCorelMap.mapCorToCorRel.containsKey(corVar.corr)) {
corVarList.add(corVar);
}
}
Collections.sort(corVarList);
return corVarList;
}

/**
* Ensures that the correlated variables in {@code allCorDefs} are present
* in the output of the frame.
Expand Down
132 changes: 132 additions & 0 deletions core/src/test/java/org/apache/calcite/sql2rel/RelDecorrelatorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,138 @@ public static Frameworks.ConfigBuilder config() {
assertThat(after, hasTree(planAfter));
}

/** Test case for <a href="https://issues.apache.org/jira/browse/CALCITE-7394">[CALCITE-7394]
* Nested sub-query with multiple levels of correlation returns incorrect results</a>. */
@Test void testNestedSubQueryWithMultiLevelCorrelation() {
final FrameworkConfig frameworkConfig = config().build();
final RelBuilder builder = RelBuilder.create(frameworkConfig);
final RelOptCluster cluster = builder.getCluster();
final Planner planner = Frameworks.getPlanner(frameworkConfig);
final String sql = ""
+ "select d.dname,\n"
+ " (select count(*)\n"
+ " from emp e\n"
+ " where e.deptno = d.deptno\n"
+ " and exists (\n"
+ " select 1\n"
+ " from (values (1000), (2000), (3000)) as v(sal)\n"
+ " where e.sal > v.sal\n"
+ " and d.deptno * 100 < v.sal\n"
+ " )\n"
+ " ) as c\n"
+ "from dept d\n"
+ "order by d.dname";
final RelNode originalRel;
try {
final SqlNode parse = planner.parse(sql);
final SqlNode validate = planner.validate(parse);
originalRel = planner.rel(validate).rel;
} catch (Exception e) {
throw TestUtil.rethrow(e);
}

final HepProgram hepProgram = HepProgram.builder()
.addRuleCollection(
ImmutableList.of(
// SubQuery program rules
CoreRules.FILTER_SUB_QUERY_TO_CORRELATE,
CoreRules.PROJECT_SUB_QUERY_TO_CORRELATE,
CoreRules.JOIN_SUB_QUERY_TO_CORRELATE))
.build();
final Program program =
Programs.of(hepProgram, true,
requireNonNull(cluster.getMetadataProvider()));
final RelNode before =
program.run(cluster.getPlanner(), originalRel, cluster.traitSet(),
Collections.emptyList(), Collections.emptyList());
final String planBefore = ""
+ "LogicalSort(sort0=[$0], dir0=[ASC])\n"
+ " LogicalProject(DNAME=[$1], C=[$3])\n"
+ " LogicalCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{0}])\n"
+ " LogicalTableScan(table=[[scott, DEPT]])\n"
+ " LogicalAggregate(group=[{}], EXPR$0=[COUNT()])\n"
+ " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7])\n"
+ " LogicalFilter(condition=[=($7, $cor0.DEPTNO)])\n"
+ " LogicalCorrelate(correlation=[$cor1], joinType=[inner], requiredColumns=[{5}])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n"
+ " LogicalAggregate(group=[{0}])\n"
+ " LogicalProject(i=[true])\n"
+ " LogicalFilter(condition=[AND(>(CAST($cor1.SAL):DECIMAL(12, 2), CAST($0):DECIMAL(12, 2) NOT NULL), <(*($cor0.DEPTNO, 100), $0))])\n"
+ " LogicalValues(tuples=[[{ 1000 }, { 2000 }, { 3000 }]])\n";
assertThat(before, hasTree(planBefore));

// Decorrelate without any rules, just "purely" decorrelation algorithm on RelDecorrelator
final RelNode after =
RelDecorrelator.decorrelateQuery(before, builder, RuleSets.ofList(Collections.emptyList()),
RuleSets.ofList(Collections.emptyList()));
// before fix:
//
// LogicalSort(sort0=[$0], dir0=[ASC])
// LogicalProject(DNAME=[$1], C=[$7])
// LogicalJoin(condition=[AND(=($0, $5), =($4, $6))], joinType=[left])
// LogicalProject(DEPTNO=[$0], DNAME=[$1], LOC=[$2], DEPTNO0=[$0], $f4=[*($0, 100)])
// LogicalTableScan(table=[[scott, DEPT]])
// LogicalProject(DEPTNO8=[$0], $f4=[$1], EXPR$0=[CASE(IS NOT NULL($5), $5, 0)])
// LogicalJoin(condition=[AND(IS NOT DISTINCT FROM($0, $3),
// IS NOT DISTINCT FROM($1, $4))], joinType=[left])
// LogicalJoin(condition=[true], joinType=[inner]) // <---- error part
// LogicalProject(DEPTNO=[$0], $f4=[*($0, 100)])
// LogicalTableScan(table=[[scott, DEPT]])
// LogicalAggregate(group=[{0}]) // <---- error part
// LogicalProject(SAL0=[CAST($5):DECIMAL(12, 2)]) // <---- error part
// LogicalTableScan(table=[[scott, EMP]]) // <---- error part
// LogicalAggregate(group=[{0, 1}], EXPR$0=[COUNT()])
// LogicalProject(DEPTNO8=[$7], $f4=[$9])
// LogicalFilter(condition=[IS NOT NULL($7)])
// LogicalProject(..., DEPTNO=[$7], i=[$11], $f4=[$9])
// LogicalJoin(condition=[=($8, $10)], joinType=[inner])
// LogicalProject(..., SAL0=[CAST($5):DECIMAL(12, 2)])
// LogicalTableScan(table=[[scott, EMP]])
// LogicalProject($f4=[$0], SAL0=[$1], $f2=[true])
// LogicalAggregate(group=[{0, 1}])
// LogicalProject($f4=[$1], SAL0=[$2])
// LogicalJoin(condition=[AND(>($2, CAST($0):DECIMAL(12, 2) NOT NULL),
// <($1, $0))], joinType=[inner])
// LogicalValues(tuples=[[{ 1000 }, { 2000 }, { 3000 }]])
// LogicalJoin(condition=[true], joinType=[inner])
// LogicalAggregate(group=[{0}])
// LogicalProject($f4=[*($0, 100)])
// LogicalTableScan(table=[[scott, DEPT]])
// LogicalAggregate(group=[{0}])
// LogicalProject(SAL0=[CAST($5):DECIMAL(12, 2)])
// LogicalTableScan(table=[[scott, EMP]])
final String planAfter = ""
+ "LogicalSort(sort0=[$0], dir0=[ASC])\n"
+ " LogicalProject(DNAME=[$1], C=[$7])\n"
+ " LogicalJoin(condition=[AND(=($0, $5), =($4, $6))], joinType=[left])\n"
+ " LogicalProject(DEPTNO=[$0], DNAME=[$1], LOC=[$2], DEPTNO0=[$0], $f4=[*($0, 100)])\n"
+ " LogicalTableScan(table=[[scott, DEPT]])\n"
+ " LogicalProject(DEPTNO8=[$0], $f4=[$1], EXPR$0=[CASE(IS NOT NULL($4), $4, 0)])\n"
+ " LogicalJoin(condition=[AND(IS NOT DISTINCT FROM($0, $2), IS NOT DISTINCT FROM($1, $3))], joinType=[left])\n"
+ " LogicalProject(DEPTNO=[$0], $f4=[*($0, 100)])\n"
+ " LogicalTableScan(table=[[scott, DEPT]])\n"
+ " LogicalAggregate(group=[{0, 1}], EXPR$0=[COUNT()])\n"
+ " LogicalProject(DEPTNO8=[$7], $f4=[$9])\n"
+ " LogicalFilter(condition=[IS NOT NULL($7)])\n"
+ " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], i=[$11], $f4=[$9])\n"
+ " LogicalJoin(condition=[=($8, $10)], joinType=[inner])\n"
+ " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SAL0=[CAST($5):DECIMAL(12, 2)])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n"
+ " LogicalProject($f4=[$0], SAL0=[$1], $f2=[true])\n"
+ " LogicalAggregate(group=[{0, 1}])\n"
+ " LogicalProject($f4=[$1], SAL0=[$2])\n"
+ " LogicalJoin(condition=[AND(>($2, CAST($0):DECIMAL(12, 2) NOT NULL), <($1, $0))], joinType=[inner])\n"
+ " LogicalValues(tuples=[[{ 1000 }, { 2000 }, { 3000 }]])\n"
+ " LogicalJoin(condition=[true], joinType=[inner])\n"
+ " LogicalAggregate(group=[{0}])\n"
+ " LogicalProject($f4=[*($0, 100)])\n"
+ " LogicalTableScan(table=[[scott, DEPT]])\n"
+ " LogicalAggregate(group=[{0}])\n"
+ " LogicalProject(SAL0=[CAST($5):DECIMAL(12, 2)])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n";
assertThat(after, hasTree(planAfter));
}

/** Test case for <a href="https://issues.apache.org/jira/browse/CALCITE-7297">[CALCITE-7297]
* The result is incorrect when the GROUP BY key in a subquery is a RexFieldAccess</a>. */
@Test void testSkipsRedundantValueGenerator() {
Expand Down
175 changes: 175 additions & 0 deletions core/src/test/resources/sql/sub-query.iq
Original file line number Diff line number Diff line change
Expand Up @@ -5617,6 +5617,181 @@ ORDER BY deptno;

!ok

# [CALCITE-7394] Nested sub-query with multiple levels of correlation returns incorrect results
select d.dname,
(select count(*)
from emp e
where e.deptno = d.deptno
and e.sal > (
select min(s.losal)
from (VALUES (1, 700, 1200), (2, 1201, 1400), (3, 1401, 2000), (4, 2001, 3000), (5, 3001, 9999)) AS s(grade, losal, hisal)
where e.sal BETWEEN s.losal AND s.hisal
and s.hisal > d.deptno * 10
)
) as high_paid_count
from dept d
order by d.dname;
+------------+-----------------+
| DNAME | HIGH_PAID_COUNT |
+------------+-----------------+
| ACCOUNTING | 3 |
| OPERATIONS | 0 |
| RESEARCH | 5 |
| SALES | 6 |
+------------+-----------------+
(4 rows)

!ok

# [CALCITE-7394] Nested sub-query with multiple levels of correlation returns incorrect results
select e.ename
from emp e
where e.sal > (
select avg(e2.sal)
from emp e2
where e2.deptno = e.deptno
and exists (
select 1
from (values (7369, 20)) as b(empno, deptno)
where b.empno = e2.empno
and b.deptno = e.deptno
)
)
and e.sal < 2000
order by e.ename;
+-------+
| ENAME |
+-------+
| ADAMS |
+-------+
(1 row)

!ok

# [CALCITE-7394] Nested sub-query with multiple levels of correlation returns incorrect results
select d.deptno
from dept d
where exists (
select 1
from emp e
where e.deptno = d.deptno
and exists (
select 1
from (VALUES (1, 700, 1200), (2, 1201, 1400), (3, 1401, 2000), (4, 2001, 3000), (5, 3001, 9999)) AS s(grade, losal, hisal)
where s.grade = 1
and s.hisal >= e.sal
and s.losal <= d.deptno * 20
)
)
order by d.deptno;
+--------+
| DEPTNO |
+--------+
+--------+
(0 rows)

!ok

# [CALCITE-7394] Nested sub-query with multiple levels of correlation returns incorrect results
select e.ename
from emp e
where e.deptno in (
select d.deptno
from dept d
where d.deptno = e.deptno and d.deptno = 10
union
select d.deptno
from dept d
where d.deptno = e.deptno
and exists (
select 1
from emp e2
where e2.deptno = d.deptno
and e2.empno = e.empno
and e2.sal > 2000
)
)
order by e.ename;
+--------+
| ENAME |
+--------+
| BLAKE |
| CLARK |
| FORD |
| JONES |
| KING |
| MILLER |
| SCOTT |
+--------+
(7 rows)

!ok

# [CALCITE-7394] Nested sub-query with multiple levels of correlation returns incorrect results
select e.ename
from emp e
where exists (
select 1
from dept d
join emp e2 on d.deptno = e2.deptno
where d.deptno = e.deptno
and exists (
select 1
from (values (10), (20), (30)) as v(deptno)
where v.deptno = e2.deptno
and v.deptno = e.deptno
)
and e2.empno = e.empno
)
order by e.ename;
+--------+
| ENAME |
+--------+
| ADAMS |
| ALLEN |
| BLAKE |
| CLARK |
| FORD |
| JAMES |
| JONES |
| KING |
| MARTIN |
| MILLER |
| SCOTT |
| SMITH |
| TURNER |
| WARD |
+--------+
(14 rows)

!ok

# [CALCITE-7394] Nested sub-query with multiple levels of correlation returns incorrect results
select d.dname,
(select count(*)
from emp e
where e.deptno = d.deptno
and exists (
select 1
from (values (1000), (2000), (3000)) as v(sal)
where e.sal > v.sal
and d.deptno * 100 < v.sal
)
) as c
from dept d
order by d.dname;
+------------+---+
| DNAME | C |
+------------+---+
| ACCOUNTING | 2 |
| OPERATIONS | 0 |
| RESEARCH | 0 |
| SALES | 0 |
+------------+---+
(4 rows)

!ok

# [CALCITE-7303] Subqueries cannot be decorrelated if filter condition have multi CorrelationId
SELECT deptno
FROM emp e
Expand Down
Loading