Skip to content
This repository was archived by the owner on Feb 20, 2023. It is now read-only.

Commit 0ee555d

Browse files
jkosh44lmwnshnmbutrovich
authored
Fix correlated subquery optimizer rules (#1405)
* Fix correlated subquery optimizer rules Part of RewritePullFilterThroughAggregation and DependentSingleJoinToInnerJoin rules adds a group by column to the aggregation from one of the sides of the predicate. When selecting a column we need to make sure that the column is at the same depth as the aggregation or deeper. If the group by column is part of the outer query and the aggregation part of the inner query, then the aggregation has no way of accessing the column. A higher value for depth means deeper in the querer and a lower value for depth means more shallow in the query. The current code was always setting the column from the left side of the predicate as the group by column. This was sometimes causing the group by column to be more shallow than the aggregation itself which caused the query to error out. Fixes #1404 * Add comments * Fix comments and variable naming * Fix comments * Remove dead code from DependentSingleJoinToInnerJoin * Revert "Remove dead code from DependentSingleJoinToInnerJoin" This reverts commit d1255a3. * Extract common code from nested predicate rules * Respond to PR comments Co-authored-by: Wan Shen Lim <[email protected]> Co-authored-by: Matt Butrovich <[email protected]>
1 parent 5f59b7a commit 0ee555d

File tree

4 files changed

+61
-44
lines changed

4 files changed

+61
-44
lines changed

script/testing/junit/traces/nested-query.test

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,11 @@ SELECT item_name FROM shipment WHERE qty <= (SELECT SUM(amount) FROM part WHERE
174174
----
175175
14 values hashing to eeafa4e65b8e56d2198e1a4703307d57
176176

177+
query T nosort
178+
SELECT item_name FROM shipment WHERE qty <= (SELECT SUM(amount) FROM part WHERE shipment.pno = part.pno) order by item_name;
179+
----
180+
14 values hashing to eeafa4e65b8e56d2198e1a4703307d57
181+
177182
query T nosort
178183
SELECT item_name FROM shipment WHERE qty < (SELECT SUM(amount) FROM part INNER JOIN shipment s on part.pno = s.pno) order by item_name;
179184
----

src/include/optimizer/rules/unnesting_rules.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
#pragma once
22

33
#include <memory>
4+
#include <string>
5+
#include <tuple>
6+
#include <unordered_set>
47
#include <vector>
58

69
#include "optimizer/rule.h"
@@ -120,4 +123,19 @@ class DependentSingleJoinToInnerJoin : public Rule {
120123
OptimizationContext *context) const override;
121124
};
122125

126+
/**
127+
* Given predicates associated with some aggregate, this function will extract the predicates that are correlated with
128+
* an outer query, as well as the columns from those predicates that aren't correlated.
129+
* @param predicates vector of predicates associated with an aggregate
130+
* @param child_group_aliases_set the table alias set of the predicate's child node
131+
* @param[out] correlated_predicates predicates which are correlated to an outer query
132+
* @param[out] normal_predicates predicates which are not correlated to an outer query
133+
* @param[out] new_group_cols columns from correlated predicates which are not correlated to an outer query. These
134+
* will be used as group by columns in the nested aggregate
135+
*/
136+
void ExtractCorrelatedPredicatesWithAggregate(
137+
const std::vector<AnnotatedExpression> &predicates, const std::unordered_set<std::string> &child_group_aliases_set,
138+
std::vector<AnnotatedExpression> *correlated_predicates, std::vector<AnnotatedExpression> *normal_predicates,
139+
std::vector<common::ManagedPointer<parser::AbstractExpression>> *new_groupby_cols);
140+
123141
}; // namespace noisepage::optimizer

src/optimizer/rules/rewrite_rules.cpp

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "optimizer/optimizer_defs.h"
1717
#include "optimizer/physical_operators.h"
1818
#include "optimizer/properties.h"
19+
#include "optimizer/rules/unnesting_rules.h"
1920
#include "optimizer/util.h"
2021
#include "parser/expression_util.h"
2122

