Skip to content
This repository was archived by the owner on Sep 27, 2019. It is now read-only.

Commit 55cadb1

Browse files
committed
Fix tons of bugs in the optimizer again :-)
1 parent 000b9ff commit 55cadb1

16 files changed

+240
-62
lines changed

src/expression/aggregate_expression.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ namespace expression {
1717

1818
const std::string AggregateExpression::GetInfo(int num_indent) const {
1919
std::ostringstream os;
20-
2120
os << StringUtil::Indent(num_indent) << "Expression ::\n"
2221
<< StringUtil::Indent(num_indent + 1) << "expression type = Aggregate,\n"
2322
<< StringUtil::Indent(num_indent + 1) << "aggregate type = " << expr_name_

src/expression/tuple_value_expression.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ void TupleValueExpression::PerformBinding(
3636
const auto &context = binding_contexts[GetTupleId()];
3737
ai_ = context->Find(GetColumnId());
3838
PL_ASSERT(ai_ != nullptr);
39-
LOG_TRACE("TVE Column ID %u.%u binds to AI %p (%s)", GetTupleId(),
39+
LOG_DEBUG("TVE Column ID %u.%u binds to AI %p (%s)", GetTupleId(),
4040
GetColumnId(), ai_, ai_->name.c_str());
4141
}
4242

src/include/expression/aggregate_expression.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,16 +79,16 @@ class AggregateExpression : public AbstractExpression {
7979
}
8080

8181
// Attribute binding
82-
void PerformBinding(const std::vector<const planner::BindingContext *> &
83-
binding_contexts) override {
82+
void PerformBinding(const std::vector<const planner::BindingContext *>
83+
&binding_contexts) override {
8484
const auto &context = binding_contexts[0];
8585
ai_ = context->Find(value_idx_);
8686
PL_ASSERT(ai_ != nullptr);
8787
LOG_DEBUG("AggregateOutput Column ID %u.%u binds to AI %p (%s)", 0,
8888
value_idx_, ai_, ai_->name.c_str());
8989
}
9090

91-
const planner::AttributeInfo* GetAttributeRef() const { return ai_; }
91+
const planner::AttributeInfo *GetAttributeRef() const { return ai_; }
9292

9393
inline void SetValueIdx(int value_idx) { value_idx_ = value_idx; }
9494

@@ -130,7 +130,7 @@ class AggregateExpression : public AbstractExpression {
130130

131131
private:
132132
int value_idx_ = -1;
133-
const planner::AttributeInfo* ai_;
133+
const planner::AttributeInfo *ai_;
134134
};
135135

136136
} // namespace expression

src/include/expression/expression_util.h

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,42 @@ class ExpressionUtil {
443443
return ordered_expr;
444444
}
445445

446+
/**
447+
* Walks an expression trees and find all AggregationExprs subtrees.
448+
*/
449+
static void GetTupleAndAggregateExprs(ExprSet &expr_set,
450+
AbstractExpression *expr) {
451+
std::vector<TupleValueExpression *> tv_exprs;
452+
std::vector<AggregateExpression *> aggr_exprs;
453+
GetAggregateExprs(aggr_exprs, tv_exprs, expr);
454+
for (auto &tv_expr : tv_exprs) {
455+
expr_set.insert(tv_expr);
456+
}
457+
for (auto &aggr_expr : aggr_exprs) {
458+
expr_set.insert(aggr_expr);
459+
}
460+
}
461+
462+
/**
463+
* Walks an expression trees and find all AggregationExprs subtrees.
464+
*/
465+
static void GetTupleAndAggregateExprs(ExprMap &expr_map,
466+
AbstractExpression *expr) {
467+
std::vector<TupleValueExpression *> tv_exprs;
468+
std::vector<AggregateExpression *> aggr_exprs;
469+
GetAggregateExprs(aggr_exprs, tv_exprs, expr);
470+
for (auto &tv_expr : tv_exprs) {
471+
if (!expr_map.count(tv_expr)) {
472+
expr_map.emplace(tv_expr, expr_map.size());
473+
}
474+
}
475+
for (auto &aggr_expr : aggr_exprs) {
476+
if (!expr_map.count(aggr_expr)) {
477+
expr_map.emplace(aggr_expr, expr_map.size());
478+
}
479+
}
480+
}
481+
446482
/**
447483
* Walks an expression trees and find all AggregationExprs subtrees.
448484
*/
@@ -511,8 +547,9 @@ class ExpressionUtil {
511547
// To evaluate the return type, we need a bottom up approach.
512548
if (expr == nullptr) return;
513549
size_t children_size = expr->GetChildrenSize();
514-
for (size_t i = 0; i < children_size; i++)
550+
for (size_t i = 0; i < children_size; i++) {
515551
EvaluateExpression(expr_maps, expr->GetModifiableChild(i));
552+
}
516553

517554
if (expr->GetExpressionType() == ExpressionType::VALUE_TUPLE) {
518555
// Point to the correct column returned in the logical tuple underneath
@@ -530,9 +567,12 @@ class ExpressionUtil {
530567
}
531568
} else if (IsAggregateExpression(expr->GetExpressionType())) {
532569
auto aggr_expr = (AggregateExpression *)expr;
533-
auto &expr_map = expr_maps[0];
534-
auto iter = expr_map.find(expr);
535-
if (iter != expr_map.end()) aggr_expr->SetValueIdx(iter->second);
570+
for (auto &expr_map : expr_maps) {
571+
auto iter = expr_map.find(expr);
572+
if (iter != expr_map.end()) {
573+
aggr_expr->SetValueIdx(iter->second);
574+
}
575+
}
536576
} else if (expr->GetExpressionType() == ExpressionType::FUNCTION) {
537577
auto func_expr = (expression::FunctionExpression *)expr;
538578
std::vector<type::TypeId> argtypes;

src/include/optimizer/plan_generator.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,8 @@ class PlanGenerator : public OperatorVisitor {
126126
*/
127127
std::unique_ptr<expression::AbstractExpression> GeneratePredicateForScan(
128128
const std::shared_ptr<expression::AbstractExpression> predicate_expr,
129-
const std::string &alias, std::shared_ptr<catalog::TableCatalogObject> table);
129+
const std::string &alias,
130+
std::shared_ptr<catalog::TableCatalogObject> table);
130131

131132
/**
132133
* @brief Generate projection info and projection schema for join
@@ -148,7 +149,7 @@ class PlanGenerator : public OperatorVisitor {
148149
AggregateType aggr_type,
149150
const std::vector<std::shared_ptr<expression::AbstractExpression>>
150151
*groupby_cols,
151-
expression::AbstractExpression *having);
152+
std::unique_ptr<expression::AbstractExpression> having);
152153

153154
/**
154155
* @brief The required output property. Note that we have previously enforced

src/include/optimizer/rule_impls.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,20 @@ class CombineConsecutiveFilter : public Rule {
290290
OptimizeContext *context) const override;
291291
};
292292

293+
/**
294+
* @brief perform predicate push-down to push a filter through aggregation, also will embed filter into aggregation operator if appropriate.
295+
*/
296+
class PushFilterThroughAggregation : public Rule {
297+
public:
298+
PushFilterThroughAggregation();
299+
300+
bool Check(std::shared_ptr<OperatorExpression> plan,
301+
OptimizeContext *context) const override;
302+
303+
void Transform(std::shared_ptr<OperatorExpression> input,
304+
std::vector<std::shared_ptr<OperatorExpression>> &transformed,
305+
OptimizeContext *context) const override;
306+
};
293307
/**
294308
* @brief Embed a filter into a scan operator. After predicate push-down, we
295309
* eliminate all filters in the operator trees, predicates should be associated

src/include/planner/aggregate_plan.h

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,10 @@ class AggregatePlan : public AbstractPlan {
110110
std::iota(columns.begin(), columns.end(), 0);
111111
}
112112

113-
const std::string GetInfo() const override { return "AggregatePlan"; }
113+
const std::string GetInfo() const override {
114+
return "AggregatePlan(Having(" +
115+
(predicate_ != nullptr ? predicate_->GetInfo() : "") + "))";
116+
}
114117

115118
const std::vector<oid_t> &GetColumnIds() const { return column_ids_; }
116119

@@ -124,7 +127,9 @@ class AggregatePlan : public AbstractPlan {
124127
std::shared_ptr<const catalog::Schema> output_schema_copy(
125128
catalog::Schema::CopySchema(GetOutputSchema()));
126129
AggregatePlan *new_plan = new AggregatePlan(
127-
project_info_->Copy(), std::unique_ptr<const expression::AbstractExpression>(predicate_->Copy()),
130+
project_info_->Copy(),
131+
std::unique_ptr<const expression::AbstractExpression>(
132+
predicate_->Copy()),
128133
std::move(copied_agg_terms), std::move(copied_groupby_col_ids),
129134
output_schema_copy, agg_strategy_);
130135
return std::unique_ptr<AbstractPlan>(new_plan);
@@ -137,16 +142,17 @@ class AggregatePlan : public AbstractPlan {
137142
return !(*this == rhs);
138143
}
139144

140-
virtual void VisitParameters(codegen::QueryParametersMap &map,
145+
virtual void VisitParameters(
146+
codegen::QueryParametersMap &map,
141147
std::vector<peloton::type::Value> &values,
142148
const std::vector<peloton::type::Value> &values_from_user) override;
143149

144150
private:
145151
bool AreEqual(const std::vector<planner::AggregatePlan::AggTerm> &A,
146152
const std::vector<planner::AggregatePlan::AggTerm> &B) const;
147153

148-
hash_t Hash(const std::vector<planner::AggregatePlan::AggTerm> &agg_terms)
149-
const;
154+
hash_t Hash(
155+
const std::vector<planner::AggregatePlan::AggTerm> &agg_terms) const;
150156

151157
private:
152158
/* For projection */

src/optimizer/input_column_deriver.cpp

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -200,25 +200,39 @@ void InputColumnDeriver::AggregateHelper(const BaseOperatorNode *op) {
200200
// input_cols_set.insert(tv_expr);
201201
}
202202
}
203-
vector<AbstractExpression *> output_cols(output_col_idx, nullptr);
204-
for (auto &expr_idx_pair : output_cols_map) {
205-
output_cols[expr_idx_pair.second] = expr_idx_pair.first;
206-
}
207203

208204
// TODO(boweic): do not use shared_ptr
209205
vector<shared_ptr<AbstractExpression>> groupby_cols;
206+
vector<AnnotatedExpression> having_exprs;
210207
if (op->type() == OpType::HashGroupBy) {
211-
groupby_cols = reinterpret_cast<const PhysicalHashGroupBy *>(op)->columns;
208+
auto groupby = reinterpret_cast<const PhysicalHashGroupBy *>(op);
209+
groupby_cols = groupby->columns;
210+
having_exprs = groupby->having;
212211
} else if (op->type() == OpType::SortGroupBy) {
213-
groupby_cols = reinterpret_cast<const PhysicalSortGroupBy *>(op)->columns;
212+
auto groupby = reinterpret_cast<const PhysicalSortGroupBy *>(op);
213+
groupby_cols = groupby->columns;
214+
having_exprs = groupby->having;
214215
}
215216
for (auto &groupby_col : groupby_cols) {
216217
input_cols_set.insert(groupby_col.get());
217218
}
219+
// Check having predicate, since the predicate may use columns other than
220+
// output columns
221+
for (auto &having_expr : having_exprs) {
222+
expression::ExpressionUtil::GetTupleValueExprs(input_cols_set,
223+
having_expr.expr.get());
224+
expression::ExpressionUtil::GetTupleAndAggregateExprs(
225+
output_cols_map, having_expr.expr.get());
226+
}
218227
vector<AbstractExpression *> input_cols;
219228
for (auto &col : input_cols_set) {
220229
input_cols.push_back(col);
221230
}
231+
output_col_idx = output_cols_map.size();
232+
vector<AbstractExpression *> output_cols(output_col_idx, nullptr);
233+
for (auto &expr_idx_pair : output_cols_map) {
234+
output_cols[expr_idx_pair.second] = expr_idx_pair.first;
235+
}
222236

223237
output_input_cols_ =
224238
pair<vector<AbstractExpression *>, vector<vector<AbstractExpression *>>>{
@@ -248,16 +262,16 @@ void InputColumnDeriver::JoinHelper(const BaseOperatorNode *op) {
248262
PL_ASSERT(right_keys != nullptr);
249263
PL_ASSERT(join_conds != nullptr);
250264
for (auto &left_key : *left_keys) {
251-
expression::ExpressionUtil::GetTupleValueExprs(input_cols_set,
252-
left_key.get());
265+
expression::ExpressionUtil::GetTupleAndAggregateExprs(input_cols_set,
266+
left_key.get());
253267
}
254268
for (auto &right_key : *right_keys) {
255-
expression::ExpressionUtil::GetTupleValueExprs(input_cols_set,
256-
right_key.get());
269+
expression::ExpressionUtil::GetTupleAndAggregateExprs(input_cols_set,
270+
right_key.get());
257271
}
258272
for (auto &join_cond : *join_conds) {
259-
expression::ExpressionUtil::GetTupleValueExprs(input_cols_set,
260-
join_cond.expr.get());
273+
expression::ExpressionUtil::GetTupleAndAggregateExprs(input_cols_set,
274+
join_cond.expr.get());
261275
}
262276
ExprMap output_cols_map;
263277
for (auto expr : required_cols_) {
@@ -274,8 +288,21 @@ void InputColumnDeriver::JoinHelper(const BaseOperatorNode *op) {
274288
UNUSED_ATTRIBUTE auto &probe_table_aliases =
275289
memo_->GetGroupByID(gexpr_->GetChildGroupId(1))->GetTableAliases();
276290
for (auto &col : input_cols_set) {
277-
PL_ASSERT(col->GetExpressionType() == ExpressionType::VALUE_TUPLE);
278-
auto tv_expr = reinterpret_cast<expression::TupleValueExpression *>(col);
291+
expression::TupleValueExpression *tv_expr;
292+
if (col->GetExpressionType() == ExpressionType::VALUE_TUPLE) {
293+
tv_expr = reinterpret_cast<expression::TupleValueExpression *>(col);
294+
} else {
295+
PL_ASSERT(expression::ExpressionUtil::IsAggregateExpression(col));
296+
ExprSet tv_exprs;
297+
expression::ExpressionUtil::GetTupleValueExprs(tv_exprs, col);
298+
if (tv_exprs.empty()) {
299+
// Do not need input columns like COUNT(1)
300+
continue;
301+
}
302+
tv_expr = reinterpret_cast<expression::TupleValueExpression *>(
303+
*(tv_exprs.begin()));
304+
}
305+
279306
if (build_table_aliases.count(tv_expr->GetTableName())) {
280307
build_table_cols_set.insert(col);
281308
} else {

src/optimizer/optimizer_task.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ void OptimizeExpression::execute() {
125125
//===--------------------------------------------------------------------===//
126126
void ExploreGroup::execute() {
127127
if (group_->HasExplored()) return;
128-
LOG_DEBUG("ExploreGroup::execute() ");
128+
// LOG_DEBUG("ExploreGroup::execute() ");
129129

130130
for (auto &logical_expr : group_->GetLogicalExpressions()) {
131131
PushTask(new ExploreExpression(logical_expr.get(), context_));
@@ -140,7 +140,7 @@ void ExploreGroup::execute() {
140140
// ExploreExpression
141141
//===--------------------------------------------------------------------===//
142142
void ExploreExpression::execute() {
143-
LOG_DEBUG("ExploreExpression::execute() ");
143+
// LOG_DEBUG("ExploreExpression::execute() ");
144144
std::vector<RuleWithPromise> valid_rules;
145145

146146
// Construct valid transformation rules from rule set

src/optimizer/plan_generator.cpp

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -144,19 +144,15 @@ void PlanGenerator::Visit(const PhysicalOrderBy *) {
144144
void PlanGenerator::Visit(const PhysicalHashGroupBy *op) {
145145
auto having_predicates =
146146
expression::ExpressionUtil::JoinAnnotatedExprs(op->having);
147-
expression::ExpressionUtil::EvaluateExpression(children_expr_map_,
148-
having_predicates.get());
149147
BuildAggregatePlan(AggregateType::HASH, &op->columns,
150-
having_predicates.release());
148+
std::move(having_predicates));
151149
}
152150

153151
void PlanGenerator::Visit(const PhysicalSortGroupBy *op) {
154152
auto having_predicates =
155153
expression::ExpressionUtil::JoinAnnotatedExprs(op->having);
156-
expression::ExpressionUtil::EvaluateExpression(children_expr_map_,
157-
having_predicates.get());
158154
BuildAggregatePlan(AggregateType::HASH, &op->columns,
159-
having_predicates.release());
155+
std::move(having_predicates));
160156
}
161157

162158
void PlanGenerator::Visit(const PhysicalAggregate *) {
@@ -457,7 +453,7 @@ void PlanGenerator::BuildAggregatePlan(
457453
AggregateType aggr_type,
458454
const std::vector<std::shared_ptr<expression::AbstractExpression>>
459455
*groupby_cols,
460-
expression::AbstractExpression *having_predicate) {
456+
std::unique_ptr<expression::AbstractExpression> having_predicate) {
461457
vector<planner::AggregatePlan::AggTerm> aggr_terms;
462458
vector<catalog::Column> output_schema_columns;
463459
DirectMapList dml;
@@ -466,12 +462,15 @@ void PlanGenerator::BuildAggregatePlan(
466462
auto &child_expr_map = children_expr_map_[0];
467463

468464
auto agg_id = 0;
465+
ExprMap output_expr_map;
469466
for (size_t idx = 0; idx < output_cols_.size(); ++idx) {
470467
auto expr = output_cols_[idx];
468+
output_expr_map[expr] = idx;
471469
expr->DeduceExpressionType();
472470
expression::ExpressionUtil::EvaluateExpression(children_expr_map_, expr);
473471
if (expression::ExpressionUtil::IsAggregateExpression(
474472
expr->GetExpressionType())) {
473+
LOG_DEBUG("Output Column for Agg %lu : %s", idx, expr->GetInfo().c_str());
475474
auto agg_expr = reinterpret_cast<expression::AggregateExpression *>(expr);
476475
auto agg_col = expr->GetModifiableChild(0);
477476
// Maps the aggregate value in th right tuple to the output
@@ -501,7 +500,14 @@ void PlanGenerator::BuildAggregatePlan(
501500
// Generate the Aggregate Plan
502501
unique_ptr<const planner::ProjectInfo> proj_info(
503502
new planner::ProjectInfo(move(tl), move(dml)));
504-
unique_ptr<const expression::AbstractExpression> predicate(having_predicate);
503+
// LOG_DEBUG(
504+
// "Having predicate : %s ",
505+
// having_predicate == nullptr ? "" :
506+
// having_predicate->GetInfo().c_str());
507+
expression::ExpressionUtil::EvaluateExpression({output_expr_map},
508+
having_predicate.get());
509+
unique_ptr<const expression::AbstractExpression> predicate(
510+
having_predicate.release());
505511
// TODO(boweic): Ditto, since the aggregate plan will own the schema, we may
506512
// want make the parameter as unique_ptr
507513
shared_ptr<const catalog::Schema> output_table_schema(

0 commit comments

Comments
 (0)