Skip to content
This repository was archived by the owner on Sep 27, 2019. It is now read-only.

Commit f371356

Browse files
authored
Merge pull request #1147 from chenboy/nested_query
Add unnesting for correlated subquery in where clause with aggregation
2 parents 23e1005 + ca9880a commit f371356

28 files changed

+890
-543
lines changed

src/binder/bind_node_visitor.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,8 @@ void BindNodeVisitor::Visit(parser::SelectStatement *node) {
7575
}
7676
select_element->DeriveSubqueryFlag();
7777

78-
// Recursively deduce expression value type
79-
expression::ExpressionUtil::EvaluateExpression({ExprMap()},
80-
select_element.get());
81-
// Recursively deduce expression name
78+
// Traverse the expression to deduce expression value type and name
79+
select_element->DeduceExpressionType();
8280
select_element->DeduceExpressionName();
8381
new_select_list.push_back(std::move(select_element));
8482
}

src/common/internal_types.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -978,6 +978,9 @@ std::string ExpressionTypeToString(ExpressionType type, bool short_str) {
978978
case ExpressionType::CAST: {
979979
return ("CAST");
980980
}
981+
case ExpressionType::OPERATOR_IS_NOT_NULL: {
982+
return ("IS_NOT_NULL");
983+
}
981984
default: {
982985
throw ConversionException(StringUtil::Format(
983986
"No string conversion for ExpressionType value '%d'",

src/expression/aggregate_expression.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ namespace expression {
1717

1818
const std::string AggregateExpression::GetInfo(int num_indent) const {
1919
std::ostringstream os;
20-
2120
os << StringUtil::Indent(num_indent) << "Expression ::\n"
2221
<< StringUtil::Indent(num_indent + 1) << "expression type = Aggregate,\n"
2322
<< StringUtil::Indent(num_indent + 1) << "aggregate type = " << expr_name_

src/include/common/internal_types.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1336,6 +1336,7 @@ enum class RuleType : uint32_t {
13361336
MARK_JOIN_INNER_JOIN_TO_INNER_JOIN,
13371337
MARK_JOIN_FILTER_TO_INNER_JOIN,
13381338
PULL_FILTER_THROUGH_MARK_JOIN,
1339+
PULL_FILTER_THROUGH_AGGREGATION,
13391340

13401341
// Place holder to generate number of rules compile time
13411342
NUM_RULES

src/include/expression/aggregate_expression.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,16 +79,16 @@ class AggregateExpression : public AbstractExpression {
7979
}
8080

8181
// Attribute binding
82-
void PerformBinding(const std::vector<const planner::BindingContext *> &
83-
binding_contexts) override {
82+
void PerformBinding(const std::vector<const planner::BindingContext *>
83+
&binding_contexts) override {
8484
const auto &context = binding_contexts[0];
8585
ai_ = context->Find(value_idx_);
8686
PL_ASSERT(ai_ != nullptr);
87-
LOG_DEBUG("AggregateOutput Column ID %u.%u binds to AI %p (%s)", 0,
87+
LOG_TRACE("AggregateOutput Column ID %u.%u binds to AI %p (%s)", 0,
8888
value_idx_, ai_, ai_->name.c_str());
8989
}
9090

91-
const planner::AttributeInfo* GetAttributeRef() const { return ai_; }
91+
const planner::AttributeInfo *GetAttributeRef() const { return ai_; }
9292

9393
inline void SetValueIdx(int value_idx) { value_idx_ = value_idx; }
9494

@@ -130,7 +130,7 @@ class AggregateExpression : public AbstractExpression {
130130

131131
private:
132132
int value_idx_ = -1;
133-
const planner::AttributeInfo* ai_;
133+
const planner::AttributeInfo *ai_;
134134
};
135135

136136
} // namespace expression

src/include/expression/expression_util.h

Lines changed: 67 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -401,35 +401,43 @@ class ExpressionUtil {
401401
}
402402

403403
/**
404-
* TODO(boweic): this function may not be efficient, in the future we may want
405-
* to add expressions to groups so that we do not need to walk through the
406-
* expression tree when judging '==' each time
404+
* @brief TODO(boweic): this function may not be efficient, in the future we
405+
* may want to add expressions to groups so that we do not need to walk
406+
* through the expression tree when judging '==' each time
407407
*
408408
* Convert all expression in the current expression tree that is in
409409
* child_expr_map to tuple value expression with corresponding column offset
410410
* of the input child tuple. This is used for handling projection
411411
* on situations like aggregate function (e.g. SELECT sum(a)+max(b) FROM ...
412412
* GROUP BY ...) when input columns contain sum(a) and sum(b). We need to
413413
* treat them as tuple value expression in the projection plan. This function
414-
*should always be called before calling EvaluateExpression
414+
* should always be called before calling EvaluateExpression
415415
*
416416
* Please notice that this function should only apply to copied expression
417-
*since it would modify the current expression. We do not want to modify the
418-
*original expression since it may be referenced in other places
417+
* since it would modify the current expression. We do not want to modify the
418+
* original expression since it may be referenced in other places
419+
*
420+
* @param expr The expression to modify
421+
* @param child_expr_maps map from child column ids to expression
419422
*/
420423
static void ConvertToTvExpr(AbstractExpression *expr,
421-
ExprMap &child_expr_map) {
424+
std::vector<ExprMap> child_expr_maps) {
425+
if (expr == nullptr) {
426+
return;
427+
};
422428
for (size_t i = 0; i < expr->GetChildrenSize(); i++) {
423429
auto child_expr = expr->GetModifiableChild(i);
424-
if (child_expr->GetExpressionType() != ExpressionType::VALUE_TUPLE &&
425-
child_expr_map.count(child_expr)) {
426-
// EvaluateExpression({child_expr_map}, child_expr);
427-
expr->SetChild(i,
428-
new TupleValueExpression(child_expr->GetValueType(), 0,
429-
child_expr_map[child_expr]));
430-
} else {
431-
ConvertToTvExpr(child_expr, child_expr_map);
430+
for (size_t tuple_idx = 0; tuple_idx < child_expr_maps.size();
431+
++tuple_idx) {
432+
if (child_expr->GetExpressionType() != ExpressionType::VALUE_TUPLE &&
433+
child_expr_maps[tuple_idx].count(child_expr)) {
434+
expr->SetChild(i, new TupleValueExpression(
435+
child_expr->GetValueType(), tuple_idx,
436+
child_expr_maps[tuple_idx][child_expr]));
437+
break;
438+
}
432439
}
440+
ConvertToTvExpr(expr->GetModifiableChild(i), child_expr_maps);
433441
}
434442
}
435443

@@ -443,6 +451,42 @@ class ExpressionUtil {
443451
return ordered_expr;
444452
}
445453

454+
/**
455+
* Walks an expression trees and find all AggregationExprs subtrees.
456+
*/
457+
static void GetTupleAndAggregateExprs(ExprSet &expr_set,
458+
AbstractExpression *expr) {
459+
std::vector<TupleValueExpression *> tv_exprs;
460+
std::vector<AggregateExpression *> aggr_exprs;
461+
GetAggregateExprs(aggr_exprs, tv_exprs, expr);
462+
for (auto &tv_expr : tv_exprs) {
463+
expr_set.insert(tv_expr);
464+
}
465+
for (auto &aggr_expr : aggr_exprs) {
466+
expr_set.insert(aggr_expr);
467+
}
468+
}
469+
470+
/**
471+
* Walks an expression trees and find all AggregationExprs subtrees.
472+
*/
473+
static void GetTupleAndAggregateExprs(ExprMap &expr_map,
474+
AbstractExpression *expr) {
475+
std::vector<TupleValueExpression *> tv_exprs;
476+
std::vector<AggregateExpression *> aggr_exprs;
477+
GetAggregateExprs(aggr_exprs, tv_exprs, expr);
478+
for (auto &tv_expr : tv_exprs) {
479+
if (!expr_map.count(tv_expr)) {
480+
expr_map.emplace(tv_expr, expr_map.size());
481+
}
482+
}
483+
for (auto &aggr_expr : aggr_exprs) {
484+
if (!expr_map.count(aggr_expr)) {
485+
expr_map.emplace(aggr_expr, expr_map.size());
486+
}
487+
}
488+
}
489+
446490
/**
447491
* Walks an expression trees and find all AggregationExprs subtrees.
448492
*/
@@ -511,8 +555,9 @@ class ExpressionUtil {
511555
// To evaluate the return type, we need a bottom up approach.
512556
if (expr == nullptr) return;
513557
size_t children_size = expr->GetChildrenSize();
514-
for (size_t i = 0; i < children_size; i++)
558+
for (size_t i = 0; i < children_size; i++) {
515559
EvaluateExpression(expr_maps, expr->GetModifiableChild(i));
560+
}
516561

517562
if (expr->GetExpressionType() == ExpressionType::VALUE_TUPLE) {
518563
// Point to the correct column returned in the logical tuple underneath
@@ -530,9 +575,12 @@ class ExpressionUtil {
530575
}
531576
} else if (IsAggregateExpression(expr->GetExpressionType())) {
532577
auto aggr_expr = (AggregateExpression *)expr;
533-
auto &expr_map = expr_maps[0];
534-
auto iter = expr_map.find(expr);
535-
if (iter != expr_map.end()) aggr_expr->SetValueIdx(iter->second);
578+
for (auto &expr_map : expr_maps) {
579+
auto iter = expr_map.find(expr);
580+
if (iter != expr_map.end()) {
581+
aggr_expr->SetValueIdx(iter->second);
582+
}
583+
}
536584
} else if (expr->GetExpressionType() == ExpressionType::FUNCTION) {
537585
auto func_expr = (expression::FunctionExpression *)expr;
538586
std::vector<type::TypeId> argtypes;

src/include/optimizer/operators.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,9 @@ class LogicalAggregateAndGroupBy
226226
public:
227227
static Operator make();
228228

229+
static Operator make(
230+
std::vector<std::shared_ptr<expression::AbstractExpression>> &columns);
231+
229232
static Operator make(
230233
std::vector<std::shared_ptr<expression::AbstractExpression>> &columns,
231234
std::vector<AnnotatedExpression> &having);

src/include/optimizer/plan_generator.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,8 @@ class PlanGenerator : public OperatorVisitor {
126126
*/
127127
std::unique_ptr<expression::AbstractExpression> GeneratePredicateForScan(
128128
const std::shared_ptr<expression::AbstractExpression> predicate_expr,
129-
const std::string &alias, std::shared_ptr<catalog::TableCatalogObject> table);
129+
const std::string &alias,
130+
std::shared_ptr<catalog::TableCatalogObject> table);
130131

131132
/**
132133
* @brief Generate projection info and projection schema for join
@@ -148,7 +149,7 @@ class PlanGenerator : public OperatorVisitor {
148149
AggregateType aggr_type,
149150
const std::vector<std::shared_ptr<expression::AbstractExpression>>
150151
*groupby_cols,
151-
expression::AbstractExpression *having);
152+
std::unique_ptr<expression::AbstractExpression> having);
152153

153154
/**
154155
* @brief The required output property. Note that we have previously enforced

src/include/optimizer/query_to_operator_transformer.h

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ class QueryToOperatorTransformer : public SqlNodeVisitor {
6666
void Visit(expression::ComparisonExpression *expr) override;
6767
void Visit(expression::OperatorExpression *expr) override;
6868

69+
private:
6970
inline oid_t GetAndIncreaseGetId() { return get_id++; }
7071

7172
/**
@@ -80,12 +81,10 @@ class QueryToOperatorTransformer : public SqlNodeVisitor {
8081
*
8182
* @param expr The original predicate
8283
*/
83-
void CollectPredicates(expression::AbstractExpression *expr);
84+
std::vector<AnnotatedExpression> CollectPredicates(
85+
expression::AbstractExpression *expr,
86+
std::vector<AnnotatedExpression> predicates = {});
8487

85-
// TODO(boweic): Since we haven't migrated all the functionalities needed to
86-
// generate mark-join and single-join to the optimizer, currently this
87-
// function has not been tested, and it may be a bit hard to understand. We
88-
// may integrate the unnesting functionality in the next PR
8988
/**
9089
* @brief Transform a sub-query in an expression to use
9190
*
@@ -96,14 +95,28 @@ class QueryToOperatorTransformer : public SqlNodeVisitor {
9695
* @return If the expression could be transformed into sub-query, return true,
9796
* return false otherwise
9897
*/
99-
bool GenerateSubquerytree(
100-
expression::AbstractExpression *expr,
101-
std::vector<expression::AbstractExpression *> &select_list,
102-
bool single_join = false);
98+
bool GenerateSubquerytree(expression::AbstractExpression *expr,
99+
oid_t child_id, bool single_join = false);
103100

101+
/**
102+
* @brief Decide if a conjunctive predicate is supported. We need to extract
103+
* conjunction predicate first then call this function to decide if the
104+
* predicate is supported by our system
105+
*
106+
* @param expr The conjunctive predicate provided
107+
*
108+
* @return True if supported, false otherwise
109+
*/
110+
bool IsSupportedConjunctivePredicate(expression::AbstractExpression *expr);
111+
/**
112+
* @brief Check if a sub-select statement is supported.
113+
*
114+
* @param op The select statement
115+
*
116+
* @return True if supported, false otherwise
117+
*/
118+
bool IsSupportedSubSelect(const parser::SelectStatement *op);
104119
static bool RequireAggregation(const parser::SelectStatement *op);
105-
106-
private:
107120
std::shared_ptr<OperatorExpression> output_expr_;
108121

109122
concurrency::TransactionContext *txn_;

src/include/optimizer/rule.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ struct RuleWithPromise {
108108
int promise;
109109

110110
bool operator<(const RuleWithPromise &r) const { return promise < r.promise; }
111+
bool operator>(const RuleWithPromise &r) const { return promise > r.promise; }
111112
};
112113

113114
enum class RewriteRuleSetName : uint32_t {

0 commit comments

Comments
 (0)