@@ -493,21 +494,8 @@ void RewritePullFilterThroughAggregation::Transform(common::ManagedPointer<Abstr
493494
std::vector<AnnotatedExpression> correlated_predicates;
494495
std::vector<AnnotatedExpression> normal_predicates;
495496
std::vector<common::ManagedPointer<parser::AbstractExpression>> new_groupby_cols;
496-
for (auto &predicate : predicates) {
497-
if (OptimizerUtil::IsSubset(child_group_aliases_set, predicate.GetTableAliasSet())) {
498-
normal_predicates.emplace_back(predicate);
499-
} else {
500-
// Correlated predicate, already in the form of
501-
// (outer_relation.a = (expr))
502-
correlated_predicates.emplace_back(predicate);
503-
auto root_expr = predicate.GetExpr();
504-
if (root_expr->GetChild(0)->GetDepth() < root_expr->GetDepth()) {
505-
new_groupby_cols.emplace_back(root_expr->GetChild(1).Get());
506-
} else {
507-
new_groupby_cols.emplace_back(root_expr->GetChild(0).Get());
508-
}
509-
}
510-
}
497+
ExtractCorrelatedPredicatesWithAggregate(predicates, child_group_aliases_set, &correlated_predicates,
498+
&normal_predicates, &new_groupby_cols);
511499

512500
if (correlated_predicates.empty()) {
513501
// No need to pull

src/optimizer/rules/unnesting_rules.cpp

Lines changed: 35 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -150,32 +150,11 @@ void DependentSingleJoinToInnerJoin::Transform(common::ManagedPointer<AbstractOp
150150
const auto &agg_group_aliases_set = memo.GetGroupByID(agg_group_id)->GetTableAliases();
151151
auto &filter_predicates = filter_expr->Contents()->GetContentsAs<LogicalFilter>()->GetPredicates();
152152

153-
std::vector<AnnotatedExpression> ancestor_predicates;
154-
std::vector<AnnotatedExpression> descendant_predicates;
153+
std::vector<AnnotatedExpression> correlated_predicates;
154+
std::vector<AnnotatedExpression> normal_predicates;
155155
std::vector<common::ManagedPointer<parser::AbstractExpression>> new_groupby_cols;
156-
157-
// loop over all predicates check each of them if they refer table not contained in agg
158-
// from RewritePullFilterThroughAggregation
159-
for (auto &predicate : filter_predicates) {
160-
if (OptimizerUtil::IsSubset(agg_group_aliases_set, predicate.GetTableAliasSet())) {
161-
descendant_predicates.emplace_back(predicate);
162-
} else {
163-
// Correlated predicate, already in the form of
164-
// (outer_relation.a = (expr))
165-
ancestor_predicates.emplace_back(predicate);
166-
auto root_expr = predicate.GetExpr();
167-
// If the sub-query depth level of the first child is less than the current expression
168-
// the first child is outer_relation.a and the second child is a (expr)
169-
// The second child expression shall be evaluated as a part of the new aggregation before the new filter
170-
if (root_expr->GetChild(0)->GetDepth() < root_expr->GetDepth()) {
171-
new_groupby_cols.emplace_back(root_expr->GetChild(1).Get());
172-
} else {
173-
// Otherwise, the first child is a (expr) and the second child is outer_relation.a
174-
// The first child expression shall be evaluated as a part of the new aggregation before the new filter
175-
new_groupby_cols.emplace_back(root_expr->GetChild(0).Get());
176-
}
177-
}
178-
}
156+
ExtractCorrelatedPredicatesWithAggregate(filter_predicates, agg_group_aliases_set, &correlated_predicates,
157+
&normal_predicates, &new_groupby_cols);
179158

180159
// Create a new agg node
181160
auto aggregation = agg_expr->Contents()->GetContentsAs<LogicalAggregateAndGroupBy>();
@@ -193,10 +172,10 @@ void DependentSingleJoinToInnerJoin::Transform(common::ManagedPointer<AbstractOp
193172
// Create a new inner join node from single join
194173
std::vector<std::unique_ptr<AbstractOptimizerNode>> inner_node;
195174
inner_node.emplace_back(input->GetChildren()[0]->Copy());
196-
if (!descendant_predicates.empty()) {
175+
if (!normal_predicates.empty()) {
197176
std::vector<std::unique_ptr<AbstractOptimizerNode>> child_node;
198177
child_node.emplace_back(std::move(new_aggr));
199-
auto filter = std::make_unique<OperatorNode>(LogicalFilter::Make(std::move(descendant_predicates))
178+
auto filter = std::make_unique<OperatorNode>(LogicalFilter::Make(std::move(normal_predicates))
200179
.RegisterWithTxnContext(context->GetOptimizerContext()->GetTxn()),
201180
std::move(child_node), context->GetOptimizerContext()->GetTxn());
202181
inner_node.emplace_back(std::move(filter));
@@ -210,10 +189,10 @@ void DependentSingleJoinToInnerJoin::Transform(common::ManagedPointer<AbstractOp
210189
std::unique_ptr<OperatorNode> output;
211190
// Create new filter nodes
212191
// Construct a top filter if any
213-
if (!ancestor_predicates.empty()) {
192+
if (!correlated_predicates.empty()) {
214193
std::vector<std::unique_ptr<AbstractOptimizerNode>> root_node;
215194
root_node.emplace_back(std::move(new_inner));
216-
output = std::make_unique<OperatorNode>(LogicalFilter::Make(std::move(ancestor_predicates))
195+
output = std::make_unique<OperatorNode>(LogicalFilter::Make(std::move(correlated_predicates))
217196
.RegisterWithTxnContext(context->GetOptimizerContext()->GetTxn()),
218197
std::move(root_node), context->GetOptimizerContext()->GetTxn());
219198

@@ -223,4 +202,31 @@ void DependentSingleJoinToInnerJoin::Transform(common::ManagedPointer<AbstractOp
223202
transformed->emplace_back(std::move(output));
224203
}
225204

205+
void ExtractCorrelatedPredicatesWithAggregate(
206+
const std::vector<AnnotatedExpression> &predicates, const std::unordered_set<std::string> &child_group_aliases_set,
207+
std::vector<AnnotatedExpression> *correlated_predicates, std::vector<AnnotatedExpression> *normal_predicates,
208+
std::vector<common::ManagedPointer<parser::AbstractExpression>> *new_groupby_cols) {
209+
for (auto &predicate : predicates) {
210+
if (OptimizerUtil::IsSubset(child_group_aliases_set, predicate.GetTableAliasSet())) {
211+
normal_predicates->emplace_back(predicate);
212+
} else {
213+
// Correlated predicate, predicate in nested query references column in outer query
214+
correlated_predicates->emplace_back(predicate);
215+
auto root_expr = predicate.GetExpr();
216+
// See https://github.com/cmu-db/noisepage/issues/1404
217+
// The higher the depth of an expression the deeper/more nested it is.
218+
// If the sub-query depth level of the left side of the predicate is greater than the current expression then the
219+
// left side doesn't references the outer query (i.e. the right side references the outer query).
220+
// The left side shall be evaluated as a part of the new aggregation before the new filter
221+
if (root_expr->GetChild(0)->GetDepth() > root_expr->GetDepth()) {
222+
new_groupby_cols->emplace_back(root_expr->GetChild(0).Get());
223+
} else {
224+
// Otherwise, the right side of the predicate doesn't references the outer query (i.e. the left side references
225+
// the outer query). The right side shall be evaluated as a part of the new aggregation before the new filter
226+
new_groupby_cols->emplace_back(root_expr->GetChild(1).Get());
227+
}
228+
}
229+
}
230+
}
231+
226232
} // namespace noisepage::optimizer

0 commit comments

Comments
 (0)