diff --git a/src/include/common/internal_types.h b/src/include/common/internal_types.h index 96b45f9e42b..21de29a080e 100644 --- a/src/include/common/internal_types.h +++ b/src/include/common/internal_types.h @@ -1383,6 +1383,10 @@ enum class RuleType : uint32_t { PULL_FILTER_THROUGH_MARK_JOIN, PULL_FILTER_THROUGH_AGGREGATION, + // AST rewrite rules (logical -> logical) + // Removes ConstantValueExpression = ConstantValueExpression + COMP_EQUALITY_ELIMINATION, + // Place holder to generate number of rules compile time NUM_RULES diff --git a/src/include/optimizer/absexpr_expression.h b/src/include/optimizer/absexpr_expression.h new file mode 100644 index 00000000000..745881ccfb0 --- /dev/null +++ b/src/include/optimizer/absexpr_expression.h @@ -0,0 +1,183 @@ +//===----------------------------------------------------------------------===// +// +// Peloton +// +// absexpr_expression.h +// +// Identification: src/include/optimizer/absexpr_expression.h +// +//===----------------------------------------------------------------------===// + +#pragma once + +// AbstractExpression Definition +#include "expression/abstract_expression.h" +#include "expression/conjunction_expression.h" +#include "expression/comparison_expression.h" +#include "expression/constant_value_expression.h" + +#include +#include + +namespace peloton { +namespace optimizer { + +// (TODO): rethink the AbsExpr_Container/Expression approach in comparion to abstract +// Most of the core rule/optimizer code relies on the concept of an Operator / +// OperatorExpression and the interface that the two functions respectively expose. +// +// The annoying part is that an AbstractExpression blends together an Operator +// and OperatorExpression. Second part, the AbstractExpression does not export the +// correct interface that the rest of the system depends on. +// +// As an extreme level of simplification (sort of hacky), an AbsExpr_Container is +// analogous to Operator and wraps a single AbstractExpression node. AbsExpr_Expression +// is analogous to OperatorExpression. +// +// AbsExpr_Container does *not* handle memory correctly w.r.t internal instantiations +// from Rule transformation. This is since Peloton itself mixes unique_ptrs and +// hands out raw pointers which makes adding a shared_ptr here extremely problematic. +// terrier uses only shared_ptr when dealing with AbstractExpression trees. + +class AbsExpr_Container { + public: + AbsExpr_Container(); + + AbsExpr_Container(const expression::AbstractExpression *expr) { + node = expr; + } + + // Return operator type + ExpressionType GetType() const { + if (IsDefined()) { + return node->GetExpressionType(); + } + return ExpressionType::INVALID; + } + + const expression::AbstractExpression *GetExpr() const { + return node; + } + + // Operator contains Logical node + bool IsLogical() const { + return true; + } + + // Operator contains Physical node + bool IsPhysical() const { + return false; + } + + std::string GetName() const { + if (IsDefined()) { + return node->GetExpressionName(); + } + + return "Undefined"; + } + + hash_t Hash() const { + if (IsDefined()) { + return node->Hash(); + } + return 0; + } + + bool operator==(const AbsExpr_Container &r) { + if (IsDefined() && r.IsDefined()) { + // (TODO): need a better way to determine deep equality + + // NOTE: + // Without proper equality determinations, the groups will + // not be assigned correctly. Arguably, terrier does this + // better because a blind ExactlyEquals on different types + // of ConstantValueExpression under Peloton will crash! + + // For now, just return (false). + // I don't anticipate this will affect correctness, just + // performance, since duplicate trees will have to evaluated + // over and over again, rather than being able to "borrow" + // a previous tree's rewrite. + // + // Probably not worth to create a "validator" since porting + // this to terrier anyways (?). == does not check Value + // so it's broken. ExactlyEqual requires precondition checking. + return false; + } else if (!IsDefined() && !r.IsDefined()) { + return true; + } + return false; + } + + // Operator contains physical or logical operator node + bool IsDefined() const { + return node != nullptr; + } + + //(TODO): fix memory management once go to terrier + expression::AbstractExpression *Rebuild(std::vector children) { + switch (GetType()) { + case ExpressionType::COMPARE_EQUAL: + case ExpressionType::COMPARE_NOTEQUAL: + case ExpressionType::COMPARE_LESSTHAN: + case ExpressionType::COMPARE_GREATERTHAN: + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + case ExpressionType::COMPARE_LIKE: + case ExpressionType::COMPARE_NOTLIKE: + case ExpressionType::COMPARE_IN: + case ExpressionType::COMPARE_DISTINCT_FROM: { + PELOTON_ASSERT(children.size() == 2); + return new expression::ComparisonExpression(GetType(), children[0], children[1]); + } + case ExpressionType::CONJUNCTION_AND: + case ExpressionType::CONJUNCTION_OR: { + PELOTON_ASSERT(children.size() == 2); + return new expression::ConjunctionExpression(GetType(), children[0], children[1]); + } + case ExpressionType::VALUE_CONSTANT: { + PELOTON_ASSERT(children.size() == 0); + auto cve = static_cast(node); + return new expression::ConstantValueExpression(cve->GetValue()); + } + default: { + int type = static_cast(GetType()); + LOG_ERROR("Unimplemented Rebuild() for %d found", type); + return nullptr; + } + } + } + + private: + const expression::AbstractExpression *node; +}; + +class AbsExpr_Expression { + public: + AbsExpr_Expression(AbsExpr_Container op): op(op) {}; + + void PushChild(std::shared_ptr op) { + children.push_back(op); + } + + void PopChild() { + children.pop_back(); + } + + const std::vector> &Children() const { + return children; + } + + const AbsExpr_Container &Op() const { + return op; + } + + private: + AbsExpr_Container op; + std::vector> children; +}; + +} // namespace optimizer +} // namespace peloton + diff --git a/src/include/optimizer/binding.h b/src/include/optimizer/binding.h index 7a6d772813d..57756b07b83 100644 --- a/src/include/optimizer/binding.h +++ b/src/include/optimizer/binding.h @@ -24,63 +24,71 @@ namespace peloton { namespace optimizer { class Optimizer; + +template class Memo; //===--------------------------------------------------------------------===// // Binding Iterator //===--------------------------------------------------------------------===// +template class BindingIterator { public: - BindingIterator(Memo& memo) : memo_(memo) {} + BindingIterator(Memo& memo) : memo_(memo) {} virtual ~BindingIterator(){}; virtual bool HasNext() = 0; - virtual std::shared_ptr Next() = 0; + virtual std::shared_ptr Next() = 0; protected: - Memo &memo_; + Memo &memo_; }; -class GroupBindingIterator : public BindingIterator { +template +class GroupBindingIterator : public BindingIterator { public: - GroupBindingIterator(Memo& memo, GroupID id, - std::shared_ptr pattern); + GroupBindingIterator(Memo& memo, + GroupID id, + std::shared_ptr> pattern); bool HasNext() override; - std::shared_ptr Next() override; + std::shared_ptr Next() override; private: GroupID group_id_; - std::shared_ptr pattern_; - Group *target_group_; + std::shared_ptr> pattern_; + Group *target_group_; size_t num_group_items_; + // Internal function for HasNext() + bool HasNextBinding(); + size_t current_item_index_; - std::unique_ptr current_iterator_; + std::unique_ptr> current_iterator_; }; -class GroupExprBindingIterator : public BindingIterator { +template +class GroupExprBindingIterator : public BindingIterator { public: - GroupExprBindingIterator(Memo& memo, - GroupExpression *gexpr, - std::shared_ptr pattern); + GroupExprBindingIterator(Memo& memo, + GroupExpression *gexpr, + std::shared_ptr> pattern); bool HasNext() override; - std::shared_ptr Next() override; + std::shared_ptr Next() override; private: - GroupExpression* gexpr_; - std::shared_ptr pattern_; + GroupExpression* gexpr_; + std::shared_ptr> pattern_; bool first_; bool has_next_; - std::shared_ptr current_binding_; - std::vector>> - children_bindings_; + std::shared_ptr current_binding_; + std::vector>> children_bindings_; std::vector children_bindings_pos_; }; diff --git a/src/include/optimizer/child_property_deriver.h b/src/include/optimizer/child_property_deriver.h index 914cc77ab27..6ec2c09400a 100644 --- a/src/include/optimizer/child_property_deriver.h +++ b/src/include/optimizer/child_property_deriver.h @@ -13,10 +13,12 @@ #pragma once #include #include "optimizer/operator_visitor.h" +#include "optimizer/operator_expression.h" namespace peloton { namespace optimizer { +template class Memo; } @@ -33,8 +35,10 @@ class ChildPropertyDeriver : public OperatorVisitor { public: std::vector, std::vector>>> - GetProperties(GroupExpression *gexpr, - std::shared_ptr requirements, Memo *memo); + + GetProperties(GroupExpression *gexpr, + std::shared_ptr requirements, + Memo *memo); void Visit(const DummyScan *) override; void Visit(const PhysicalSeqScan *) override; @@ -74,8 +78,8 @@ class ChildPropertyDeriver : public OperatorVisitor { * @brief We need the memo and gexpr because some property may depend on * child's schema */ - Memo *memo_; - GroupExpression *gexpr_; + Memo *memo_; + GroupExpression *gexpr_; }; } // namespace optimizer diff --git a/src/include/optimizer/cost_model/abstract_cost_model.h b/src/include/optimizer/cost_model/abstract_cost_model.h index 95a593f04d9..e01548739b1 100644 --- a/src/include/optimizer/cost_model/abstract_cost_model.h +++ b/src/include/optimizer/cost_model/abstract_cost_model.h @@ -13,10 +13,12 @@ #pragma once #include "optimizer/operator_visitor.h" +#include "optimizer/operator_expression.h" namespace peloton { namespace optimizer { +template class Memo; // Default cost when cost model cannot compute correct cost. @@ -34,7 +36,8 @@ static constexpr double DEFAULT_OPERATOR_COST = 0.0025; class AbstractCostModel : public OperatorVisitor { public: - virtual double CalculateCost(GroupExpression *gexpr, Memo *memo, + virtual double CalculateCost(GroupExpression *gexpr, + Memo *memo, concurrency::TransactionContext *txn) = 0; }; diff --git a/src/include/optimizer/cost_model/default_cost_model.h b/src/include/optimizer/cost_model/default_cost_model.h index a92cb091db7..a89bd4ee3a3 100644 --- a/src/include/optimizer/cost_model/default_cost_model.h +++ b/src/include/optimizer/cost_model/default_cost_model.h @@ -23,14 +23,17 @@ namespace peloton { namespace optimizer { +template class Memo; + // Derive cost for a physical group expression class DefaultCostModel : public AbstractCostModel { public: DefaultCostModel(){}; - double CalculateCost(GroupExpression *gexpr, Memo *memo, - concurrency::TransactionContext *txn) { + double CalculateCost(GroupExpression *gexpr, + Memo *memo, + concurrency::TransactionContext *txn) { gexpr_ = gexpr; memo_ = memo; txn_ = txn; @@ -151,8 +154,8 @@ class DefaultCostModel : public AbstractCostModel { return child_num_rows * DEFAULT_TUPLE_COST; } - GroupExpression *gexpr_; - Memo *memo_; + GroupExpression *gexpr_; + Memo *memo_; concurrency::TransactionContext *txn_; double output_cost_ = 0; }; diff --git a/src/include/optimizer/cost_model/postgres_cost_model.h b/src/include/optimizer/cost_model/postgres_cost_model.h index 2632a247a39..523983a89d1 100644 --- a/src/include/optimizer/cost_model/postgres_cost_model.h +++ b/src/include/optimizer/cost_model/postgres_cost_model.h @@ -28,13 +28,16 @@ namespace peloton { namespace optimizer { +template class Memo; + // Derive cost for a physical group expression class PostgresCostModel : public AbstractCostModel { public: PostgresCostModel(){}; - double CalculateCost(GroupExpression *gexpr, Memo *memo, + double CalculateCost(GroupExpression *gexpr, + Memo *memo, concurrency::TransactionContext *txn) override { gexpr_ = gexpr; memo_ = memo; @@ -230,8 +233,8 @@ class PostgresCostModel : public AbstractCostModel { } - GroupExpression *gexpr_; - Memo *memo_; + GroupExpression *gexpr_; + Memo *memo_; concurrency::TransactionContext *txn_; double output_cost_ = 0; @@ -279,4 +282,4 @@ class PostgresCostModel : public AbstractCostModel { }; } // namespace optimizer -} // namespace peloton \ No newline at end of file +} // namespace peloton diff --git a/src/include/optimizer/cost_model/trivial_cost_model.h b/src/include/optimizer/cost_model/trivial_cost_model.h index 2c5994ee728..f755626f083 100644 --- a/src/include/optimizer/cost_model/trivial_cost_model.h +++ b/src/include/optimizer/cost_model/trivial_cost_model.h @@ -31,12 +31,15 @@ namespace peloton { namespace optimizer { +template class Memo; + class TrivialCostModel : public AbstractCostModel { public: TrivialCostModel(){}; - double CalculateCost(GroupExpression *gexpr, Memo *memo, + double CalculateCost(GroupExpression *gexpr, + Memo *memo, concurrency::TransactionContext *txn) override { gexpr_ = gexpr; memo_ = memo; @@ -109,11 +112,11 @@ class TrivialCostModel : public AbstractCostModel { } private: - GroupExpression *gexpr_; - Memo *memo_; + GroupExpression *gexpr_; + Memo *memo_; concurrency::TransactionContext *txn_; double output_cost_ = 0; }; } // namespace optimizer -} // namespace peloton \ No newline at end of file +} // namespace peloton diff --git a/src/include/optimizer/group.h b/src/include/optimizer/group.h index a0606d1597c..9129a4952a8 100644 --- a/src/include/optimizer/group.h +++ b/src/include/optimizer/group.h @@ -32,6 +32,7 @@ class ColumnStats; //===--------------------------------------------------------------------===// // Group //===--------------------------------------------------------------------===// +template class Group : public Printable { public: Group(GroupID id, std::unordered_set table_alias); @@ -39,29 +40,30 @@ class Group : public Printable { // If the GroupExpression is generated by applying a // property enforcer, we add them to enforced_exprs_ // which will not be enumerated during OptimizeExpression - void AddExpression(std::shared_ptr expr, bool enforced); + void AddExpression(std::shared_ptr> expr, + bool enforced); void RemoveLogicalExpression(size_t idx) { logical_expressions_.erase(logical_expressions_.begin() + idx); } - bool SetExpressionCost(GroupExpression *expr, double cost, + bool SetExpressionCost(GroupExpression *expr, double cost, std::shared_ptr &properties); - GroupExpression *GetBestExpression(std::shared_ptr &properties); + GroupExpression *GetBestExpression(std::shared_ptr &properties); inline const std::unordered_set &GetTableAliases() const { return table_aliases_; } // TODO: thread safety? - const std::vector> GetLogicalExpressions() + const std::vector>> GetLogicalExpressions() const { return logical_expressions_; } // TODO: thread safety? - const std::vector> GetPhysicalExpressions() + const std::vector>> GetPhysicalExpressions() const { return physical_expressions_; } @@ -105,7 +107,7 @@ class Group : public Printable { // This should only be called in rewrite phase to retrieve the only logical // expr in the group - inline GroupExpression *GetLogicalExpression() { + inline GroupExpression *GetLogicalExpression() { PELOTON_ASSERT(logical_expressions_.size() == 1); PELOTON_ASSERT(physical_expressions_.size() == 0); return logical_expressions_[0].get(); @@ -117,15 +119,15 @@ class Group : public Printable { // TODO(boweic) Do not use string, store table alias id std::unordered_set table_aliases_; std::unordered_map, - std::tuple, PropSetPtrHash, + std::tuple *>, PropSetPtrHash, PropSetPtrEq> lowest_cost_expressions_; // Whether equivalent logical expressions have been explored for this group bool has_explored_; - std::vector> logical_expressions_; - std::vector> physical_expressions_; - std::vector> enforced_exprs_; + std::vector>> logical_expressions_; + std::vector>> physical_expressions_; + std::vector>> enforced_exprs_; // We'll add stats lazily // TODO(boweic): diff --git a/src/include/optimizer/group_expression.h b/src/include/optimizer/group_expression.h index 303ebaf036e..af71c9e75e2 100644 --- a/src/include/optimizer/group_expression.h +++ b/src/include/optimizer/group_expression.h @@ -25,6 +25,7 @@ namespace peloton { namespace optimizer { +template class Rule; using GroupID = int32_t; @@ -32,9 +33,10 @@ using GroupID = int32_t; //===--------------------------------------------------------------------===// // Group Expression //===--------------------------------------------------------------------===// +template class GroupExpression { public: - GroupExpression(Operator op, std::vector child_groups); + GroupExpression(Node op, std::vector child_groups); GroupID GetGroupID() const; @@ -46,7 +48,7 @@ class GroupExpression { GroupID GetChildGroupId(int child_idx) const; - Operator Op() const; + Node Op() const; double GetCost(std::shared_ptr& requirements) const; @@ -61,11 +63,11 @@ class GroupExpression { hash_t Hash() const; - bool operator==(const GroupExpression &r); + bool operator==(const GroupExpression &r); - void SetRuleExplored(Rule *rule); + void SetRuleExplored(Rule *rule); - bool HasRuleExplored(Rule *rule); + bool HasRuleExplored(Rule *rule); void SetDerivedStats() { stats_derived_ = true; } @@ -75,7 +77,7 @@ class GroupExpression { private: GroupID group_id; - Operator op; + Node op; std::vector child_groups; std::bitset(RuleType::NUM_RULES)> rule_mask_; bool stats_derived_; @@ -92,9 +94,9 @@ class GroupExpression { namespace std { -template <> -struct hash { - typedef peloton::optimizer::GroupExpression argument_type; +template +struct hash> { + typedef peloton::optimizer::GroupExpression argument_type; typedef std::size_t result_type; result_type operator()(argument_type const &s) const { return s.Hash(); } }; diff --git a/src/include/optimizer/input_column_deriver.h b/src/include/optimizer/input_column_deriver.h index ef66823bba0..dd368f8636f 100644 --- a/src/include/optimizer/input_column_deriver.h +++ b/src/include/optimizer/input_column_deriver.h @@ -27,6 +27,8 @@ class AggregatePlan; namespace optimizer { class OperatorExpression; + +template class Memo; } @@ -44,8 +46,9 @@ class InputColumnDeriver : public OperatorVisitor { std::pair, std::vector>> DeriveInputColumns( - GroupExpression *gexpr, std::shared_ptr properties, - std::vector required_cols, Memo *memo); + GroupExpression *gexpr, std::shared_ptr properties, + std::vector required_cols, + Memo *memo); void Visit(const DummyScan *) override; @@ -108,8 +111,8 @@ class InputColumnDeriver : public OperatorVisitor { * property */ void Passdown(); - GroupExpression *gexpr_; - Memo *memo_; + GroupExpression *gexpr_; + Memo *memo_; /** * @brief The derived output columns and input columns, note that the current diff --git a/src/include/optimizer/memo.h b/src/include/optimizer/memo.h index 951caa4c94d..4bc77009de8 100644 --- a/src/include/optimizer/memo.h +++ b/src/include/optimizer/memo.h @@ -22,13 +22,15 @@ namespace peloton { namespace optimizer { +template struct GExprPtrHash { - std::size_t operator()(GroupExpression* const& s) const { return s->Hash(); } + std::size_t operator()(GroupExpression* const& s) const { return s->Hash(); } }; +template struct GExprPtrEq { - bool operator()(GroupExpression* const& t1, - GroupExpression* const& t2) const { + bool operator()(GroupExpression* const& t1, + GroupExpression* const& t2) const { return *t1 == *t2; } }; @@ -36,6 +38,7 @@ struct GExprPtrEq { //===--------------------------------------------------------------------===// // Memo //===--------------------------------------------------------------------===// +template class Memo { public: Memo(); @@ -48,15 +51,17 @@ class Memo { * target_group: an optional target group to insert expression into * return: existing expression if found. Otherwise, return the new expr */ - GroupExpression* InsertExpression(std::shared_ptr gexpr, - bool enforced); + GroupExpression* InsertExpression( + std::shared_ptr> gexpr, + bool enforced); - GroupExpression* InsertExpression(std::shared_ptr gexpr, - GroupID target_group, bool enforced); + GroupExpression* InsertExpression( + std::shared_ptr> gexpr, + GroupID target_group, bool enforced); - std::vector>& Groups(); + std::vector>>& Groups(); - Group* GetGroupByID(GroupID id); + Group* GetGroupByID(GroupID id); const std::string GetInfo(int num_indent) const; const std::string GetInfo() const; @@ -68,10 +73,10 @@ class Memo { //===--------------------------------------------------------------------===// // For rewrite phase: remove and add expression directly for the set //===--------------------------------------------------------------------===// - void RemoveParExpressionForRewirte(GroupExpression* gexpr) { + void RemoveParExpressionForRewirte(GroupExpression* gexpr) { group_expressions_.erase(gexpr); } - void AddParExpressionForRewrite(GroupExpression* gexpr) { + void AddParExpressionForRewrite(GroupExpression* gexpr) { group_expressions_.insert(gexpr); } // When a rewrite rule is applied, we need to replace the original gexpr with @@ -84,12 +89,18 @@ class Memo { } private: - GroupID AddNewGroup(std::shared_ptr gexpr); + GroupID AddNewGroup(std::shared_ptr> gexpr); + + // Internal InsertExpression function + GroupExpression* InsertExpr( + std::shared_ptr> gexpr, + GroupID target_group, bool enforced); // The group owns the group expressions, not the memo - std::unordered_set - group_expressions_; - std::vector> groups_; + std::unordered_set*, + GExprPtrHash, + GExprPtrEq> group_expressions_; + std::vector>> groups_; size_t rule_set_size_; }; diff --git a/src/include/optimizer/optimize_context.h b/src/include/optimizer/optimize_context.h index b5568208d9e..15747a44b5a 100644 --- a/src/include/optimizer/optimize_context.h +++ b/src/include/optimizer/optimize_context.h @@ -22,18 +22,20 @@ namespace peloton { namespace optimizer { +template class OptimizerMetadata; +template class OptimizeContext { public: - OptimizeContext(OptimizerMetadata *metadata, + OptimizeContext(OptimizerMetadata *metadata, std::shared_ptr required_prop, double cost_upper_bound = std::numeric_limits::max()) : metadata(metadata), required_prop(required_prop), cost_upper_bound(cost_upper_bound) {} - OptimizerMetadata *metadata; + OptimizerMetadata *metadata; std::shared_ptr required_prop; double cost_upper_bound; }; diff --git a/src/include/optimizer/optimizer.h b/src/include/optimizer/optimizer.h index ebf82d625b4..668049b5333 100644 --- a/src/include/optimizer/optimizer.h +++ b/src/include/optimizer/optimizer.h @@ -60,7 +60,10 @@ enum CostModels {DEFAULT, POSTGRES, TRIVIAL}; // Optimizer //===--------------------------------------------------------------------===// class Optimizer : public AbstractOptimizer { + template friend class BindingIterator; + + template friend class GroupBindingIterator; friend class ::peloton::test:: @@ -85,16 +88,18 @@ class Optimizer : public AbstractOptimizer { void Reset() override; - OptimizerMetadata &GetMetadata() { return metadata_; } + OptimizerMetadata &GetMetadata() { return metadata_; } /* For test purposes only */ - std::shared_ptr TestInsertQueryTree( - parser::SQLStatement *tree, concurrency::TransactionContext *txn) { + std::shared_ptr> TestInsertQueryTree( + parser::SQLStatement *tree, + concurrency::TransactionContext *txn) { + return InsertQueryTree(tree, txn); } /* For test purposes only */ - void TestExecuteTaskStack(OptimizerTaskStack &task_stack, int root_group_id, - std::shared_ptr root_context) { + void TestExecuteTaskStack(OptimizerTaskStack &task_stack, int root_group_id, + std::shared_ptr> root_context) { return ExecuteTaskStack(task_stack, root_group_id, root_context); } @@ -119,7 +124,7 @@ class Optimizer : public AbstractOptimizer { * tree: a peloton query tree representing a select query * return: the root group expression for the inserted query */ - std::shared_ptr InsertQueryTree( + std::shared_ptr> InsertQueryTree( parser::SQLStatement *tree, concurrency::TransactionContext *txn); /* GetQueryTreeRequiredProperties - get the required physical properties for @@ -161,12 +166,12 @@ class Optimizer : public AbstractOptimizer { * root_context: the OptimizerContext to use that maintains required *properties */ - void ExecuteTaskStack(OptimizerTaskStack &task_stack, int root_group_id, - std::shared_ptr root_context); + void ExecuteTaskStack(OptimizerTaskStack &task_stack, int root_group_id, + std::shared_ptr> root_context); ////////////////////////////////////////////////////////////////////////////// /// Metadata - OptimizerMetadata metadata_; + OptimizerMetadata metadata_; std::unique_ptr cost_model_; }; diff --git a/src/include/optimizer/optimizer_metadata.h b/src/include/optimizer/optimizer_metadata.h index 3f33e3ee8b1..84a6977ee09 100644 --- a/src/include/optimizer/optimizer_metadata.h +++ b/src/include/optimizer/optimizer_metadata.h @@ -26,9 +26,13 @@ class CatalogCache; } namespace optimizer { +template class OptimizerTaskPool; + +template class RuleSet; +template class OptimizerMetadata { public: @@ -37,45 +41,45 @@ class OptimizerMetadata { settings::SettingId::task_execution_timeout)), timer(Timer()) {} - Memo memo; - RuleSet rule_set; - OptimizerTaskPool *task_pool; + Memo memo; + RuleSet rule_set; + OptimizerTaskPool *task_pool; std::unique_ptr cost_model; catalog::CatalogCache *catalog_cache; unsigned int timeout_limit; Timer timer; concurrency::TransactionContext* txn; - void SetTaskPool(OptimizerTaskPool *task_pool) { + void SetTaskPool(OptimizerTaskPool *task_pool) { this->task_pool = task_pool; } - std::shared_ptr MakeGroupExpression( - std::shared_ptr expr) { + std::shared_ptr> MakeGroupExpression( + std::shared_ptr expr) { std::vector child_groups; for (auto &child : expr->Children()) { auto gexpr = MakeGroupExpression(child); memo.InsertExpression(gexpr, false); child_groups.push_back(gexpr->GetGroupID()); } - return std::make_shared(expr->Op(), - std::move(child_groups)); + return std::make_shared>(expr->Op(), + std::move(child_groups)); } - bool RecordTransformedExpression(std::shared_ptr expr, - std::shared_ptr &gexpr) { + bool RecordTransformedExpression(std::shared_ptr expr, + std::shared_ptr> &gexpr) { return RecordTransformedExpression(expr, gexpr, UNDEFINED_GROUP); } - bool RecordTransformedExpression(std::shared_ptr expr, - std::shared_ptr &gexpr, + bool RecordTransformedExpression(std::shared_ptr expr, + std::shared_ptr> &gexpr, GroupID target_group) { gexpr = MakeGroupExpression(expr); return (memo.InsertExpression(gexpr, target_group, false) == gexpr.get()); } // TODO(boweic): check if we really need to use shared_ptr - void ReplaceRewritedExpression(std::shared_ptr expr, + void ReplaceRewritedExpression(std::shared_ptr expr, GroupID target_group) { memo.EraseExpression(target_group); memo.InsertExpression(MakeGroupExpression(expr), target_group, false); diff --git a/src/include/optimizer/optimizer_task.h b/src/include/optimizer/optimizer_task.h index fb2edeaa5db..173c64075c6 100644 --- a/src/include/optimizer/optimizer_task.h +++ b/src/include/optimizer/optimizer_task.h @@ -24,14 +24,33 @@ class AbstractExpression; } namespace optimizer { +template class OptimizeContext; + +template class Memo; + +template class Rule; + +template struct RuleWithPromise; + +template class RuleSet; + +template class Group; + +template class GroupExpression; + +template class OptimizerMetadata; + +enum class OpType; +class Operator; +class OperatorExpression; class PropertySet; enum class RewriteRuleSetName : uint32_t; using GroupID = int32_t; @@ -53,9 +72,10 @@ enum class OptimizerTaskType { /** * @brief The base class for tasks in the optimizer */ +template class OptimizerTask { public: - OptimizerTask(std::shared_ptr context, + OptimizerTask(std::shared_ptr> context, OptimizerTaskType type) : type_(type), context_(context) {} @@ -71,24 +91,24 @@ class OptimizerTask { * @param valid_rules The valid rules to apply in the current rule set will be * append to valid_rules, with their promises */ - static void ConstructValidRules(GroupExpression *group_expr, - OptimizeContext *context, - std::vector> &rules, - std::vector &valid_rules); + static void ConstructValidRules(GroupExpression *group_expr, + OptimizeContext *context, + std::vector>> &rules, + std::vector> &valid_rules); virtual void execute() = 0; - void PushTask(OptimizerTask *task); + void PushTask(OptimizerTask *task); - inline Memo &GetMemo() const; + inline Memo &GetMemo() const; - inline RuleSet &GetRuleSet() const; + inline RuleSet &GetRuleSet() const; virtual ~OptimizerTask(){}; protected: OptimizerTaskType type_; - std::shared_ptr context_; + std::shared_ptr> context_; }; /** @@ -96,15 +116,16 @@ class OptimizerTask { * equivalent operator trees if not already explored 2. Cost all physical * operator trees given the current context */ -class OptimizeGroup : public OptimizerTask { +class OptimizeGroup : public OptimizerTask { public: - OptimizeGroup(Group *group, std::shared_ptr context) + OptimizeGroup(Group *group, + std::shared_ptr> context) : OptimizerTask(context, OptimizerTaskType::OPTIMIZE_GROUP), group_(group) {} virtual void execute() override; private: - Group *group_; + Group *group_; }; /** @@ -114,31 +135,32 @@ class OptimizeGroup : public OptimizerTask { * promises so that a physical transformation rule is applied before a logical * transformation rule */ -class OptimizeExpression : public OptimizerTask { +class OptimizeExpression : public OptimizerTask { public: - OptimizeExpression(GroupExpression *group_expr, - std::shared_ptr context) + OptimizeExpression(GroupExpression *group_expr, + std::shared_ptr> context) : OptimizerTask(context, OptimizerTaskType::OPTIMIZE_EXPR), group_expr_(group_expr) {} virtual void execute() override; private: - GroupExpression *group_expr_; + GroupExpression *group_expr_; }; /** * @brief Generate all logical transformation rules by applying logical * transformation rules to logical operators in the group until saturated */ -class ExploreGroup : public OptimizerTask { +class ExploreGroup : public OptimizerTask { public: - ExploreGroup(Group *group, std::shared_ptr context) + ExploreGroup(Group *group, + std::shared_ptr> context) : OptimizerTask(context, OptimizerTaskType::EXPLORE_GROUP), group_(group) {} virtual void execute() override; private: - Group *group_; + Group *group_; }; /** @@ -146,16 +168,16 @@ class ExploreGroup : public OptimizerTask { * pattern * in the same group is found, also apply logical transformation rule for it. */ -class ExploreExpression : public OptimizerTask { +class ExploreExpression : public OptimizerTask { public: - ExploreExpression(GroupExpression *group_expr, - std::shared_ptr context) + ExploreExpression(GroupExpression *group_expr, + std::shared_ptr> context) : OptimizerTask(context, OptimizerTaskType::EXPLORE_EXPR), group_expr_(group_expr) {} virtual void execute() override; private: - GroupExpression *group_expr_; + GroupExpression *group_expr_; }; /** @@ -164,10 +186,11 @@ class ExploreExpression : public OptimizerTask { * to the new group expression based on the explore flag. If the rule is a * physical implementation rule, we directly cost the physical expression */ -class ApplyRule : public OptimizerTask { +class ApplyRule : public OptimizerTask { public: - ApplyRule(GroupExpression *group_expr, Rule *rule, - std::shared_ptr context, bool explore = false) + ApplyRule(GroupExpression *group_expr, + Rule *rule, + std::shared_ptr> context, bool explore = false) : OptimizerTask(context, OptimizerTaskType::APPLY_RULE), group_expr_(group_expr), rule_(rule), @@ -175,8 +198,8 @@ class ApplyRule : public OptimizerTask { virtual void execute() override; private: - GroupExpression *group_expr_; - Rule *rule_; + GroupExpression *group_expr_; + Rule *rule_; bool explore_only; }; @@ -187,10 +210,10 @@ class ApplyRule : public OptimizerTask { * current expression's cost is larger than the upper bound of the current * group */ -class OptimizeInputs : public OptimizerTask { +class OptimizeInputs : public OptimizerTask { public: - OptimizeInputs(GroupExpression *group_expr, - std::shared_ptr context) + OptimizeInputs(GroupExpression *group_expr, + std::shared_ptr> context) : OptimizerTask(context, OptimizerTaskType::OPTIMIZE_INPUTS), group_expr_(group_expr) {} @@ -208,7 +231,7 @@ class OptimizeInputs : public OptimizerTask { std::vector, std::vector>>> output_input_properties_; - GroupExpression *group_expr_; + GroupExpression *group_expr_; double cur_total_cost_; int cur_child_idx_ = -1; int prev_child_idx_ = -1; @@ -220,11 +243,11 @@ class OptimizeInputs : public OptimizerTask { * child group have the stats, if not, recursively derive the stats. This would * lazily collect the stats for the column needed */ -class DeriveStats : public OptimizerTask { +class DeriveStats : public OptimizerTask { public: - DeriveStats(GroupExpression *gexpr, + DeriveStats(GroupExpression *gexpr, ExprSet required_cols, - std::shared_ptr context) + std::shared_ptr> context) : OptimizerTask(context, OptimizerTaskType::DERIVE_STATS), gexpr_(gexpr), required_cols_(required_cols) {} @@ -237,7 +260,7 @@ class DeriveStats : public OptimizerTask { virtual void execute() override; private: - GroupExpression *gexpr_; + GroupExpression *gexpr_; ExprSet required_cols_; }; @@ -247,11 +270,13 @@ class DeriveStats : public OptimizerTask { * level rewrite. An example is predicate push-down. We only push the predicates * from the upper level to the lower level. */ -class TopDownRewrite : public OptimizerTask { +template +class TopDownRewrite : public OptimizerTask { public: - TopDownRewrite(GroupID group_id, std::shared_ptr context, + TopDownRewrite(GroupID group_id, + std::shared_ptr> context, RewriteRuleSetName rule_set_name) - : OptimizerTask(context, OptimizerTaskType::TOP_DOWN_REWRITE), + : OptimizerTask(context, OptimizerTaskType::TOP_DOWN_REWRITE), group_id_(group_id), rule_set_name_(rule_set_name) {} virtual void execute() override; @@ -266,11 +291,13 @@ class TopDownRewrite : public OptimizerTask { * that the upper level rewrite in the operator tree will not enable lower * level rewrite. */ -class BottomUpRewrite : public OptimizerTask { +template +class BottomUpRewrite : public OptimizerTask { public: - BottomUpRewrite(GroupID group_id, std::shared_ptr context, + BottomUpRewrite(GroupID group_id, + std::shared_ptr> context, RewriteRuleSetName rule_set_name, bool has_optimized_child) - : OptimizerTask(context, OptimizerTaskType::BOTTOM_UP_REWRITE), + : OptimizerTask(context, OptimizerTaskType::BOTTOM_UP_REWRITE), group_id_(group_id), rule_set_name_(rule_set_name), has_optimized_child_(has_optimized_child) {} diff --git a/src/include/optimizer/optimizer_task_pool.h b/src/include/optimizer/optimizer_task_pool.h index a14789df64a..2ce755e8de0 100644 --- a/src/include/optimizer/optimizer_task_pool.h +++ b/src/include/optimizer/optimizer_task_pool.h @@ -24,32 +24,35 @@ namespace optimizer { * is identical to a stack but we may need to implement a different data * structure for multi-threaded optimization */ + +template class OptimizerTaskPool { public: - virtual std::unique_ptr Pop() = 0; - virtual void Push(OptimizerTask *task) = 0; + virtual std::unique_ptr> Pop() = 0; + virtual void Push(OptimizerTask *task) = 0; virtual bool Empty() = 0; }; /** * @brief Stack implementation of the task pool */ -class OptimizerTaskStack : public OptimizerTaskPool { +template +class OptimizerTaskStack : public OptimizerTaskPool { public: - virtual std::unique_ptr Pop() { + virtual std::unique_ptr> Pop() { auto task = std::move(task_stack_.top()); task_stack_.pop(); return task; } - virtual void Push(OptimizerTask *task) { - task_stack_.push(std::unique_ptr(task)); + virtual void Push(OptimizerTask *task) { + task_stack_.push(std::unique_ptr>(task)); } virtual bool Empty() { return task_stack_.empty(); } private: - std::stack> task_stack_; + std::stack>> task_stack_; }; } // namespace optimizer diff --git a/src/include/optimizer/pattern.h b/src/include/optimizer/pattern.h index 67c52592889..176fb382b9a 100644 --- a/src/include/optimizer/pattern.h +++ b/src/include/optimizer/pattern.h @@ -20,9 +20,13 @@ namespace peloton { namespace optimizer { +/** + * template parameter should *really* only be OpType or ExpressionType + */ +template class Pattern { public: - Pattern(OpType op); + Pattern(OperatorType op); void AddChild(std::shared_ptr child); @@ -30,10 +34,10 @@ class Pattern { inline size_t GetChildPatternsSize() const { return children.size(); } - OpType Type() const; + OperatorType Type() const; private: - OpType _type; + OperatorType _type; std::vector> children; }; diff --git a/src/include/optimizer/property_enforcer.h b/src/include/optimizer/property_enforcer.h index e82b802d84c..c826edbe54d 100644 --- a/src/include/optimizer/property_enforcer.h +++ b/src/include/optimizer/property_enforcer.h @@ -30,8 +30,8 @@ class PropertyEnforcer : public PropertyVisitor { public: - std::shared_ptr EnforceProperty( - GroupExpression* gexpr, Property* property); + std::shared_ptr> EnforceProperty( + GroupExpression* gexpr, Property* property); virtual void Visit(const PropertyColumns *) override; virtual void Visit(const PropertySort *) override; @@ -39,8 +39,8 @@ class PropertyEnforcer : public PropertyVisitor { virtual void Visit(const PropertyLimit *) override; private: - GroupExpression* input_gexpr_; - std::shared_ptr output_gexpr_; + GroupExpression* input_gexpr_; + std::shared_ptr> output_gexpr_; }; } // namespace optimizer diff --git a/src/include/optimizer/rewriter.h b/src/include/optimizer/rewriter.h new file mode 100644 index 00000000000..796b10f7779 --- /dev/null +++ b/src/include/optimizer/rewriter.h @@ -0,0 +1,51 @@ +//===----------------------------------------------------------------------===// +// +// Peloton +// +// rewriter.h +// +// Identification: src/include/optimizer/rewriter.h +// +// Copyright (c) 2015-2018, Carnegie Mellon University Database Group +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +#include "expression/abstract_expression.h" +#include "optimizer/optimizer_metadata.h" +#include "optimizer/optimizer_task_pool.h" +#include "optimizer/absexpr_expression.h" + +namespace peloton { +namespace optimizer { + +class Rewriter { + + public: + Rewriter(const Rewriter &) = delete; + Rewriter &operator=(const Rewriter &) = delete; + Rewriter(Rewriter &&) = delete; + Rewriter &operator=(Rewriter &&) = delete; + + Rewriter(); + + expression::AbstractExpression* RewriteExpression(const expression::AbstractExpression *expr); + void Reset(); + + OptimizerMetadata &GetMetadata() { return metadata_; } + + std::shared_ptr ConvertToAbsExpr(const expression::AbstractExpression *expr); + + private: + expression::AbstractExpression* RebuildExpression(int root_group); + void ExecuteTaskStack(OptimizerTaskStack &task_stack); + void RewriteLoop(int root_group_id); + std::shared_ptr> ConvertTree(const expression::AbstractExpression *expr); + OptimizerMetadata metadata_; +}; + +} // namespace optimizer +} // namespace peloton diff --git a/src/include/optimizer/rule.h b/src/include/optimizer/rule.h index 4ea78a630c6..b7681433405 100644 --- a/src/include/optimizer/rule.h +++ b/src/include/optimizer/rule.h @@ -21,19 +21,18 @@ namespace peloton { namespace optimizer { +template class GroupExpression; #define PHYS_PROMISE 3 #define LOG_PROMISE 1 -/** - * @brief The base class of all rules - */ +template class Rule { public: virtual ~Rule(){}; - std::shared_ptr GetMatchPattern() const { return match_pattern; } + std::shared_ptr> GetMatchPattern() const { return match_pattern; } bool IsPhysical() const { return type_ > RuleType::LogicalPhysicalDelimiter && @@ -58,8 +57,8 @@ class Rule { * @return The promise, the higher the promise, the rule should be applied * sooner */ - virtual int Promise(GroupExpression *group_expr, - OptimizeContext *context) const; + virtual int Promise(GroupExpression *group_expr, + OptimizeContext *context) const; /** * @brief Check if the rule is applicable for the operator expression. The @@ -74,8 +73,8 @@ class Rule { * * @return If the rule is applicable, return true, otherwise return false */ - virtual bool Check(std::shared_ptr expr, - OptimizeContext *context) const = 0; + virtual bool Check(std::shared_ptr expr, + OptimizeContext *context) const = 0; /** * @brief Convert a "before" operator tree to an "after" operator tree @@ -85,76 +84,83 @@ class Rule { * @param context The current optimization context */ virtual void Transform( - std::shared_ptr input, - std::vector> &transformed, - OptimizeContext *context) const = 0; + std::shared_ptr input, + std::vector> &transformed, + OptimizeContext *context) const = 0; inline RuleType GetType() { return type_; } inline uint32_t GetRuleIdx() { return static_cast(type_); } protected: - std::shared_ptr match_pattern; + std::shared_ptr> match_pattern; RuleType type_; }; /** * @brief A struct to store a rule together with its promise */ +template struct RuleWithPromise { - RuleWithPromise(Rule *rule, int promise) : rule(rule), promise(promise) {} + RuleWithPromise(Rule *rule, int promise) : rule(rule), promise(promise) {} - Rule *rule; + Rule *rule; int promise; - bool operator<(const RuleWithPromise &r) const { return promise < r.promise; } - bool operator>(const RuleWithPromise &r) const { return promise > r.promise; } + bool operator<(const RuleWithPromise &r) const { return promise < r.promise; } + bool operator>(const RuleWithPromise &r) const { return promise > r.promise; } }; enum class RewriteRuleSetName : uint32_t { PREDICATE_PUSH_DOWN = 0, - UNNEST_SUBQUERY + UNNEST_SUBQUERY, + COMPARATOR_ELIMINATION }; /** * @brief All the rule sets, including logical transformation rules, physical * implementation rules and rewrite rules */ +template class RuleSet { public: // RuleSet will take the ownership of the rule object RuleSet(); - inline void AddTransformationRule(Rule* rule) { transformation_rules_.emplace_back(rule); } + inline void AddTransformationRule(Rule* rule) { transformation_rules_.emplace_back(rule); } - inline void AddImplementationRule(Rule* rule) { transformation_rules_.emplace_back(rule); } + inline void AddImplementationRule(Rule* rule) { transformation_rules_.emplace_back(rule); } - inline void AddRewriteRule(RewriteRuleSetName set, Rule* rule) { + inline void AddRewriteRule(RewriteRuleSetName set, Rule* rule) { rewrite_rules_map_[static_cast(set)].emplace_back(rule); } - std::vector> &GetTransformationRules() { + std::vector>> &GetTransformationRules() { return transformation_rules_; } - std::vector> &GetImplementationRules() { + std::vector>> &GetImplementationRules() { return implementation_rules_; } - std::vector> &GetRewriteRulesByName( + std::vector>> &GetRewriteRulesByName( RewriteRuleSetName set) { return rewrite_rules_map_[static_cast(set)]; } - std::unordered_map>> &GetRewriteRulesMap() { return rewrite_rules_map_; } + std::unordered_map>>> &GetRewriteRulesMap() { + return rewrite_rules_map_; + } - std::vector> &GetPredicatePushDownRules() { return predicate_push_down_rules_; } + std::vector>> &GetPredicatePushDownRules() { + return predicate_push_down_rules_; + } private: - std::vector> transformation_rules_; - std::vector> implementation_rules_; - std::unordered_map>> rewrite_rules_map_; - std::vector> predicate_push_down_rules_; + std::vector>> transformation_rules_; + std::vector>> implementation_rules_; + std::unordered_map>>> rewrite_rules_map_; + std::vector>> predicate_push_down_rules_; }; } // namespace optimizer diff --git a/src/include/optimizer/rule_impls.h b/src/include/optimizer/rule_impls.h index 57902e744a9..810c3b8e8bb 100644 --- a/src/include/optimizer/rule_impls.h +++ b/src/include/optimizer/rule_impls.h @@ -26,32 +26,32 @@ namespace optimizer { /** * @brief (A join B) -> (B join A) */ -class InnerJoinCommutativity : public Rule { +class InnerJoinCommutativity : public Rule { public: InnerJoinCommutativity(); bool Check(std::shared_ptr plan, - OptimizeContext *context) const override; + OptimizeContext *context) const override; void Transform(std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const override; + OptimizeContext *context) const override; }; /** * @brief (A join B) join C -> A join (B join C) */ -class InnerJoinAssociativity : public Rule { +class InnerJoinAssociativity : public Rule { public: InnerJoinAssociativity(); bool Check(std::shared_ptr plan, - OptimizeContext *context) const override; + OptimizeContext *context) const override; void Transform(std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const override; + OptimizeContext *context) const override; }; //===--------------------------------------------------------------------===// @@ -61,239 +61,239 @@ class InnerJoinAssociativity : public Rule { /** * @brief (Logical Scan -> Sequential Scan) */ -class GetToSeqScan : public Rule { +class GetToSeqScan : public Rule { public: GetToSeqScan(); bool Check(std::shared_ptr plan, - OptimizeContext *context) const override; + OptimizeContext *context) const override; void Transform(std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const override; + OptimizeContext *context) const override; }; -class LogicalExternalFileGetToPhysical : public Rule { +class LogicalExternalFileGetToPhysical : public Rule { public: LogicalExternalFileGetToPhysical(); bool Check(std::shared_ptr plan, - OptimizeContext *context) const override; + OptimizeContext *context) const override; void Transform(std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const override; + OptimizeContext *context) const override; }; /** * @brief Generate dummy scan for queries like "SELECT 1", there's no actual * table to generate */ -class GetToDummyScan : public Rule { +class GetToDummyScan : public Rule { public: GetToDummyScan(); bool Check(std::shared_ptr plan, - OptimizeContext *context) const override; + OptimizeContext *context) const override; void Transform(std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const override; + OptimizeContext *context) const override; }; /** * @brief (Logical Scan -> Index Scan) */ -class GetToIndexScan : public Rule { +class GetToIndexScan : public Rule { public: GetToIndexScan(); bool Check(std::shared_ptr plan, - OptimizeContext *context) const override; + OptimizeContext *context) const override; void Transform(std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const override; + OptimizeContext *context) const override; }; /** * @brief Transforming query derived scan for nested query */ -class LogicalQueryDerivedGetToPhysical : public Rule { +class LogicalQueryDerivedGetToPhysical : public Rule { public: LogicalQueryDerivedGetToPhysical(); bool Check(std::shared_ptr plan, - OptimizeContext *context) const override; + OptimizeContext *context) const override; void Transform(std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const override; + OptimizeContext *context) const override; }; /** * @brief (Logical Delete -> Physical Delete) */ -class LogicalDeleteToPhysical : public Rule { +class LogicalDeleteToPhysical : public Rule { public: LogicalDeleteToPhysical(); bool Check(std::shared_ptr plan, - OptimizeContext *context) const override; + OptimizeContext *context) const override; void Transform(std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const override; + OptimizeContext *context) const override; }; /** * @brief (Logical Update -> Physical Update) */ -class LogicalUpdateToPhysical : public Rule { +class LogicalUpdateToPhysical : public Rule { public: LogicalUpdateToPhysical(); bool Check(std::shared_ptr plan, - OptimizeContext *context) const override; + OptimizeContext *context) const override; void Transform(std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const override; + OptimizeContext *context) const override; }; /** * @brief (Logical Insert -> Physical Insert) */ -class LogicalInsertToPhysical : public Rule { +class LogicalInsertToPhysical : public Rule { public: LogicalInsertToPhysical(); bool Check(std::shared_ptr plan, - OptimizeContext *context) const override; + OptimizeContext *context) const override; void Transform(std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const override; + OptimizeContext *context) const override; }; /** * @brief (Logical Insert Select -> Physical Insert Select) */ -class LogicalInsertSelectToPhysical : public Rule { +class LogicalInsertSelectToPhysical : public Rule { public: LogicalInsertSelectToPhysical(); bool Check(std::shared_ptr plan, - OptimizeContext *context) const override; + OptimizeContext *context) const override; void Transform(std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const override; + OptimizeContext *context) const override; }; /** * @brief (Logical Group by -> Hash Group by) */ -class LogicalGroupByToHashGroupBy : public Rule { +class LogicalGroupByToHashGroupBy : public Rule { public: LogicalGroupByToHashGroupBy(); bool Check(std::shared_ptr plan, - OptimizeContext *context) const override; + OptimizeContext *context) const override; void Transform(std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const override; + OptimizeContext *context) const override; }; /** * @brief (Logical Aggregate -> Physical Aggregate) */ -class LogicalAggregateToPhysical : public Rule { +class LogicalAggregateToPhysical : public Rule { public: LogicalAggregateToPhysical(); bool Check(std::shared_ptr plan, - OptimizeContext *context) const override; + OptimizeContext *context) const override; void Transform(std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const override; + OptimizeContext *context) const override; }; /** * @brief (Logical Inner Join -> Inner Nested-Loop Join) */ -class InnerJoinToInnerNLJoin : public Rule { +class InnerJoinToInnerNLJoin : public Rule { public: InnerJoinToInnerNLJoin(); bool Check(std::shared_ptr plan, - OptimizeContext *context) const override; + OptimizeContext *context) const override; void Transform(std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const override; + OptimizeContext *context) const override; }; /** * @brief (Logical Inner Join -> Inner Hash Join) */ -class InnerJoinToInnerHashJoin : public Rule { +class InnerJoinToInnerHashJoin : public Rule { public: InnerJoinToInnerHashJoin(); bool Check(std::shared_ptr plan, - OptimizeContext *context) const override; + OptimizeContext *context) const override; void Transform(std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const override; + OptimizeContext *context) const override; }; /** * @brief (Logical Distinct -> Physical Distinct) */ -class ImplementDistinct : public Rule { +class ImplementDistinct : public Rule { public: ImplementDistinct(); bool Check(std::shared_ptr plan, - OptimizeContext *context) const override; + OptimizeContext *context) const override; void Transform(std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const override; + OptimizeContext *context) const override; }; /** * @brief (Logical Limit -> Physical Limit) */ -class ImplementLimit : public Rule { +class ImplementLimit : public Rule { public: ImplementLimit(); bool Check(std::shared_ptr plan, - OptimizeContext *context) const override; + OptimizeContext *context) const override; void Transform(std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const override; + OptimizeContext *context) const override; }; /** * @brief Logical Export to External File -> Physical Export to External file */ -class LogicalExportToPhysicalExport : public Rule { +class LogicalExportToPhysicalExport : public Rule { public: LogicalExportToPhysicalExport(); bool Check(std::shared_ptr plan, - OptimizeContext *context) const override; + OptimizeContext *context) const override; void Transform(std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const override; + OptimizeContext *context) const override; }; //===--------------------------------------------------------------------===// @@ -306,63 +306,63 @@ class LogicalExportToPhysicalExport : public Rule { * we could push "test.a=5" through the join to evaluate at the table scan * level */ -class PushFilterThroughJoin : public Rule { +class PushFilterThroughJoin : public Rule { public: PushFilterThroughJoin(); bool Check(std::shared_ptr plan, - OptimizeContext *context) const override; + OptimizeContext *context) const override; void Transform(std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const override; + OptimizeContext *context) const override; }; /** * @brief Combine multiple filters into one single filter using conjunction */ -class CombineConsecutiveFilter : public Rule { +class CombineConsecutiveFilter : public Rule { public: CombineConsecutiveFilter(); bool Check(std::shared_ptr plan, - OptimizeContext *context) const override; + OptimizeContext *context) const override; void Transform(std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const override; + OptimizeContext *context) const override; }; /** * @brief perform predicate push-down to push a filter through aggregation, also * will embed filter into aggregation operator if appropriate. */ -class PushFilterThroughAggregation : public Rule { +class PushFilterThroughAggregation : public Rule { public: PushFilterThroughAggregation(); bool Check(std::shared_ptr plan, - OptimizeContext *context) const override; + OptimizeContext *context) const override; void Transform(std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const override; + OptimizeContext *context) const override; }; /** * @brief Embed a filter into a scan operator. After predicate push-down, we * eliminate all filters in the operator trees, predicates should be associated * with get or join */ -class EmbedFilterIntoGet : public Rule { +class EmbedFilterIntoGet : public Rule { public: EmbedFilterIntoGet(); bool Check(std::shared_ptr plan, - OptimizeContext *context) const override; + OptimizeContext *context) const override; void Transform(std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const override; + OptimizeContext *context) const override; }; /////////////////////////////////////////////////////////////////////////////// @@ -377,69 +377,69 @@ enum class UnnestPromise { Low = 1, High }; // should not use the following rules in the rewrite phase /////////////////////////////////////////////////////////////////////////////// /// MarkJoinGetToInnerJoin -class MarkJoinToInnerJoin : public Rule { +class MarkJoinToInnerJoin : public Rule { public: MarkJoinToInnerJoin(); - int Promise(GroupExpression *group_expr, - OptimizeContext *context) const override; + int Promise(GroupExpression *group_expr, + OptimizeContext *context) const override; bool Check(std::shared_ptr plan, - OptimizeContext *context) const override; + OptimizeContext *context) const override; void Transform(std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const override; + OptimizeContext *context) const override; }; /////////////////////////////////////////////////////////////////////////////// /// SingleJoinToInnerJoin -class SingleJoinToInnerJoin : public Rule { +class SingleJoinToInnerJoin : public Rule { public: SingleJoinToInnerJoin(); - int Promise(GroupExpression *group_expr, - OptimizeContext *context) const override; + int Promise(GroupExpression *group_expr, + OptimizeContext *context) const override; bool Check(std::shared_ptr plan, - OptimizeContext *context) const override; + OptimizeContext *context) const override; void Transform(std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const override; + OptimizeContext *context) const override; }; /////////////////////////////////////////////////////////////////////////////// /// PullFilterThroughMarkJoin -class PullFilterThroughMarkJoin : public Rule { +class PullFilterThroughMarkJoin : public Rule { public: PullFilterThroughMarkJoin(); - int Promise(GroupExpression *group_expr, - OptimizeContext *context) const override; + int Promise(GroupExpression *group_expr, + OptimizeContext *context) const override; bool Check(std::shared_ptr plan, - OptimizeContext *context) const override; + OptimizeContext *context) const override; void Transform(std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const override; + OptimizeContext *context) const override; }; /////////////////////////////////////////////////////////////////////////////// /// PullFilterThroughAggregation -class PullFilterThroughAggregation : public Rule { +class PullFilterThroughAggregation : public Rule { public: PullFilterThroughAggregation(); - int Promise(GroupExpression *group_expr, - OptimizeContext *context) const override; + int Promise(GroupExpression *group_expr, + OptimizeContext *context) const override; bool Check(std::shared_ptr plan, - OptimizeContext *context) const override; + OptimizeContext *context) const override; void Transform(std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const override; + OptimizeContext *context) const override; }; } // namespace optimizer } // namespace peloton diff --git a/src/include/optimizer/rule_rewrite.h b/src/include/optimizer/rule_rewrite.h new file mode 100644 index 00000000000..fe0f2b829bf --- /dev/null +++ b/src/include/optimizer/rule_rewrite.h @@ -0,0 +1,45 @@ +//===----------------------------------------------------------------------===// +// +// Peloton +// +// rule_rewrite.h +// +// Identification: src/include/optimizer/rule_rewrite.h +// +// Copyright (c) 2015-16, Carnegie Mellon University Database Group +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "optimizer/rule.h" +#include "optimizer/absexpr_expression.h" + +#include + +namespace peloton { +namespace optimizer { + +/* Rules are applied from high to low priority */ +enum class RulePriority : int { + HIGH = 3, + MEDIUM = 2, + LOW = 1 +}; + +class ComparatorElimination: public Rule { + public: + ComparatorElimination(); + + int Promise(GroupExpression *group_expr, + OptimizeContext *context) const override; + + bool Check(std::shared_ptr plan, + OptimizeContext *context) const override; + + void Transform(std::shared_ptr input, + std::vector> &transformed, + OptimizeContext *context) const override; +}; +} // namespace optimizer +} // namespace peloton diff --git a/src/include/optimizer/stats/child_stats_deriver.h b/src/include/optimizer/stats/child_stats_deriver.h index d0c72f9bf9b..cfca18e30d9 100644 --- a/src/include/optimizer/stats/child_stats_deriver.h +++ b/src/include/optimizer/stats/child_stats_deriver.h @@ -21,15 +21,19 @@ class AbstractExpression; } namespace optimizer { +template class Memo; +class OperatorExpression; + // Derive child stats that has not yet been calculated for a logical group // expression class ChildStatsDeriver : public OperatorVisitor { public: std::vector DeriveInputStats( - GroupExpression *gexpr, - ExprSet required_cols, Memo *memo); + GroupExpression *gexpr, + ExprSet required_cols, + Memo *memo); void Visit(const LogicalQueryDerivedGet *) override; void Visit(const LogicalInnerJoin *) override; @@ -43,8 +47,8 @@ class ChildStatsDeriver : public OperatorVisitor { void PassDownRequiredCols(); void PassDownColumn(expression::AbstractExpression* col); ExprSet required_cols_; - GroupExpression *gexpr_; - Memo *memo_; + GroupExpression *gexpr_; + Memo *memo_; std::vector output_; }; diff --git a/src/include/optimizer/stats/stats_calculator.h b/src/include/optimizer/stats/stats_calculator.h index 9637db2f224..6fed68370f9 100644 --- a/src/include/optimizer/stats/stats_calculator.h +++ b/src/include/optimizer/stats/stats_calculator.h @@ -17,8 +17,10 @@ namespace peloton { namespace optimizer { +template class Memo; class TableStats; +class OperatorExpression; /** * @brief Derive stats for the root group using a group expression's children's @@ -26,8 +28,10 @@ class TableStats; */ class StatsCalculator : public OperatorVisitor { public: - void CalculateStats(GroupExpression *gexpr, ExprSet required_cols, - Memo *memo, concurrency::TransactionContext* txn); + void CalculateStats(GroupExpression *gexpr, + ExprSet required_cols, + Memo *memo, + concurrency::TransactionContext* txn); void Visit(const LogicalGet *) override; void Visit(const LogicalQueryDerivedGet *) override; @@ -72,9 +76,9 @@ class StatsCalculator : public OperatorVisitor { const std::shared_ptr predicate_table_stats, const expression::AbstractExpression *expr); - GroupExpression *gexpr_; + GroupExpression *gexpr_; ExprSet required_cols_; - Memo *memo_; + Memo *memo_; concurrency::TransactionContext* txn_; }; diff --git a/src/optimizer/binding.cpp b/src/optimizer/binding.cpp index 9651ce8102c..2975dce336c 100644 --- a/src/optimizer/binding.cpp +++ b/src/optimizer/binding.cpp @@ -15,6 +15,7 @@ #include "common/logger.h" #include "optimizer/operator_visitor.h" #include "optimizer/optimizer.h" +#include "optimizer/absexpr_expression.h" namespace peloton { namespace optimizer { @@ -22,23 +23,22 @@ namespace optimizer { //===--------------------------------------------------------------------===// // Group Binding Iterator //===--------------------------------------------------------------------===// -GroupBindingIterator::GroupBindingIterator(Memo &memo, GroupID id, - std::shared_ptr pattern) - : BindingIterator(memo), +template +GroupBindingIterator::GroupBindingIterator( + Memo &memo, + GroupID id, + std::shared_ptr> pattern) + : BindingIterator(memo), group_id_(id), pattern_(pattern), - target_group_(memo_.GetGroupByID(id)), + target_group_(this->memo_.GetGroupByID(id)), num_group_items_(target_group_->GetLogicalExpressions().size()), current_item_index_(0) { LOG_TRACE("Attempting to bind on group %d", id); } -bool GroupBindingIterator::HasNext() { - LOG_TRACE("HasNext"); - if (pattern_->Type() == OpType::Leaf) { - return current_item_index_ == 0; - } - +template +bool GroupBindingIterator::HasNextBinding() { if (current_iterator_) { // Check if still have bindings in current item if (!current_iterator_->HasNext()) { @@ -50,8 +50,8 @@ bool GroupBindingIterator::HasNext() { if (current_iterator_ == nullptr) { // Keep checking item iterators until we find a match while (current_item_index_ < num_group_items_) { - current_iterator_.reset(new GroupExprBindingIterator( - memo_, + current_iterator_.reset(new GroupExprBindingIterator( + this->memo_, target_group_->GetLogicalExpressions()[current_item_index_].get(), pattern_)); @@ -67,7 +67,31 @@ bool GroupBindingIterator::HasNext() { return current_iterator_ != nullptr; } -std::shared_ptr GroupBindingIterator::Next() { +template +bool GroupBindingIterator::HasNext() { + return HasNextBinding(); +} + +// Specialization +template <> +bool GroupBindingIterator::HasNext() { + LOG_TRACE("HasNext"); + + if (pattern_->Type() == OpType::Leaf) { + return current_item_index_ == 0; + } + + return HasNextBinding(); +} + +template +std::shared_ptr GroupBindingIterator::Next() { + return current_iterator_->Next(); +} + +// Specialization +template <> +std::shared_ptr GroupBindingIterator::Next() { if (pattern_->Type() == OpType::Leaf) { current_item_index_ = num_group_items_; return std::make_shared(LeafOperator::make(group_id_)); @@ -78,20 +102,23 @@ std::shared_ptr GroupBindingIterator::Next() { //===--------------------------------------------------------------------===// // Item Binding Iterator //===--------------------------------------------------------------------===// -GroupExprBindingIterator::GroupExprBindingIterator( - Memo &memo, GroupExpression *gexpr, std::shared_ptr pattern) - : BindingIterator(memo), +template +GroupExprBindingIterator::GroupExprBindingIterator( + Memo &memo, + GroupExpression *gexpr, + std::shared_ptr> pattern) + : BindingIterator(memo), gexpr_(gexpr), pattern_(pattern), first_(true), has_next_(false), - current_binding_(std::make_shared(gexpr->Op())) { + current_binding_(std::make_shared(gexpr->Op())) { if (gexpr->Op().GetType() != pattern->Type()) { return; } const std::vector &child_groups = gexpr->GetChildGroupIDs(); - const std::vector> &child_patterns = + const std::vector>> &child_patterns = pattern->Children(); if (child_groups.size() != child_patterns.size()) { @@ -107,9 +134,9 @@ GroupExprBindingIterator::GroupExprBindingIterator( children_bindings_pos_.resize(child_groups.size(), 0); for (size_t i = 0; i < child_groups.size(); ++i) { // Try to find a match in the given group - std::vector> &child_bindings = + std::vector> &child_bindings = children_bindings_[i]; - GroupBindingIterator iterator(memo_, child_groups[i], child_patterns[i]); + GroupBindingIterator iterator(this->memo_, child_groups[i], child_patterns[i]); // Get all bindings while (iterator.HasNext()) { @@ -126,7 +153,8 @@ GroupExprBindingIterator::GroupExprBindingIterator( has_next_ = true; } -bool GroupExprBindingIterator::HasNext() { +template +bool GroupExprBindingIterator::HasNext() { LOG_TRACE("HasNext"); if (has_next_ && first_) { first_ = false; @@ -137,8 +165,7 @@ bool GroupExprBindingIterator::HasNext() { // The first child to be modified int first_modified_idx = children_bindings_pos_.size() - 1; for (; first_modified_idx >= 0; --first_modified_idx) { - const std::vector> &child_binding = - children_bindings_[first_modified_idx]; + const std::vector> &child_binding = children_bindings_[first_modified_idx]; // Try to increment idx from the back size_t new_pos = ++children_bindings_pos_[first_modified_idx]; @@ -154,17 +181,14 @@ bool GroupExprBindingIterator::HasNext() { has_next_ = false; } else { // Pop all updated childrens - for (size_t idx = first_modified_idx; idx < children_bindings_pos_.size(); - idx++) { + for (size_t idx = first_modified_idx; idx < children_bindings_pos_.size(); idx++) { current_binding_->PopChild(); } // Add new children to end for (size_t offset = first_modified_idx; offset < children_bindings_pos_.size(); ++offset) { - const std::vector> &child_binding = - children_bindings_[offset]; - std::shared_ptr binding = - child_binding[children_bindings_pos_[offset]]; + const std::vector> &child_binding = children_bindings_[offset]; + std::shared_ptr binding = child_binding[children_bindings_pos_[offset]]; current_binding_->PushChild(binding); } } @@ -172,9 +196,17 @@ bool GroupExprBindingIterator::HasNext() { return has_next_; } -std::shared_ptr GroupExprBindingIterator::Next() { +template +std::shared_ptr GroupExprBindingIterator::Next() { return current_binding_; } +// Explicitly instantiate +template class GroupBindingIterator; +template class GroupExprBindingIterator; + +template class GroupBindingIterator; +template class GroupExprBindingIterator; + } // namespace optimizer } // namespace peloton diff --git a/src/optimizer/child_property_deriver.cpp b/src/optimizer/child_property_deriver.cpp index b432067fae1..c025eed7dff 100644 --- a/src/optimizer/child_property_deriver.cpp +++ b/src/optimizer/child_property_deriver.cpp @@ -31,9 +31,9 @@ namespace peloton { namespace optimizer { vector, vector>>> -ChildPropertyDeriver::GetProperties(GroupExpression *gexpr, +ChildPropertyDeriver::GetProperties(GroupExpression *gexpr, shared_ptr requirements, - Memo *memo) { + Memo *memo) { requirements_ = requirements; output_.clear(); memo_ = memo; @@ -218,7 +218,7 @@ void ChildPropertyDeriver::DeriveForJoin() { if (prop->Type() == PropertyType::SORT) { auto sort_prop = prop->As(); size_t sort_col_size = sort_prop->GetSortColumnSize(); - Group *probe_group = memo_->GetGroupByID(gexpr_->GetChildGroupId(1)); + Group *probe_group = memo_->GetGroupByID(gexpr_->GetChildGroupId(1)); bool can_pass_down = true; for (size_t idx = 0; idx < sort_col_size; ++idx) { ExprSet tuples; diff --git a/src/optimizer/group.cpp b/src/optimizer/group.cpp index 673a7a1b8bd..99f9efd9171 100644 --- a/src/optimizer/group.cpp +++ b/src/optimizer/group.cpp @@ -11,6 +11,8 @@ //===----------------------------------------------------------------------===// #include "optimizer/group.h" +#include "optimizer/operator_expression.h" +#include "optimizer/absexpr_expression.h" #include "common/logger.h" @@ -20,13 +22,23 @@ namespace optimizer { //===--------------------------------------------------------------------===// // Group //===--------------------------------------------------------------------===// -Group::Group(GroupID id, std::unordered_set table_aliases) +template +Group::Group(GroupID id, std::unordered_set table_aliases) : id_(id), table_aliases_(std::move(table_aliases)) { has_explored_ = false; } -void Group::AddExpression(std::shared_ptr expr, - bool enforced) { +template +void Group::AddExpression( + std::shared_ptr> expr, + bool enforced) { + + // Additional assertion checks for AddExpression() with AST rewriting + if (std::is_same::value) { + PELOTON_ASSERT(!enforced); + PELOTON_ASSERT(!expr->Op().IsPhysical()); + } + // Do duplicate detection expr->SetGroupID(id_); if (enforced) @@ -37,8 +49,12 @@ void Group::AddExpression(std::shared_ptr expr, logical_expressions_.push_back(expr); } -bool Group::SetExpressionCost(GroupExpression *expr, double cost, - std::shared_ptr &properties) { +template +bool Group::SetExpressionCost( + GroupExpression *expr, + double cost, + std::shared_ptr &properties) { + LOG_TRACE("Adding expression cost on group %d with op %s, req %s", expr->GetGroupID(), expr->Op().GetName().c_str(), properties->ToString().c_str()); @@ -51,8 +67,11 @@ bool Group::SetExpressionCost(GroupExpression *expr, double cost, } return false; } -GroupExpression *Group::GetBestExpression( + +template +GroupExpression *Group::GetBestExpression( std::shared_ptr &properties) { + auto it = lowest_cost_expressions_.find(properties); if (it != lowest_cost_expressions_.end()) { return std::get<1>(it->second); @@ -62,20 +81,22 @@ GroupExpression *Group::GetBestExpression( return nullptr; } -bool Group::HasExpressions( - const std::shared_ptr &properties) const { +template +bool Group::HasExpressions(const std::shared_ptr &properties) const { const auto &it = lowest_cost_expressions_.find(properties); return (it != lowest_cost_expressions_.end()); } -std::shared_ptr Group::GetStats(std::string column_name) { +template +std::shared_ptr Group::GetStats(std::string column_name) { if (!stats_.count(column_name)) { return nullptr; } return stats_[column_name]; } -const std::string Group::GetInfo(int num_indent) const { +template +const std::string Group::GetInfo(int num_indent) const { std::ostringstream os; os << StringUtil::Indent(num_indent) << "GroupID: " << GetID() << std::endl; @@ -134,22 +155,29 @@ const std::string Group::GetInfo(int num_indent) const { return os.str(); } -const std::string Group::GetInfo() const { +template +const std::string Group::GetInfo() const { std::ostringstream os; os << GetInfo(0); return os.str(); } -void Group::AddStats(std::string column_name, +template +void Group::AddStats(std::string column_name, std::shared_ptr stats) { PELOTON_ASSERT((size_t)GetNumRows() == stats->num_rows); stats_[column_name] = stats; } -bool Group::HasColumnStats(std::string column_name) { +template +bool Group::HasColumnStats(std::string column_name) { return stats_.count(column_name); } +// Explicitly instantiate +template class Group; +template class Group; + } // namespace optimizer } // namespace peloton diff --git a/src/optimizer/group_expression.cpp b/src/optimizer/group_expression.cpp index 498c949b583..98540606558 100644 --- a/src/optimizer/group_expression.cpp +++ b/src/optimizer/group_expression.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "common/internal_types.h" +#include "optimizer/absexpr_expression.h" #include "optimizer/group_expression.h" #include "optimizer/group.h" #include "optimizer/rule.h" @@ -21,41 +22,51 @@ namespace optimizer { //===--------------------------------------------------------------------===// // Group Expression //===--------------------------------------------------------------------===// -GroupExpression::GroupExpression(Operator op, std::vector child_groups) +template +GroupExpression::GroupExpression(Node op, std::vector child_groups) : group_id(UNDEFINED_GROUP), op(op), child_groups(child_groups), stats_derived_(false) {} -GroupID GroupExpression::GetGroupID() const { return group_id; } +template +GroupID GroupExpression::GetGroupID() const { return group_id; } -void GroupExpression::SetGroupID(GroupID id) { group_id = id; } +template +void GroupExpression::SetGroupID(GroupID id) { group_id = id; } -void GroupExpression::SetChildGroupID(int child_group_idx, GroupID group_id) { +template +void GroupExpression::SetChildGroupID(int child_group_idx, GroupID group_id) { child_groups[child_group_idx] = group_id; } -const std::vector &GroupExpression::GetChildGroupIDs() const { +template +const std::vector &GroupExpression::GetChildGroupIDs() const { return child_groups; } -GroupID GroupExpression::GetChildGroupId(int child_idx) const { +template +GroupID GroupExpression::GetChildGroupId(int child_idx) const { return child_groups[child_idx]; } -Operator GroupExpression::Op() const { return op; } +template +Node GroupExpression::Op() const { return op; } -double GroupExpression::GetCost( +template +double GroupExpression::GetCost( std::shared_ptr &requirements) const { return std::get<0>(lowest_cost_table_.find(requirements)->second); } -std::vector> GroupExpression::GetInputProperties( +template +std::vector> GroupExpression::GetInputProperties( std::shared_ptr requirements) const { return std::get<1>(lowest_cost_table_.find(requirements)->second); } -void GroupExpression::SetLocalHashTable( +template +void GroupExpression::SetLocalHashTable( const std::shared_ptr &output_properties, const std::vector> &input_properties_list, double cost) { @@ -73,7 +84,8 @@ void GroupExpression::SetLocalHashTable( } } -hash_t GroupExpression::Hash() const { +template +hash_t GroupExpression::Hash() const { size_t hash = op.Hash(); for (size_t i = 0; i < child_groups.size(); ++i) { @@ -84,17 +96,24 @@ hash_t GroupExpression::Hash() const { return hash; } -bool GroupExpression::operator==(const GroupExpression &r) { +template +bool GroupExpression::operator==(const GroupExpression &r) { return (op == r.Op()) && (child_groups == r.child_groups); } -void GroupExpression::SetRuleExplored(Rule *rule) { +template +void GroupExpression::SetRuleExplored(Rule *rule) { rule_mask_.set(rule->GetRuleIdx(), true); } -bool GroupExpression::HasRuleExplored(Rule *rule) { +template +bool GroupExpression::HasRuleExplored(Rule *rule) { return rule_mask_.test(rule->GetRuleIdx()); } +// Explicitly instantiate to prevent linker errors +template class GroupExpression; +template class GroupExpression; + } // namespace optimizer } // namespace peloton diff --git a/src/optimizer/input_column_deriver.cpp b/src/optimizer/input_column_deriver.cpp index fdffb7e79a6..30ee095a379 100644 --- a/src/optimizer/input_column_deriver.cpp +++ b/src/optimizer/input_column_deriver.cpp @@ -37,8 +37,9 @@ InputColumnDeriver::InputColumnDeriver() {} pair, vector>> InputColumnDeriver::DeriveInputColumns( - GroupExpression *gexpr, shared_ptr properties, - vector required_cols, Memo *memo) { + GroupExpression *gexpr, shared_ptr properties, + vector required_cols, + Memo *memo) { properties_ = properties; gexpr_ = gexpr; required_cols_ = move(required_cols); diff --git a/src/optimizer/memo.cpp b/src/optimizer/memo.cpp index ca68a52c1d0..691cc9d3832 100644 --- a/src/optimizer/memo.cpp +++ b/src/optimizer/memo.cpp @@ -14,6 +14,7 @@ #include "optimizer/memo.h" #include "optimizer/operators.h" #include "optimizer/stats/stats_calculator.h" +#include "optimizer/absexpr_expression.h" namespace peloton { namespace optimizer { @@ -21,27 +22,63 @@ namespace optimizer { //===--------------------------------------------------------------------===// // Memo //===--------------------------------------------------------------------===// -Memo::Memo() {} +template +Memo::Memo() {} -GroupExpression *Memo::InsertExpression(std::shared_ptr gexpr, - bool enforced) { - return InsertExpression(gexpr, UNDEFINED_GROUP, enforced); +//===--------------------------------------------------------------------===// +// Memo::AddNewGroup (declare here to prevent specialization error) +//===--------------------------------------------------------------------===// +template +GroupID Memo::AddNewGroup(std::shared_ptr> gexpr) { + (void)gexpr; + + GroupID new_group_id = groups_.size(); + // Find out the table alias that this group represents + std::unordered_set table_aliases; + + groups_.emplace_back( + new Group(new_group_id, std::move(table_aliases))); + return new_group_id; } -GroupExpression *Memo::InsertExpression(std::shared_ptr gexpr, - GroupID target_group, bool enforced) { - // If leaf, then just return - if (gexpr->Op().GetType() == OpType::Leaf) { - const LeafOperator *leaf = gexpr->Op().As(); - PELOTON_ASSERT(target_group == UNDEFINED_GROUP || - target_group == leaf->origin_group); - gexpr->SetGroupID(leaf->origin_group); - return nullptr; +template <> +GroupID Memo::AddNewGroup(std::shared_ptr> gexpr) { + GroupID new_group_id = groups_.size(); + // Find out the table alias that this group represents + std::unordered_set table_aliases; + auto op_type = gexpr->Op().GetType(); + if (op_type == OpType::Get) { + // For base group, the table alias can get directly from logical get + const LogicalGet *logical_get = gexpr->Op().As(); + table_aliases.insert(logical_get->table_alias); + } else if (op_type == OpType::LogicalQueryDerivedGet) { + const LogicalQueryDerivedGet *query_get = + gexpr->Op().As(); + table_aliases.insert(query_get->table_alias); + } else { + // For other groups, need to aggregate the table alias from children + for (auto child_group_id : gexpr->GetChildGroupIDs()) { + Group *child_group = GetGroupByID(child_group_id); + for (auto &table_alias : child_group->GetTableAliases()) { + table_aliases.insert(table_alias); + } + } } - // Lookup in hash table - auto it = group_expressions_.find(gexpr.get()); + groups_.emplace_back( + new Group(new_group_id, std::move(table_aliases))); + return new_group_id; +} +//===--------------------------------------------------------------------===// +// Memo remaining interface functions +//===--------------------------------------------------------------------===// +template +GroupExpression* Memo::InsertExpr( + std::shared_ptr> gexpr, + GroupID target_group, bool enforced) { + + auto it = group_expressions_.find(gexpr.get()); if (it != group_expressions_.end()) { gexpr->SetGroupID((*it)->GetGroupID()); return *it; @@ -55,19 +92,59 @@ GroupExpression *Memo::InsertExpression(std::shared_ptr gexpr, } else { group_id = target_group; } - Group *group = GetGroupByID(group_id); + + Group *group = GetGroupByID(group_id); group->AddExpression(gexpr, enforced); return gexpr.get(); } } -std::vector> &Memo::Groups() { +template +GroupExpression *Memo::InsertExpression( + std::shared_ptr> gexpr, + bool enforced) { + + return InsertExpression(gexpr, UNDEFINED_GROUP, enforced); +} + +template +GroupExpression *Memo::InsertExpression( + std::shared_ptr> gexpr, + GroupID target_group, + bool enforced) { + + return InsertExpr(gexpr, target_group, enforced); +} + +// Specialization for Memo::InsertExpression due to OpType +template <> +GroupExpression *Memo::InsertExpression( + std::shared_ptr> gexpr, + GroupID target_group, + bool enforced) { + + // If leaf, then just return + if (gexpr->Op().GetType() == OpType::Leaf) { + const LeafOperator *leaf = gexpr->Op().As(); + PELOTON_ASSERT(target_group == UNDEFINED_GROUP || + target_group == leaf->origin_group); + gexpr->SetGroupID(leaf->origin_group); + return nullptr; + } + + return InsertExpr(gexpr, target_group, enforced); +} + +template +std::vector>> &Memo::Groups() { return groups_; } -Group *Memo::GetGroupByID(GroupID id) { return groups_[id].get(); } +template +Group *Memo::GetGroupByID(GroupID id) { return groups_[id].get(); } -const std::string Memo::GetInfo(int num_indent) const { +template +const std::string Memo::GetInfo(int num_indent) const { std::ostringstream os; os << StringUtil::Indent(num_indent) << "Memo::\n"; os << StringUtil::Indent(num_indent + 1) @@ -80,40 +157,16 @@ const std::string Memo::GetInfo(int num_indent) const { return os.str(); } -const std::string Memo::GetInfo() const { +template +const std::string Memo::GetInfo() const { std::ostringstream os; os << GetInfo(0); return os.str(); } - -GroupID Memo::AddNewGroup(std::shared_ptr gexpr) { - GroupID new_group_id = groups_.size(); - // Find out the table alias that this group represents - std::unordered_set table_aliases; - auto op_type = gexpr->Op().GetType(); - if (op_type == OpType::Get) { - // For base group, the table alias can get directly from logical get - const LogicalGet *logical_get = gexpr->Op().As(); - table_aliases.insert(logical_get->table_alias); - } else if (op_type == OpType::LogicalQueryDerivedGet) { - const LogicalQueryDerivedGet *query_get = - gexpr->Op().As(); - table_aliases.insert(query_get->table_alias); - } else { - // For other groups, need to aggregate the table alias from children - for (auto child_group_id : gexpr->GetChildGroupIDs()) { - Group *child_group = GetGroupByID(child_group_id); - for (auto &table_alias : child_group->GetTableAliases()) { - table_aliases.insert(table_alias); - } - } - } - - groups_.emplace_back( - new Group(new_group_id, std::move(table_aliases))); - return new_group_id; -} +// Explicitly instantiate template +template class Memo; +template class Memo; } // namespace optimizer } // namespace peloton diff --git a/src/optimizer/optimizer.cpp b/src/optimizer/optimizer.cpp index 83bcadde4de..5c6d8ac304c 100644 --- a/src/optimizer/optimizer.cpp +++ b/src/optimizer/optimizer.cpp @@ -65,15 +65,15 @@ Optimizer::Optimizer(const CostModels cost_model) : metadata_(nullptr) { switch (cost_model) { case CostModels::DEFAULT: { - metadata_ = OptimizerMetadata(std::unique_ptr(new DefaultCostModel)); + metadata_ = OptimizerMetadata(std::unique_ptr(new DefaultCostModel)); break; } case CostModels::POSTGRES: { - metadata_ = OptimizerMetadata(std::unique_ptr(new PostgresCostModel)); + metadata_ = OptimizerMetadata(std::unique_ptr(new PostgresCostModel)); break; } case CostModels::TRIVIAL: { - metadata_ = OptimizerMetadata(std::unique_ptr(new TrivialCostModel)); + metadata_ = OptimizerMetadata(std::unique_ptr(new TrivialCostModel)); break; } default: @@ -83,17 +83,17 @@ Optimizer::Optimizer(const CostModels cost_model) : metadata_(nullptr) { void Optimizer::OptimizeLoop(int root_group_id, std::shared_ptr required_props) { - std::shared_ptr root_context = - std::make_shared(&metadata_, required_props); + std::shared_ptr> root_context = + std::make_shared>(&metadata_, required_props); auto task_stack = - std::unique_ptr(new OptimizerTaskStack()); + std::unique_ptr>(new OptimizerTaskStack()); metadata_.SetTaskPool(task_stack.get()); // Perform rewrite first - task_stack->Push(new TopDownRewrite(root_group_id, root_context, + task_stack->Push(new TopDownRewrite(root_group_id, root_context, RewriteRuleSetName::PREDICATE_PUSH_DOWN)); - task_stack->Push(new BottomUpRewrite( + task_stack->Push(new BottomUpRewrite( root_group_id, root_context, RewriteRuleSetName::UNNEST_SUBQUERY, false)); ExecuteTaskStack(*task_stack, root_group_id, root_context); @@ -132,7 +132,7 @@ shared_ptr Optimizer::BuildPelotonPlanTree( metadata_.txn = txn; // Generate initial operator tree from query tree - shared_ptr gexpr = InsertQueryTree(parse_tree, txn); + shared_ptr> gexpr = InsertQueryTree(parse_tree, txn); GroupID root_id = gexpr->GetGroupID(); // Get the physical properties the final plan must output auto query_info = GetQueryInfo(parse_tree); @@ -158,7 +158,7 @@ shared_ptr Optimizer::BuildPelotonPlanTree( } void Optimizer::Reset() { - metadata_ = OptimizerMetadata(std::move(metadata_.cost_model)); + metadata_ = OptimizerMetadata(std::move(metadata_.cost_model)); } unique_ptr Optimizer::HandleDDLStatement( @@ -247,12 +247,12 @@ unique_ptr Optimizer::HandleDDLStatement( return ddl_plan; } -shared_ptr Optimizer::InsertQueryTree( +shared_ptr> Optimizer::InsertQueryTree( parser::SQLStatement *tree, concurrency::TransactionContext *txn) { QueryToOperatorTransformer converter(txn); shared_ptr initial = converter.ConvertToOpExpression(tree); - shared_ptr gexpr; + shared_ptr> gexpr; metadata_.RecordTransformedExpression(initial, gexpr); return gexpr; } @@ -323,7 +323,7 @@ const std::string Optimizer::GetOperatorInfo( int num_indent) { std::ostringstream os; - Group *group = metadata_.memo.GetGroupByID(id); + Group *group = metadata_.memo.GetGroupByID(id); auto gexpr = group->GetBestExpression(required_props); os << std::endl << StringUtil::Indent(num_indent) << "operator name: " @@ -347,7 +347,7 @@ const std::string Optimizer::GetOperatorInfo( unique_ptr Optimizer::ChooseBestPlan( GroupID id, std::shared_ptr required_props, std::vector required_cols) { - Group *group = metadata_.memo.GetGroupByID(id); + Group *group = metadata_.memo.GetGroupByID(id); LOG_TRACE("Choosing with property : %s", required_props->ToString().c_str()); auto gexpr = group->GetBestExpression(required_props); @@ -395,8 +395,8 @@ unique_ptr Optimizer::ChooseBestPlan( } void Optimizer::ExecuteTaskStack( - OptimizerTaskStack &task_stack, int root_group_id, - std::shared_ptr root_context) { + OptimizerTaskStack &task_stack, int root_group_id, + std::shared_ptr> root_context) { auto root_group = metadata_.memo.GetGroupByID(root_group_id); auto &timer = metadata_.timer; const auto timeout_limit = metadata_.timeout_limit; diff --git a/src/optimizer/optimizer_task.cpp b/src/optimizer/optimizer_task.cpp index e1cfac5643d..d8fc17b7e27 100644 --- a/src/optimizer/optimizer_task.cpp +++ b/src/optimizer/optimizer_task.cpp @@ -18,6 +18,7 @@ #include "optimizer/child_property_deriver.h" #include "optimizer/stats/stats_calculator.h" #include "optimizer/stats/child_stats_deriver.h" +#include "optimizer/absexpr_expression.h" namespace peloton { namespace optimizer { @@ -25,10 +26,12 @@ namespace optimizer { //===--------------------------------------------------------------------===// // Base class //===--------------------------------------------------------------------===// -void OptimizerTask::ConstructValidRules( - GroupExpression *group_expr, OptimizeContext *context, - std::vector> &rules, - std::vector &valid_rules) { +template +void OptimizerTask::ConstructValidRules( + GroupExpression *group_expr, + OptimizeContext *context, + std::vector>> &rules, + std::vector> &valid_rules) { for (auto &rule : rules) { // Check if we can apply the rule bool root_pattern_mismatch = @@ -45,13 +48,16 @@ void OptimizerTask::ConstructValidRules( } } -void OptimizerTask::PushTask(OptimizerTask *task) { +template +void OptimizerTask::PushTask(OptimizerTask *task) { context_->metadata->task_pool->Push(task); } -Memo &OptimizerTask::GetMemo() const { return context_->metadata->memo; } +template +Memo &OptimizerTask::GetMemo() const { return context_->metadata->memo; } -RuleSet &OptimizerTask::GetRuleSet() const { +template +RuleSet &OptimizerTask::GetRuleSet() const { return context_->metadata->rule_set; } @@ -86,14 +92,14 @@ void OptimizeGroup::execute() { // OptimizeExpression //===--------------------------------------------------------------------===// void OptimizeExpression::execute() { - std::vector valid_rules; + std::vector> valid_rules; // Construct valid transformation rules from rule set - ConstructValidRules(group_expr_, context_.get(), - GetRuleSet().GetTransformationRules(), valid_rules); + this->ConstructValidRules(group_expr_, context_.get(), + GetRuleSet().GetTransformationRules(), valid_rules); // Construct valid implementation rules from rule set - ConstructValidRules(group_expr_, context_.get(), - GetRuleSet().GetImplementationRules(), valid_rules); + this->ConstructValidRules(group_expr_, context_.get(), + GetRuleSet().GetImplementationRules(), valid_rules); std::sort(valid_rules.begin(), valid_rules.end()); LOG_DEBUG("OptimizeExpression::execute() op %d, valid rules : %lu", @@ -138,7 +144,7 @@ void ExploreGroup::execute() { //===--------------------------------------------------------------------===// void ExploreExpression::execute() { LOG_TRACE("ExploreExpression::execute() "); - std::vector valid_rules; + std::vector> valid_rules; // Construct valid transformation rules from rule set ConstructValidRules(group_expr_, context_.get(), @@ -172,8 +178,8 @@ void ApplyRule::execute() { LOG_TRACE("ApplyRule::execute() for rule: %d", rule_->GetRuleIdx()); if (group_expr_->HasRuleExplored(rule_)) return; - GroupExprBindingIterator iterator(GetMemo(), group_expr_, - rule_->GetMatchPattern()); + GroupExprBindingIterator iterator(GetMemo(), group_expr_, + rule_->GetMatchPattern()); while (iterator.HasNext()) { auto before = iterator.Next(); if (!rule_->Check(before, context_.get())) { @@ -183,7 +189,7 @@ void ApplyRule::execute() { std::vector> after; rule_->Transform(before, after, context_.get()); for (auto &new_expr : after) { - std::shared_ptr new_gexpr; + std::shared_ptr> new_gexpr; if (context_->metadata->RecordTransformedExpression( new_expr, new_gexpr, group_expr_->GetGroupID())) { // A new group expression is generated @@ -315,7 +321,7 @@ void OptimizeInputs::execute() { prev_child_idx_ = cur_child_idx_; PushTask(new OptimizeInputs(this)); PushTask(new OptimizeGroup( - child_group, std::make_shared( + child_group, std::make_shared>( context_->metadata, i_prop, context_->cost_upper_bound - cur_total_cost_))); return; } else { // If we return from OptimizeGroup, then there is no expr for @@ -336,7 +342,7 @@ void OptimizeInputs::execute() { // Enforce property if the requirement does not meet PropertyEnforcer prop_enforcer; auto extended_output_properties = output_prop->Properties(); - GroupExpression *memo_enforced_expr = nullptr; + GroupExpression *memo_enforced_expr = nullptr; bool meet_requirement = true; // TODO: For now, we enforce the missing properties in the order of how we // find them. This may @@ -402,29 +408,38 @@ void OptimizeInputs::execute() { } } -void TopDownRewrite::execute() { - std::vector valid_rules; +template +void TopDownRewrite::execute() { + std::vector> valid_rules; - auto cur_group = GetMemo().GetGroupByID(group_id_); + auto cur_group = this->GetMemo().GetGroupByID(group_id_); auto cur_group_expr = cur_group->GetLogicalExpression(); // Construct valid transformation rules from rule set - ConstructValidRules(cur_group_expr, context_.get(), - GetRuleSet().GetRewriteRulesByName(rule_set_name_), - valid_rules); + this->ConstructValidRules(cur_group_expr, this->context_.get(), + this->GetRuleSet().GetRewriteRulesByName(rule_set_name_), + valid_rules); // Sort so that we apply rewrite rules with higher promise first std::sort(valid_rules.begin(), valid_rules.end(), - std::greater()); + std::greater>()); for (auto &r : valid_rules) { - GroupExprBindingIterator iterator(GetMemo(), cur_group_expr, - r.rule->GetMatchPattern()); + GroupExprBindingIterator iterator(this->GetMemo(), cur_group_expr, + r.rule->GetMatchPattern()); if (iterator.HasNext()) { auto before = iterator.Next(); PELOTON_ASSERT(!iterator.HasNext()); - std::vector> after; - r.rule->Transform(before, after, context_.get()); + + // (TODO): pending terrier issue #332 + // Check whether rule actually can be applied + // as opposed to a structural level test + if (!r.rule->Check(before, this->context_.get())) { + continue; + } + + std::vector> after; + r.rule->Transform(before, after, this->context_.get()); // Rewrite rule should provide at most 1 expression PELOTON_ASSERT(after.size() <= 1); @@ -433,8 +448,8 @@ void TopDownRewrite::execute() { // saturated if (!after.empty()) { auto &new_expr = after[0]; - context_->metadata->ReplaceRewritedExpression(new_expr, group_id_); - PushTask(new TopDownRewrite(group_id_, context_, rule_set_name_)); + this->context_->metadata->ReplaceRewritedExpression(new_expr, group_id_); + this->PushTask(new TopDownRewrite(group_id_, this->context_, rule_set_name_)); return; } } @@ -445,47 +460,56 @@ void TopDownRewrite::execute() { child_group_idx < cur_group_expr->GetChildrenGroupsSize(); child_group_idx++) { // Need to rewrite all sub trees first - PushTask( - new TopDownRewrite(cur_group_expr->GetChildGroupId(child_group_idx), - context_, rule_set_name_)); + this->PushTask( + new TopDownRewrite(cur_group_expr->GetChildGroupId(child_group_idx), + this->context_, rule_set_name_)); } } -void BottomUpRewrite::execute() { - std::vector valid_rules; +template +void BottomUpRewrite::execute() { + std::vector> valid_rules; - auto cur_group = GetMemo().GetGroupByID(group_id_); + auto cur_group = this->GetMemo().GetGroupByID(group_id_); auto cur_group_expr = cur_group->GetLogicalExpression(); if (!has_optimized_child_) { - PushTask(new BottomUpRewrite(group_id_, context_, rule_set_name_, true)); + this->PushTask(new BottomUpRewrite(group_id_, this->context_, rule_set_name_, true)); for (size_t child_group_idx = 0; child_group_idx < cur_group_expr->GetChildrenGroupsSize(); child_group_idx++) { // Need to rewrite all sub trees first - PushTask( - new BottomUpRewrite(cur_group_expr->GetChildGroupId(child_group_idx), - context_, rule_set_name_, false)); + this->PushTask( + new BottomUpRewrite(cur_group_expr->GetChildGroupId(child_group_idx), + this->context_, rule_set_name_, false)); } return; } // Construct valid transformation rules from rule set - ConstructValidRules(cur_group_expr, context_.get(), - GetRuleSet().GetRewriteRulesByName(rule_set_name_), - valid_rules); + this->ConstructValidRules(cur_group_expr, this->context_.get(), + this->GetRuleSet().GetRewriteRulesByName(rule_set_name_), + valid_rules); // Sort so that we apply rewrite rules with higher promise first std::sort(valid_rules.begin(), valid_rules.end(), - std::greater()); + std::greater>()); for (auto &r : valid_rules) { - GroupExprBindingIterator iterator(GetMemo(), cur_group_expr, - r.rule->GetMatchPattern()); + GroupExprBindingIterator iterator(this->GetMemo(), cur_group_expr, + r.rule->GetMatchPattern()); if (iterator.HasNext()) { auto before = iterator.Next(); PELOTON_ASSERT(!iterator.HasNext()); - std::vector> after; - r.rule->Transform(before, after, context_.get()); + + // (TODO): pending terrier issue #332 + // Check whether rule actually can be applied + // as opposed to a structural level test + if (!r.rule->Check(before, this->context_.get())) { + continue; + } + + std::vector> after; + r.rule->Transform(before, after, this->context_.get()); // Rewrite rule should provide at most 1 expression PELOTON_ASSERT(after.size() <= 1); @@ -494,14 +518,24 @@ void BottomUpRewrite::execute() { // saturated, also childs are already been rewritten if (!after.empty()) { auto &new_expr = after[0]; - context_->metadata->ReplaceRewritedExpression(new_expr, group_id_); - PushTask( - new BottomUpRewrite(group_id_, context_, rule_set_name_, false)); + this->context_->metadata->ReplaceRewritedExpression(new_expr, group_id_); + this->PushTask( + new BottomUpRewrite(group_id_, this->context_, rule_set_name_, false)); + return; } } cur_group_expr->SetRuleExplored(r.rule); } } + + +// Explicitly instantiate +template class TopDownRewrite; +template class BottomUpRewrite; + +template class TopDownRewrite; +template class BottomUpRewrite; + } // namespace optimizer } // namespace peloton diff --git a/src/optimizer/pattern.cpp b/src/optimizer/pattern.cpp index d7665d678bb..23b976888cf 100644 --- a/src/optimizer/pattern.cpp +++ b/src/optimizer/pattern.cpp @@ -15,17 +15,25 @@ namespace peloton { namespace optimizer { -Pattern::Pattern(OpType op) : _type(op) {} +template +Pattern::Pattern(OperatorType op) : _type(op) {} -void Pattern::AddChild(std::shared_ptr child) { +template +void Pattern::AddChild(std::shared_ptr> child) { children.push_back(child); } -const std::vector> &Pattern::Children() const { +template +const std::vector>> &Pattern::Children() const { return children; } -OpType Pattern::Type() const { return _type; } +template +OperatorType Pattern::Type() const { return _type; } + +// Explicitly instantiate +template class Pattern; +template class Pattern; } // namespace optimizer } // namespace peloton diff --git a/src/optimizer/property_enforcer.cpp b/src/optimizer/property_enforcer.cpp index 834cf9a76d7..98013f214f4 100644 --- a/src/optimizer/property_enforcer.cpp +++ b/src/optimizer/property_enforcer.cpp @@ -19,8 +19,10 @@ namespace peloton { namespace optimizer { -std::shared_ptr PropertyEnforcer::EnforceProperty( - GroupExpression* gexpr, Property* property) { +std::shared_ptr> PropertyEnforcer::EnforceProperty( + GroupExpression* gexpr, + Property* property) { + input_gexpr_ = gexpr; property->Accept(this); return output_gexpr_; @@ -33,13 +35,13 @@ void PropertyEnforcer::Visit(const PropertyColumns *) { void PropertyEnforcer::Visit(const PropertySort *) { std::vector child_groups(1, input_gexpr_->GetGroupID()); output_gexpr_ = - std::make_shared(PhysicalOrderBy::make(), child_groups); + std::make_shared>(PhysicalOrderBy::make(), child_groups); } void PropertyEnforcer::Visit(const PropertyDistinct *) { std::vector child_groups(1, input_gexpr_->GetGroupID()); output_gexpr_ = - std::make_shared(PhysicalDistinct::make(), child_groups); + std::make_shared>(PhysicalDistinct::make(), child_groups); } void PropertyEnforcer::Visit(const PropertyLimit *) {} diff --git a/src/optimizer/rewriter.cpp b/src/optimizer/rewriter.cpp new file mode 100644 index 00000000000..d23d998e51d --- /dev/null +++ b/src/optimizer/rewriter.cpp @@ -0,0 +1,150 @@ +//===----------------------------------------------------------------------===// +// +// Peloton +// +// rewriter.cpp +// +// Identification: src/optimizer/rewriter.cpp +// +// Copyright (c) 2015-2018, Carnegie Mellon University Database Group +// +//===----------------------------------------------------------------------===// + +#include + +#include "optimizer/optimizer.h" +#include "optimizer/rewriter.h" +#include "common/exception.h" + +#include "optimizer/cost_model/trivial_cost_model.h" +#include "optimizer/operator_visitor.h" +#include "optimizer/optimize_context.h" +#include "optimizer/optimizer_task_pool.h" +#include "optimizer/rule.h" +#include "optimizer/rule_impls.h" +#include "optimizer/optimizer_metadata.h" +#include "optimizer/absexpr_expression.h" +#include "expression/abstract_expression.h" +#include "expression/constant_value_expression.h" + +using std::vector; +using std::unordered_map; +using std::shared_ptr; +using std::unique_ptr; +using std::move; +using std::pair; +using std::make_shared; + +namespace peloton { +namespace optimizer { + +using OptimizerMetadataTemplate = OptimizerMetadata; + +using OptimizeContextTemplate = OptimizeContext; + +using OptimizerTaskStackTemplate = OptimizerTaskStack; + +using TopDownRewriteTemplate = TopDownRewrite; + +using BottomUpRewriteTemplate = BottomUpRewrite; + +using GroupExpressionTemplate = GroupExpression; + +using GroupTemplate = Group; + +Rewriter::Rewriter() : metadata_(nullptr) { + metadata_ = OptimizerMetadataTemplate(nullptr); +} + +void Rewriter::RewriteLoop(int root_group_id) { + std::shared_ptr root_context = + std::make_shared(&metadata_, nullptr); + auto task_stack = + std::unique_ptr(new OptimizerTaskStackTemplate()); + metadata_.SetTaskPool(task_stack.get()); + + // Perform rewrite first + task_stack->Push(new BottomUpRewriteTemplate(root_group_id, root_context, RewriteRuleSetName::COMPARATOR_ELIMINATION, false)); + + ExecuteTaskStack(*task_stack); +} + +expression::AbstractExpression* Rewriter::RebuildExpression(int root) { + auto cur_group = metadata_.memo.GetGroupByID(root); + auto exprs = cur_group->GetLogicalExpressions(); + + // (TODO): what should we do if exprs.size() > 1? + PELOTON_ASSERT(exprs.size() > 0); + auto expr = exprs[0]; + + std::vector child_groups = expr->GetChildGroupIDs(); + std::vector child_exprs; + for (auto group : child_groups) { + // Build children first + expression::AbstractExpression *child = RebuildExpression(group); + PELOTON_ASSERT(child != nullptr); + + child_exprs.push_back(child); + } + + AbsExpr_Container c = expr->Op(); + return c.Rebuild(child_exprs); +} + +expression::AbstractExpression* Rewriter::RewriteExpression(const expression::AbstractExpression *expr) { + // (TODO): do we need to actually convert to a wrapper? + // This is needed in order to provide template classes the correct interface. + // This should probably be better abstracted away. + std::shared_ptr gexpr = ConvertTree(expr); + LOG_DEBUG("Converted tree to internal data structures"); + + GroupID root_id = gexpr->GetGroupID(); + RewriteLoop(root_id); + LOG_DEBUG("Performed rewrite loop pass"); + + expression::AbstractExpression *expr_tree = RebuildExpression(root_id); + LOG_DEBUG("Rebuilt expression tree from memo table"); + + Reset(); + LOG_DEBUG("Reset the rewriter"); + return expr_tree; +} + +void Rewriter::Reset() { + metadata_ = OptimizerMetadataTemplate(nullptr); +} + +std::shared_ptr Rewriter::ConvertToAbsExpr(const expression::AbstractExpression* expr) { + + // (TODO): fix memory management once we get to terrier + // for now, this just directly wraps each AbstractExpression in a AbsExpr_Container + // which is then wrapped in an AbsExpr_Expression to provide the same Operator/OperatorExpression + // interface that is relied upon by the rest of the code base. + + auto container = AbsExpr_Container(expr); + auto exp = std::make_shared(container); + for (size_t i = 0; i < expr->GetChildrenSize(); i++) { + exp->PushChild(ConvertToAbsExpr(expr->GetChild(i))); + } + return exp; +} + +std::shared_ptr Rewriter::ConvertTree( + const expression::AbstractExpression *expr) { + + std::shared_ptr exp = ConvertToAbsExpr(expr); + std::shared_ptr gexpr; + metadata_.RecordTransformedExpression(exp, gexpr); + return gexpr; +} + +void Rewriter::ExecuteTaskStack(OptimizerTaskStackTemplate &task_stack) { + // Iterate through the task stack + while (!task_stack.Empty()) { + auto task = task_stack.Pop(); + task->execute(); + } +} + +} // namespace optimizer +} // namespace peloton diff --git a/src/optimizer/rule.cpp b/src/optimizer/rule.cpp index 8c72ed17fa8..0d14104060d 100644 --- a/src/optimizer/rule.cpp +++ b/src/optimizer/rule.cpp @@ -12,11 +12,31 @@ #include "optimizer/rule_impls.h" #include "optimizer/group_expression.h" +#include "optimizer/absexpr_expression.h" +#include "optimizer/rule_rewrite.h" namespace peloton { namespace optimizer { -int Rule::Promise(GroupExpression *group_expr, OptimizeContext *context) const { +template +int Rule::Promise( + GroupExpression *group_expr, + OptimizeContext *context) const { + + (void)group_expr; + (void)context; + + LOG_ERROR("Rule::Promise for rewrite engine not implemented!"); + PELOTON_ASSERT(0); + return 0; +} + +// Specialization due to OpType +template <> +int Rule::Promise( + GroupExpression *group_expr, + OptimizeContext *context) const { + (void)context; auto root_type = match_pattern->Type(); // This rule is not applicable @@ -27,7 +47,20 @@ int Rule::Promise(GroupExpression *group_expr, OptimizeContext *context) const { return LOG_PROMISE; } -RuleSet::RuleSet() { +template +RuleSet::RuleSet() { + LOG_ERROR("Must invoke specialization of RuleSet constructor"); + PELOTON_ASSERT(0); +} + +template <> +RuleSet::RuleSet() { + AddRewriteRule(RewriteRuleSetName::COMPARATOR_ELIMINATION, + new ComparatorElimination()); +} + +template <> +RuleSet::RuleSet() { AddTransformationRule(new InnerJoinCommutativity()); AddTransformationRule(new InnerJoinAssociativity()); AddImplementationRule(new LogicalDeleteToPhysical()); @@ -64,5 +97,9 @@ RuleSet::RuleSet() { new PullFilterThroughAggregation()); } +// Explicitly instantiate +template class Rule; +template class Rule; + } // namespace optimizer } // namespace peloton diff --git a/src/optimizer/rule_impls.cpp b/src/optimizer/rule_impls.cpp index 8574e00f337..ed24b5680ec 100644 --- a/src/optimizer/rule_impls.cpp +++ b/src/optimizer/rule_impls.cpp @@ -34,15 +34,15 @@ namespace optimizer { InnerJoinCommutativity::InnerJoinCommutativity() { type_ = RuleType::INNER_JOIN_COMMUTE; - std::shared_ptr left_child(std::make_shared(OpType::Leaf)); - std::shared_ptr right_child(std::make_shared(OpType::Leaf)); - match_pattern = std::make_shared(OpType::InnerJoin); + std::shared_ptr> left_child(std::make_shared>(OpType::Leaf)); + std::shared_ptr> right_child(std::make_shared>(OpType::Leaf)); + match_pattern = std::make_shared>(OpType::InnerJoin); match_pattern->AddChild(left_child); match_pattern->AddChild(right_child); } bool InnerJoinCommutativity::Check(std::shared_ptr expr, - OptimizeContext *context) const { + OptimizeContext *context) const { (void)context; (void)expr; return true; @@ -51,7 +51,7 @@ bool InnerJoinCommutativity::Check(std::shared_ptr expr, void InnerJoinCommutativity::Transform( std::shared_ptr input, std::vector> &transformed, - UNUSED_ATTRIBUTE OptimizeContext *context) const { + UNUSED_ATTRIBUTE OptimizeContext *context) const { auto join_op = input->Op().As(); auto join_predicates = std::vector(join_op->join_predicates); @@ -74,20 +74,20 @@ InnerJoinAssociativity::InnerJoinAssociativity() { type_ = RuleType::INNER_JOIN_ASSOCIATE; // Create left nested join - auto left_child = std::make_shared(OpType::InnerJoin); - left_child->AddChild(std::make_shared(OpType::Leaf)); - left_child->AddChild(std::make_shared(OpType::Leaf)); + auto left_child = std::make_shared>(OpType::InnerJoin); + left_child->AddChild(std::make_shared>(OpType::Leaf)); + left_child->AddChild(std::make_shared>(OpType::Leaf)); - std::shared_ptr right_child(std::make_shared(OpType::Leaf)); + std::shared_ptr> right_child(std::make_shared>(OpType::Leaf)); - match_pattern = std::make_shared(OpType::InnerJoin); + match_pattern = std::make_shared>(OpType::InnerJoin); match_pattern->AddChild(left_child); match_pattern->AddChild(right_child); } // TODO: As far as I know, theres nothing else that needs to be checked bool InnerJoinAssociativity::Check(std::shared_ptr expr, - OptimizeContext *context) const { + OptimizeContext *context) const { (void)context; (void)expr; return true; @@ -96,7 +96,7 @@ bool InnerJoinAssociativity::Check(std::shared_ptr expr, void InnerJoinAssociativity::Transform( std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const { + OptimizeContext *context) const { // NOTE: Transforms (left JOIN middle) JOIN right -> left JOIN (middle JOIN // right) Variables are named accordingly to above transformation auto parent_join = input->Op().As(); @@ -179,11 +179,11 @@ void InnerJoinAssociativity::Transform( GetToDummyScan::GetToDummyScan() { type_ = RuleType::GET_TO_DUMMY_SCAN; - match_pattern = std::make_shared(OpType::Get); + match_pattern = std::make_shared>(OpType::Get); } bool GetToDummyScan::Check(std::shared_ptr plan, - OptimizeContext *context) const { + OptimizeContext *context) const { (void)context; const LogicalGet *get = plan->Op().As(); return get->table == nullptr; @@ -192,7 +192,7 @@ bool GetToDummyScan::Check(std::shared_ptr plan, void GetToDummyScan::Transform( UNUSED_ATTRIBUTE std::shared_ptr input, std::vector> &transformed, - UNUSED_ATTRIBUTE OptimizeContext *context) const { + UNUSED_ATTRIBUTE OptimizeContext *context) const { auto result_plan = std::make_shared(DummyScan::make()); transformed.push_back(result_plan); @@ -203,11 +203,11 @@ void GetToDummyScan::Transform( GetToSeqScan::GetToSeqScan() { type_ = RuleType::GET_TO_SEQ_SCAN; - match_pattern = std::make_shared(OpType::Get); + match_pattern = std::make_shared>(OpType::Get); } bool GetToSeqScan::Check(std::shared_ptr plan, - OptimizeContext *context) const { + OptimizeContext *context) const { (void)context; const LogicalGet *get = plan->Op().As(); return get->table != nullptr; @@ -216,7 +216,7 @@ bool GetToSeqScan::Check(std::shared_ptr plan, void GetToSeqScan::Transform( std::shared_ptr input, std::vector> &transformed, - UNUSED_ATTRIBUTE OptimizeContext *context) const { + UNUSED_ATTRIBUTE OptimizeContext *context) const { const LogicalGet *get = input->Op().As(); auto result_plan = std::make_shared( @@ -235,11 +235,11 @@ void GetToSeqScan::Transform( GetToIndexScan::GetToIndexScan() { type_ = RuleType::GET_TO_INDEX_SCAN; - match_pattern = std::make_shared(OpType::Get); + match_pattern = std::make_shared>(OpType::Get); } bool GetToIndexScan::Check(std::shared_ptr plan, - OptimizeContext *context) const { + OptimizeContext *context) const { // If there is a index for the table, return true, // else return false (void)context; @@ -255,7 +255,7 @@ bool GetToIndexScan::Check(std::shared_ptr plan, void GetToIndexScan::Transform( std::shared_ptr input, std::vector> &transformed, - UNUSED_ATTRIBUTE OptimizeContext *context) const { + UNUSED_ATTRIBUTE OptimizeContext *context) const { UNUSED_ATTRIBUTE std::vector> children = input->Children(); PELOTON_ASSERT(children.size() == 0); @@ -409,13 +409,13 @@ void GetToIndexScan::Transform( /// LogicalQueryDerivedGetToPhysical LogicalQueryDerivedGetToPhysical::LogicalQueryDerivedGetToPhysical() { type_ = RuleType::QUERY_DERIVED_GET_TO_PHYSICAL; - match_pattern = std::make_shared(OpType::LogicalQueryDerivedGet); - std::shared_ptr child(std::make_shared(OpType::Leaf)); + match_pattern = std::make_shared>(OpType::LogicalQueryDerivedGet); + std::shared_ptr> child(std::make_shared>(OpType::Leaf)); match_pattern->AddChild(child); } bool LogicalQueryDerivedGetToPhysical::Check( - std::shared_ptr expr, OptimizeContext *context) const { + std::shared_ptr expr, OptimizeContext *context) const { (void)context; (void)expr; return true; @@ -424,7 +424,7 @@ bool LogicalQueryDerivedGetToPhysical::Check( void LogicalQueryDerivedGetToPhysical::Transform( std::shared_ptr input, std::vector> &transformed, - UNUSED_ATTRIBUTE OptimizeContext *context) const { + UNUSED_ATTRIBUTE OptimizeContext *context) const { const LogicalQueryDerivedGet *get = input->Op().As(); auto result_plan = @@ -439,19 +439,19 @@ void LogicalQueryDerivedGetToPhysical::Transform( /// LogicalExternalFileGetToPhysical LogicalExternalFileGetToPhysical::LogicalExternalFileGetToPhysical() { type_ = RuleType::EXTERNAL_FILE_GET_TO_PHYSICAL; - match_pattern = std::make_shared(OpType::LogicalExternalFileGet); + match_pattern = std::make_shared>(OpType::LogicalExternalFileGet); } bool LogicalExternalFileGetToPhysical::Check( UNUSED_ATTRIBUTE std::shared_ptr plan, - UNUSED_ATTRIBUTE OptimizeContext *context) const { + UNUSED_ATTRIBUTE OptimizeContext *context) const { return true; } void LogicalExternalFileGetToPhysical::Transform( std::shared_ptr input, std::vector> &transformed, - UNUSED_ATTRIBUTE OptimizeContext *context) const { + UNUSED_ATTRIBUTE OptimizeContext *context) const { const auto *get = input->Op().As(); auto result_plan = std::make_shared( @@ -467,13 +467,13 @@ void LogicalExternalFileGetToPhysical::Transform( /// LogicalDeleteToPhysical LogicalDeleteToPhysical::LogicalDeleteToPhysical() { type_ = RuleType::DELETE_TO_PHYSICAL; - match_pattern = std::make_shared(OpType::LogicalDelete); - std::shared_ptr child(std::make_shared(OpType::Leaf)); + match_pattern = std::make_shared>(OpType::LogicalDelete); + std::shared_ptr> child(std::make_shared>(OpType::Leaf)); match_pattern->AddChild(child); } bool LogicalDeleteToPhysical::Check(std::shared_ptr plan, - OptimizeContext *context) const { + OptimizeContext *context) const { (void)plan; (void)context; return true; @@ -482,7 +482,7 @@ bool LogicalDeleteToPhysical::Check(std::shared_ptr plan, void LogicalDeleteToPhysical::Transform( std::shared_ptr input, std::vector> &transformed, - UNUSED_ATTRIBUTE OptimizeContext *context) const { + UNUSED_ATTRIBUTE OptimizeContext *context) const { const LogicalDelete *delete_op = input->Op().As(); auto result = std::make_shared( PhysicalDelete::make(delete_op->target_table)); @@ -495,13 +495,13 @@ void LogicalDeleteToPhysical::Transform( /// LogicalUpdateToPhysical LogicalUpdateToPhysical::LogicalUpdateToPhysical() { type_ = RuleType::UPDATE_TO_PHYSICAL; - match_pattern = std::make_shared(OpType::LogicalUpdate); - std::shared_ptr child(std::make_shared(OpType::Leaf)); + match_pattern = std::make_shared>(OpType::LogicalUpdate); + std::shared_ptr> child(std::make_shared>(OpType::Leaf)); match_pattern->AddChild(child); } bool LogicalUpdateToPhysical::Check(std::shared_ptr plan, - OptimizeContext *context) const { + OptimizeContext *context) const { (void)plan; (void)context; return true; @@ -510,7 +510,7 @@ bool LogicalUpdateToPhysical::Check(std::shared_ptr plan, void LogicalUpdateToPhysical::Transform( std::shared_ptr input, std::vector> &transformed, - UNUSED_ATTRIBUTE OptimizeContext *context) const { + UNUSED_ATTRIBUTE OptimizeContext *context) const { const LogicalUpdate *update_op = input->Op().As(); auto result = std::make_shared( PhysicalUpdate::make(update_op->target_table, update_op->updates)); @@ -523,13 +523,13 @@ void LogicalUpdateToPhysical::Transform( /// LogicalInsertToPhysical LogicalInsertToPhysical::LogicalInsertToPhysical() { type_ = RuleType::INSERT_TO_PHYSICAL; - match_pattern = std::make_shared(OpType::LogicalInsert); - // std::shared_ptr child(std::make_shared(OpType::Leaf)); + match_pattern = std::make_shared>(OpType::LogicalInsert); + // std::shared_ptr> child(std::make_shared>(OpType::Leaf)); // match_pattern->AddChild(child); } bool LogicalInsertToPhysical::Check(std::shared_ptr plan, - OptimizeContext *context) const { + OptimizeContext *context) const { (void)plan; (void)context; return true; @@ -538,7 +538,7 @@ bool LogicalInsertToPhysical::Check(std::shared_ptr plan, void LogicalInsertToPhysical::Transform( std::shared_ptr input, std::vector> &transformed, - UNUSED_ATTRIBUTE OptimizeContext *context) const { + UNUSED_ATTRIBUTE OptimizeContext *context) const { const LogicalInsert *insert_op = input->Op().As(); auto result = std::make_shared(PhysicalInsert::make( insert_op->target_table, insert_op->columns, insert_op->values)); @@ -551,13 +551,13 @@ void LogicalInsertToPhysical::Transform( /// LogicalInsertSelectToPhysical LogicalInsertSelectToPhysical::LogicalInsertSelectToPhysical() { type_ = RuleType::INSERT_SELECT_TO_PHYSICAL; - match_pattern = std::make_shared(OpType::LogicalInsertSelect); - std::shared_ptr child(std::make_shared(OpType::Leaf)); + match_pattern = std::make_shared>(OpType::LogicalInsertSelect); + std::shared_ptr> child(std::make_shared>(OpType::Leaf)); match_pattern->AddChild(child); } bool LogicalInsertSelectToPhysical::Check( - std::shared_ptr plan, OptimizeContext *context) const { + std::shared_ptr plan, OptimizeContext *context) const { (void)plan; (void)context; return true; @@ -566,7 +566,7 @@ bool LogicalInsertSelectToPhysical::Check( void LogicalInsertSelectToPhysical::Transform( std::shared_ptr input, std::vector> &transformed, - UNUSED_ATTRIBUTE OptimizeContext *context) const { + UNUSED_ATTRIBUTE OptimizeContext *context) const { const LogicalInsertSelect *insert_op = input->Op().As(); auto result = std::make_shared( PhysicalInsertSelect::make(insert_op->target_table)); @@ -579,14 +579,14 @@ void LogicalInsertSelectToPhysical::Transform( /// LogicalAggregateAndGroupByToHashGroupBy LogicalGroupByToHashGroupBy::LogicalGroupByToHashGroupBy() { type_ = RuleType::AGGREGATE_TO_HASH_AGGREGATE; - match_pattern = std::make_shared(OpType::LogicalAggregateAndGroupBy); - std::shared_ptr child(std::make_shared(OpType::Leaf)); + match_pattern = std::make_shared>(OpType::LogicalAggregateAndGroupBy); + std::shared_ptr> child(std::make_shared>(OpType::Leaf)); match_pattern->AddChild(child); } bool LogicalGroupByToHashGroupBy::Check( UNUSED_ATTRIBUTE std::shared_ptr plan, - OptimizeContext *context) const { + OptimizeContext *context) const { (void)context; const LogicalAggregateAndGroupBy *agg_op = plan->Op().As(); @@ -596,7 +596,7 @@ bool LogicalGroupByToHashGroupBy::Check( void LogicalGroupByToHashGroupBy::Transform( std::shared_ptr input, std::vector> &transformed, - UNUSED_ATTRIBUTE OptimizeContext *context) const { + UNUSED_ATTRIBUTE OptimizeContext *context) const { const LogicalAggregateAndGroupBy *agg_op = input->Op().As(); auto result = std::make_shared( @@ -610,14 +610,14 @@ void LogicalGroupByToHashGroupBy::Transform( /// LogicalAggregateToPhysical LogicalAggregateToPhysical::LogicalAggregateToPhysical() { type_ = RuleType::AGGREGATE_TO_PLAIN_AGGREGATE; - match_pattern = std::make_shared(OpType::LogicalAggregateAndGroupBy); - std::shared_ptr child(std::make_shared(OpType::Leaf)); + match_pattern = std::make_shared>(OpType::LogicalAggregateAndGroupBy); + std::shared_ptr> child(std::make_shared>(OpType::Leaf)); match_pattern->AddChild(child); } bool LogicalAggregateToPhysical::Check( UNUSED_ATTRIBUTE std::shared_ptr plan, - OptimizeContext *context) const { + OptimizeContext *context) const { (void)context; const LogicalAggregateAndGroupBy *agg_op = plan->Op().As(); @@ -627,7 +627,7 @@ bool LogicalAggregateToPhysical::Check( void LogicalAggregateToPhysical::Transform( std::shared_ptr input, std::vector> &transformed, - UNUSED_ATTRIBUTE OptimizeContext *context) const { + UNUSED_ATTRIBUTE OptimizeContext *context) const { auto result = std::make_shared(PhysicalAggregate::make()); PELOTON_ASSERT(input->Children().size() == 1); result->PushChild(input->Children().at(0)); @@ -640,11 +640,11 @@ InnerJoinToInnerNLJoin::InnerJoinToInnerNLJoin() { type_ = RuleType::INNER_JOIN_TO_NL_JOIN; // TODO NLJoin currently only support left deep tree - std::shared_ptr left_child(std::make_shared(OpType::Leaf)); - std::shared_ptr right_child(std::make_shared(OpType::Leaf)); + std::shared_ptr> left_child(std::make_shared>(OpType::Leaf)); + std::shared_ptr> right_child(std::make_shared>(OpType::Leaf)); // Initialize a pattern for optimizer to match - match_pattern = std::make_shared(OpType::InnerJoin); + match_pattern = std::make_shared>(OpType::InnerJoin); // Add node - we match join relation R and S match_pattern->AddChild(left_child); @@ -654,7 +654,7 @@ InnerJoinToInnerNLJoin::InnerJoinToInnerNLJoin() { } bool InnerJoinToInnerNLJoin::Check(std::shared_ptr plan, - OptimizeContext *context) const { + OptimizeContext *context) const { (void)context; (void)plan; return true; @@ -663,7 +663,7 @@ bool InnerJoinToInnerNLJoin::Check(std::shared_ptr plan, void InnerJoinToInnerNLJoin::Transform( std::shared_ptr input, std::vector> &transformed, - UNUSED_ATTRIBUTE OptimizeContext *context) const { + UNUSED_ATTRIBUTE OptimizeContext *context) const { // first build an expression representing hash join const LogicalInnerJoin *inner_join = input->Op().As(); @@ -701,11 +701,11 @@ InnerJoinToInnerHashJoin::InnerJoinToInnerHashJoin() { type_ = RuleType::INNER_JOIN_TO_HASH_JOIN; // Make three node types for pattern matching - std::shared_ptr left_child(std::make_shared(OpType::Leaf)); - std::shared_ptr right_child(std::make_shared(OpType::Leaf)); + std::shared_ptr> left_child(std::make_shared>(OpType::Leaf)); + std::shared_ptr> right_child(std::make_shared>(OpType::Leaf)); // Initialize a pattern for optimizer to match - match_pattern = std::make_shared(OpType::InnerJoin); + match_pattern = std::make_shared>(OpType::InnerJoin); // Add node - we match join relation R and S as well as the predicate exp match_pattern->AddChild(left_child); @@ -715,7 +715,7 @@ InnerJoinToInnerHashJoin::InnerJoinToInnerHashJoin() { } bool InnerJoinToInnerHashJoin::Check(std::shared_ptr plan, - OptimizeContext *context) const { + OptimizeContext *context) const { (void)context; (void)plan; return true; @@ -724,7 +724,7 @@ bool InnerJoinToInnerHashJoin::Check(std::shared_ptr plan, void InnerJoinToInnerHashJoin::Transform( std::shared_ptr input, std::vector> &transformed, - UNUSED_ATTRIBUTE OptimizeContext *context) const { + UNUSED_ATTRIBUTE OptimizeContext *context) const { // first build an expression representing hash join const LogicalInnerJoin *inner_join = input->Op().As(); @@ -761,12 +761,12 @@ void InnerJoinToInnerHashJoin::Transform( ImplementDistinct::ImplementDistinct() { type_ = RuleType::IMPLEMENT_DISTINCT; - match_pattern = std::make_shared(OpType::LogicalDistinct); - match_pattern->AddChild(std::make_shared(OpType::Leaf)); + match_pattern = std::make_shared>(OpType::LogicalDistinct); + match_pattern->AddChild(std::make_shared>(OpType::Leaf)); } bool ImplementDistinct::Check(std::shared_ptr plan, - OptimizeContext *context) const { + OptimizeContext *context) const { (void)context; (void)plan; return true; @@ -775,7 +775,7 @@ bool ImplementDistinct::Check(std::shared_ptr plan, void ImplementDistinct::Transform( std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const { + OptimizeContext *context) const { (void)context; auto result_plan = std::make_shared(PhysicalDistinct::make()); @@ -792,12 +792,12 @@ void ImplementDistinct::Transform( ImplementLimit::ImplementLimit() { type_ = RuleType::IMPLEMENT_LIMIT; - match_pattern = std::make_shared(OpType::LogicalLimit); - match_pattern->AddChild(std::make_shared(OpType::Leaf)); + match_pattern = std::make_shared>(OpType::LogicalLimit); + match_pattern->AddChild(std::make_shared>(OpType::Leaf)); } bool ImplementLimit::Check(std::shared_ptr plan, - OptimizeContext *context) const { + OptimizeContext *context) const { (void)context; (void)plan; return true; @@ -806,7 +806,7 @@ bool ImplementLimit::Check(std::shared_ptr plan, void ImplementLimit::Transform( std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const { + OptimizeContext *context) const { (void)context; const LogicalLimit *limit_op = input->Op().As(); @@ -825,20 +825,20 @@ void ImplementLimit::Transform( /// LogicalExport to Physical Export LogicalExportToPhysicalExport::LogicalExportToPhysicalExport() { type_ = RuleType::EXPORT_EXTERNAL_FILE_TO_PHYSICAL; - match_pattern = std::make_shared(OpType::LogicalExportExternalFile); - match_pattern->AddChild(std::make_shared(OpType::Leaf)); + match_pattern = std::make_shared>(OpType::LogicalExportExternalFile); + match_pattern->AddChild(std::make_shared>(OpType::Leaf)); } bool LogicalExportToPhysicalExport::Check( UNUSED_ATTRIBUTE std::shared_ptr plan, - UNUSED_ATTRIBUTE OptimizeContext *context) const { + UNUSED_ATTRIBUTE OptimizeContext *context) const { return true; } void LogicalExportToPhysicalExport::Transform( std::shared_ptr input, std::vector> &transformed, - UNUSED_ATTRIBUTE OptimizeContext *context) const { + UNUSED_ATTRIBUTE OptimizeContext *context) const { const auto *export_op = input->Op().As(); auto result_plan = @@ -863,26 +863,26 @@ PushFilterThroughJoin::PushFilterThroughJoin() { type_ = RuleType::PUSH_FILTER_THROUGH_JOIN; // Make three node types for pattern matching - std::shared_ptr child(std::make_shared(OpType::InnerJoin)); - child->AddChild(std::make_shared(OpType::Leaf)); - child->AddChild(std::make_shared(OpType::Leaf)); + std::shared_ptr> child(std::make_shared>(OpType::InnerJoin)); + child->AddChild(std::make_shared>(OpType::Leaf)); + child->AddChild(std::make_shared>(OpType::Leaf)); // Initialize a pattern for optimizer to match - match_pattern = std::make_shared(OpType::LogicalFilter); + match_pattern = std::make_shared>(OpType::LogicalFilter); // Add node - we match join relation R and S as well as the predicate exp match_pattern->AddChild(child); } bool PushFilterThroughJoin::Check(std::shared_ptr, - OptimizeContext *) const { + OptimizeContext *) const { return true; } void PushFilterThroughJoin::Transform( std::shared_ptr input, std::vector> &transformed, - UNUSED_ATTRIBUTE OptimizeContext *context) const { + UNUSED_ATTRIBUTE OptimizeContext *context) const { LOG_TRACE("PushFilterThroughJoin::Transform"); auto &memo = context->metadata->memo; auto join_op_expr = input->Children().at(0); @@ -955,26 +955,26 @@ void PushFilterThroughJoin::Transform( PushFilterThroughAggregation::PushFilterThroughAggregation() { type_ = RuleType::PUSH_FILTER_THROUGH_JOIN; - std::shared_ptr child( - std::make_shared(OpType::LogicalAggregateAndGroupBy)); - child->AddChild(std::make_shared(OpType::Leaf)); + std::shared_ptr> child( + std::make_shared>(OpType::LogicalAggregateAndGroupBy)); + child->AddChild(std::make_shared>(OpType::Leaf)); // Initialize a pattern for optimizer to match - match_pattern = std::make_shared(OpType::LogicalFilter); + match_pattern = std::make_shared>(OpType::LogicalFilter); // Add node - we match (filter)->(aggregation)->(leaf) match_pattern->AddChild(child); } bool PushFilterThroughAggregation::Check(std::shared_ptr, - OptimizeContext *) const { + OptimizeContext *) const { return true; } void PushFilterThroughAggregation::Transform( std::shared_ptr input, std::vector> &transformed, - UNUSED_ATTRIBUTE OptimizeContext *context) const { + UNUSED_ATTRIBUTE OptimizeContext *context) const { LOG_TRACE("PushFilterThroughAggregation::Transform"); auto aggregation_op = input->Children().at(0)->Op().As(); @@ -1022,16 +1022,16 @@ void PushFilterThroughAggregation::Transform( CombineConsecutiveFilter::CombineConsecutiveFilter() { type_ = RuleType::COMBINE_CONSECUTIVE_FILTER; - match_pattern = std::make_shared(OpType::LogicalFilter); - std::shared_ptr child( - std::make_shared(OpType::LogicalFilter)); - child->AddChild(std::make_shared(OpType::Leaf)); + match_pattern = std::make_shared>(OpType::LogicalFilter); + std::shared_ptr> child( + std::make_shared>(OpType::LogicalFilter)); + child->AddChild(std::make_shared>(OpType::Leaf)); match_pattern->AddChild(child); } bool CombineConsecutiveFilter::Check(std::shared_ptr plan, - OptimizeContext *context) const { + OptimizeContext *context) const { (void)context; (void)plan; @@ -1048,7 +1048,7 @@ bool CombineConsecutiveFilter::Check(std::shared_ptr plan, void CombineConsecutiveFilter::Transform( std::shared_ptr input, std::vector> &transformed, - UNUSED_ATTRIBUTE OptimizeContext *context) const { + UNUSED_ATTRIBUTE OptimizeContext *context) const { auto child_filter = input->Children()[0]; auto root_predicates = input->Op().As()->predicates; @@ -1071,14 +1071,14 @@ void CombineConsecutiveFilter::Transform( EmbedFilterIntoGet::EmbedFilterIntoGet() { type_ = RuleType::EMBED_FILTER_INTO_GET; - match_pattern = std::make_shared(OpType::LogicalFilter); - std::shared_ptr child(std::make_shared(OpType::Get)); + match_pattern = std::make_shared>(OpType::LogicalFilter); + std::shared_ptr> child(std::make_shared>(OpType::Get)); match_pattern->AddChild(child); } bool EmbedFilterIntoGet::Check(std::shared_ptr plan, - OptimizeContext *context) const { + OptimizeContext *context) const { (void)context; (void)plan; return true; @@ -1087,7 +1087,7 @@ bool EmbedFilterIntoGet::Check(std::shared_ptr plan, void EmbedFilterIntoGet::Transform( std::shared_ptr input, std::vector> &transformed, - UNUSED_ATTRIBUTE OptimizeContext *context) const { + UNUSED_ATTRIBUTE OptimizeContext *context) const { auto get = input->Children()[0]->Op().As(); auto predicates = input->Op().As()->predicates; @@ -1105,13 +1105,13 @@ void EmbedFilterIntoGet::Transform( MarkJoinToInnerJoin::MarkJoinToInnerJoin() { type_ = RuleType::MARK_JOIN_GET_TO_INNER_JOIN; - match_pattern = std::make_shared(OpType::LogicalMarkJoin); - match_pattern->AddChild(std::make_shared(OpType::Leaf)); - match_pattern->AddChild(std::make_shared(OpType::Leaf)); + match_pattern = std::make_shared>(OpType::LogicalMarkJoin); + match_pattern->AddChild(std::make_shared>(OpType::Leaf)); + match_pattern->AddChild(std::make_shared>(OpType::Leaf)); } -int MarkJoinToInnerJoin::Promise(GroupExpression *group_expr, - OptimizeContext *context) const { +int MarkJoinToInnerJoin::Promise(GroupExpression *group_expr, + OptimizeContext *context) const { (void)context; auto root_type = match_pattern->Type(); // This rule is not applicable @@ -1122,7 +1122,7 @@ int MarkJoinToInnerJoin::Promise(GroupExpression *group_expr, } bool MarkJoinToInnerJoin::Check(std::shared_ptr plan, - OptimizeContext *context) const { + OptimizeContext *context) const { (void)context; (void)plan; @@ -1135,7 +1135,7 @@ bool MarkJoinToInnerJoin::Check(std::shared_ptr plan, void MarkJoinToInnerJoin::Transform( std::shared_ptr input, std::vector> &transformed, - UNUSED_ATTRIBUTE OptimizeContext *context) const { + UNUSED_ATTRIBUTE OptimizeContext *context) const { LOG_TRACE("MarkJoinToInnerJoin::Transform"); UNUSED_ATTRIBUTE auto mark_join = input->Op().As(); auto &join_children = input->Children(); @@ -1156,13 +1156,13 @@ void MarkJoinToInnerJoin::Transform( SingleJoinToInnerJoin::SingleJoinToInnerJoin() { type_ = RuleType::MARK_JOIN_GET_TO_INNER_JOIN; - match_pattern = std::make_shared(OpType::LogicalSingleJoin); - match_pattern->AddChild(std::make_shared(OpType::Leaf)); - match_pattern->AddChild(std::make_shared(OpType::Leaf)); + match_pattern = std::make_shared>(OpType::LogicalSingleJoin); + match_pattern->AddChild(std::make_shared>(OpType::Leaf)); + match_pattern->AddChild(std::make_shared>(OpType::Leaf)); } -int SingleJoinToInnerJoin::Promise(GroupExpression *group_expr, - OptimizeContext *context) const { +int SingleJoinToInnerJoin::Promise(GroupExpression *group_expr, + OptimizeContext *context) const { (void)context; auto root_type = match_pattern->Type(); // This rule is not applicable @@ -1173,7 +1173,7 @@ int SingleJoinToInnerJoin::Promise(GroupExpression *group_expr, } bool SingleJoinToInnerJoin::Check(std::shared_ptr plan, - OptimizeContext *context) const { + OptimizeContext *context) const { (void)context; (void)plan; @@ -1186,7 +1186,7 @@ bool SingleJoinToInnerJoin::Check(std::shared_ptr plan, void SingleJoinToInnerJoin::Transform( std::shared_ptr input, std::vector> &transformed, - UNUSED_ATTRIBUTE OptimizeContext *context) const { + UNUSED_ATTRIBUTE OptimizeContext *context) const { LOG_TRACE("SingleJoinToInnerJoin::Transform"); UNUSED_ATTRIBUTE auto single_join = input->Op().As(); auto &join_children = input->Children(); @@ -1207,15 +1207,15 @@ void SingleJoinToInnerJoin::Transform( PullFilterThroughMarkJoin::PullFilterThroughMarkJoin() { type_ = RuleType::PULL_FILTER_THROUGH_MARK_JOIN; - match_pattern = std::make_shared(OpType::LogicalMarkJoin); - match_pattern->AddChild(std::make_shared(OpType::Leaf)); - auto filter = std::make_shared(OpType::LogicalFilter); - filter->AddChild(std::make_shared(OpType::Leaf)); + match_pattern = std::make_shared>(OpType::LogicalMarkJoin); + match_pattern->AddChild(std::make_shared>(OpType::Leaf)); + auto filter = std::make_shared>(OpType::LogicalFilter); + filter->AddChild(std::make_shared>(OpType::Leaf)); match_pattern->AddChild(filter); } -int PullFilterThroughMarkJoin::Promise(GroupExpression *group_expr, - OptimizeContext *context) const { +int PullFilterThroughMarkJoin::Promise(GroupExpression *group_expr, + OptimizeContext *context) const { (void)context; auto root_type = match_pattern->Type(); // This rule is not applicable @@ -1226,7 +1226,7 @@ int PullFilterThroughMarkJoin::Promise(GroupExpression *group_expr, } bool PullFilterThroughMarkJoin::Check(std::shared_ptr plan, - OptimizeContext *context) const { + OptimizeContext *context) const { (void)context; (void)plan; @@ -1241,7 +1241,7 @@ bool PullFilterThroughMarkJoin::Check(std::shared_ptr plan, void PullFilterThroughMarkJoin::Transform( std::shared_ptr input, std::vector> &transformed, - UNUSED_ATTRIBUTE OptimizeContext *context) const { + UNUSED_ATTRIBUTE OptimizeContext *context) const { LOG_TRACE("PullFilterThroughMarkJoin::Transform"); UNUSED_ATTRIBUTE auto mark_join = input->Op().As(); auto &join_children = input->Children(); @@ -1269,14 +1269,14 @@ void PullFilterThroughMarkJoin::Transform( PullFilterThroughAggregation::PullFilterThroughAggregation() { type_ = RuleType::PULL_FILTER_THROUGH_AGGREGATION; - auto filter = std::make_shared(OpType::LogicalFilter); - filter->AddChild(std::make_shared(OpType::Leaf)); - match_pattern = std::make_shared(OpType::LogicalAggregateAndGroupBy); + auto filter = std::make_shared>(OpType::LogicalFilter); + filter->AddChild(std::make_shared>(OpType::Leaf)); + match_pattern = std::make_shared>(OpType::LogicalAggregateAndGroupBy); match_pattern->AddChild(filter); } -int PullFilterThroughAggregation::Promise(GroupExpression *group_expr, - OptimizeContext *context) const { +int PullFilterThroughAggregation::Promise(GroupExpression *group_expr, + OptimizeContext *context) const { (void)context; auto root_type = match_pattern->Type(); // This rule is not applicable @@ -1287,7 +1287,7 @@ int PullFilterThroughAggregation::Promise(GroupExpression *group_expr, } bool PullFilterThroughAggregation::Check( - std::shared_ptr plan, OptimizeContext *context) const { + std::shared_ptr plan, OptimizeContext *context) const { (void)context; (void)plan; @@ -1302,7 +1302,7 @@ bool PullFilterThroughAggregation::Check( void PullFilterThroughAggregation::Transform( std::shared_ptr input, std::vector> &transformed, - UNUSED_ATTRIBUTE OptimizeContext *context) const { + UNUSED_ATTRIBUTE OptimizeContext *context) const { LOG_TRACE("PullFilterThroughAggregation::Transform"); auto &memo = context->metadata->memo; auto &filter_expr = input->Children()[0]; diff --git a/src/optimizer/rule_rewrite.cpp b/src/optimizer/rule_rewrite.cpp new file mode 100644 index 00000000000..88d23092c31 --- /dev/null +++ b/src/optimizer/rule_rewrite.cpp @@ -0,0 +1,91 @@ +#include + +#include "catalog/column_catalog.h" +#include "catalog/index_catalog.h" +#include "catalog/table_catalog.h" +#include "optimizer/operators.h" +#include "optimizer/absexpr_expression.h" +#include "optimizer/optimizer_metadata.h" +#include "optimizer/properties.h" +#include "optimizer/rule_rewrite.h" +#include "optimizer/util.h" +#include "type/value_factory.h" + +namespace peloton { +namespace optimizer { + +ComparatorElimination::ComparatorElimination() { + type_ = RuleType::COMP_EQUALITY_ELIMINATION; + + match_pattern = std::make_shared>(ExpressionType::COMPARE_EQUAL); + auto left = std::make_shared>(ExpressionType::VALUE_CONSTANT); + auto right = std::make_shared>(ExpressionType::VALUE_CONSTANT); + match_pattern->AddChild(left); + match_pattern->AddChild(right); +} + +int ComparatorElimination::Promise(GroupExpression *group_expr, + OptimizeContext *context) const { + (void)group_expr; + (void)context; + return static_cast(RulePriority::HIGH); +} + +bool ComparatorElimination::Check(std::shared_ptr plan, + OptimizeContext *context) const { + (void)context; + (void)plan; + + // If any of these assertions fail, something is seriously wrong with GroupExprBinding + // Verify the structure of the tree is correct + PELOTON_ASSERT(plan != nullptr); + PELOTON_ASSERT(plan->Children().size() == 2); + PELOTON_ASSERT(plan->Op().GetType() == ExpressionType::COMPARE_EQUAL); + + auto left = plan->Children()[0]; + auto right = plan->Children()[1]; + PELOTON_ASSERT(left->Children().size() == 0); + PELOTON_ASSERT(left->Op().GetType() == ExpressionType::VALUE_CONSTANT); + PELOTON_ASSERT(right->Children().size() == 0); + PELOTON_ASSERT(right->Op().GetType() == ExpressionType::VALUE_CONSTANT); + + // Technically, if structure matches, rule should always be applied + return true; +} + +void ComparatorElimination::Transform(std::shared_ptr input, + std::vector> &transformed, + UNUSED_ATTRIBUTE OptimizeContext *context) const { + (void)transformed; + (void)context; + + // (TODO): create a wrapper for evaluating ConstantValue relations (pending email reply) + + // Extract the AbstractExpression through indirection layer + auto left = input->Children()[0]->Op().GetExpr(); + auto right = input->Children()[1]->Op().GetExpr(); + auto lv = static_cast(left); + auto rv = static_cast(right); + lv = const_cast(lv); + rv = const_cast(rv); + + // Get the Value from ConstantValueExpression + auto lvalue = lv->GetValue(); + auto rvalue = rv->GetValue(); + + // Need to check type equality to prevent assertion failure + // This is only a Peloton issue (terrier checks type for you) + bool is_equal = (lvalue.GetTypeId() == rvalue.GetTypeId()) && + (lv->ExactlyEquals(*rv)); + + // Create the transformed expression + type::Value val = type::ValueFactory::GetBooleanValue(is_equal); + auto eq = new expression::ConstantValueExpression(val); + auto cnt = AbsExpr_Container(eq); + auto shared = std::make_shared(cnt); + + // (TODO): figure out memory management once go to terrier (which use shared_ptr) + transformed.push_back(shared); +} +} // namespace optimizer +} // namespace peloton diff --git a/src/optimizer/stats/child_stats_deriver.cpp b/src/optimizer/stats/child_stats_deriver.cpp index d320547915c..0fbf2720d99 100644 --- a/src/optimizer/stats/child_stats_deriver.cpp +++ b/src/optimizer/stats/child_stats_deriver.cpp @@ -20,9 +20,9 @@ namespace peloton { namespace optimizer { using std::vector; -vector ChildStatsDeriver::DeriveInputStats(GroupExpression *gexpr, +vector ChildStatsDeriver::DeriveInputStats(GroupExpression *gexpr, ExprSet required_cols, - Memo *memo) { + Memo *memo) { required_cols_ = required_cols; gexpr_ = gexpr; memo_ = memo; diff --git a/src/optimizer/stats/stats_calculator.cpp b/src/optimizer/stats/stats_calculator.cpp index d086938a817..815e309290b 100644 --- a/src/optimizer/stats/stats_calculator.cpp +++ b/src/optimizer/stats/stats_calculator.cpp @@ -26,8 +26,8 @@ namespace peloton { namespace optimizer { -void StatsCalculator::CalculateStats(GroupExpression *gexpr, - ExprSet required_cols, Memo *memo, +void StatsCalculator::CalculateStats(GroupExpression *gexpr, + ExprSet required_cols, Memo *memo, concurrency::TransactionContext *txn) { gexpr_ = gexpr; memo_ = memo; diff --git a/test/include/optimizer/mock_task.h b/test/include/optimizer/mock_task.h index 32e5e1b8da4..7e18f458445 100644 --- a/test/include/optimizer/mock_task.h +++ b/test/include/optimizer/mock_task.h @@ -20,10 +20,10 @@ namespace peloton { namespace optimizer { namespace test { -class MockTask : public optimizer::OptimizerTask { +class MockTask : public optimizer::OptimizerTask { public: MockTask() - : optimizer::OptimizerTask(nullptr, OptimizerTaskType::OPTIMIZE_GROUP) {} + : optimizer::OptimizerTask(nullptr, OptimizerTaskType::OPTIMIZE_GROUP) {} MOCK_METHOD0(execute, void()); }; diff --git a/test/optimizer/optimizer_rule_test.cpp b/test/optimizer/optimizer_rule_test.cpp index 23f520596dc..9868cfa924e 100644 --- a/test/optimizer/optimizer_rule_test.cpp +++ b/test/optimizer/optimizer_rule_test.cpp @@ -132,8 +132,8 @@ TEST_F(OptimizerRuleTests, SimpleAssociativeRuleTest) { optimizer.GetMetadata().memo.InsertExpression( optimizer.GetMetadata().MakeGroupExpression(parent_join), true); - OptimizeContext *root_context = - new OptimizeContext(&(optimizer.GetMetadata()), nullptr); + OptimizeContext *root_context = + new OptimizeContext(&(optimizer.GetMetadata()), nullptr); EXPECT_EQ(left_leaf, parent_join->Children()[0]->Children()[0]); EXPECT_EQ(middle_leaf, parent_join->Children()[0]->Children()[1]); @@ -227,8 +227,8 @@ TEST_F(OptimizerRuleTests, SimpleAssociativeRuleTest2) { optimizer.GetMetadata().memo.InsertExpression( optimizer.GetMetadata().MakeGroupExpression(parent_join), true); - OptimizeContext *root_context = - new OptimizeContext(&(optimizer.GetMetadata()), nullptr); + OptimizeContext *root_context = + new OptimizeContext(&(optimizer.GetMetadata()), nullptr); EXPECT_EQ(left_leaf, parent_join->Children()[0]->Children()[0]); EXPECT_EQ(middle_leaf, parent_join->Children()[0]->Children()[1]); diff --git a/test/optimizer/optimizer_test.cpp b/test/optimizer/optimizer_test.cpp index f1ffd6add66..f9fd843b3b3 100644 --- a/test/optimizer/optimizer_test.cpp +++ b/test/optimizer/optimizer_test.cpp @@ -49,7 +49,9 @@ using namespace optimizer; class OptimizerTests : public PelotonTest { protected: - GroupExpression *GetSingleGroupExpression(Memo &memo, GroupExpression *expr, + GroupExpression *GetSingleGroupExpression( + Memo &memo, + GroupExpression *expr, int child_group_idx) { auto group = memo.GetGroupByID(expr->GetChildGroupId(child_group_idx)); EXPECT_EQ(1, group->GetLogicalExpressions().size()); @@ -343,19 +345,19 @@ TEST_F(OptimizerTests, PushFilterThroughJoinTest) { auto bind_node_visitor = binder::BindNodeVisitor(txn, DEFAULT_DB_NAME); bind_node_visitor.BindNameToNode(parse_tree); - std::shared_ptr gexpr = + std::shared_ptr> gexpr = optimizer.TestInsertQueryTree(parse_tree, txn); std::vector child_groups = {gexpr->GetGroupID()}; - std::shared_ptr head_gexpr = - std::make_shared(Operator(), child_groups); + std::shared_ptr> head_gexpr = + std::make_shared>(Operator(), child_groups); - std::shared_ptr root_context = - std::make_shared(&(optimizer.GetMetadata()), nullptr); + std::shared_ptr> root_context = + std::make_shared>(&(optimizer.GetMetadata()), nullptr); auto task_stack = - std::unique_ptr(new OptimizerTaskStack()); + std::unique_ptr>(new OptimizerTaskStack()); optimizer.GetMetadata().SetTaskPool(task_stack.get()); - task_stack->Push(new TopDownRewrite(gexpr->GetGroupID(), root_context, + task_stack->Push(new TopDownRewrite(gexpr->GetGroupID(), root_context, RewriteRuleSetName::PREDICATE_PUSH_DOWN)); while (!task_stack->Empty()) { @@ -430,19 +432,19 @@ TEST_F(OptimizerTests, PredicatePushDownRewriteTest) { auto bind_node_visitor = binder::BindNodeVisitor(txn, DEFAULT_DB_NAME); bind_node_visitor.BindNameToNode(parse_tree); - std::shared_ptr gexpr = + std::shared_ptr> gexpr = optimizer.TestInsertQueryTree(parse_tree, txn); std::vector child_groups = {gexpr->GetGroupID()}; - std::shared_ptr head_gexpr = - std::make_shared(Operator(), child_groups); + std::shared_ptr> head_gexpr = + std::make_shared>(Operator(), child_groups); - std::shared_ptr root_context = - std::make_shared(&(optimizer.GetMetadata()), nullptr); + std::shared_ptr> root_context = + std::make_shared>(&(optimizer.GetMetadata()), nullptr); auto task_stack = - std::unique_ptr(new OptimizerTaskStack()); + std::unique_ptr>(new OptimizerTaskStack()); optimizer.GetMetadata().SetTaskPool(task_stack.get()); - task_stack->Push(new TopDownRewrite(gexpr->GetGroupID(), root_context, + task_stack->Push(new TopDownRewrite(gexpr->GetGroupID(), root_context, RewriteRuleSetName::PREDICATE_PUSH_DOWN)); while (!task_stack->Empty()) { @@ -486,14 +488,14 @@ TEST_F(OptimizerTests, ExecuteTaskStackTest) { optimizer::Optimizer optimizer; const int root_group_id = 0; const auto root_group = - new Group(root_group_id, std::unordered_set()); + new Group(root_group_id, std::unordered_set()); optimizer.GetMetadata().memo.Groups().emplace_back(root_group); auto required_prop = std::make_shared(PropertySet()); - auto root_context = std::make_shared( + auto root_context = std::make_shared>( &(optimizer.GetMetadata()), required_prop); auto task_stack = - std::unique_ptr(new OptimizerTaskStack()); + std::unique_ptr>(new OptimizerTaskStack()); auto &timer = optimizer.GetMetadata().timer; // Insert tasks into task stack diff --git a/test/optimizer/rewriter_test.cpp b/test/optimizer/rewriter_test.cpp new file mode 100644 index 00000000000..48c4b0420b9 --- /dev/null +++ b/test/optimizer/rewriter_test.cpp @@ -0,0 +1,206 @@ +//===----------------------------------------------------------------------===// +// +// Peloton +// +// operator_test.cpp +// +// Identification: test/optimizer/operator_test.cpp +// +// Copyright (c) 2015-2018, Carnegie Mellon University Database Group +// +//===----------------------------------------------------------------------===// + +#include +#include "common/harness.h" + +#include "optimizer/operators.h" +#include "optimizer/rewriter.h" +#include "expression/constant_value_expression.h" +#include "expression/comparison_expression.h" +#include "expression/tuple_value_expression.h" +#include "type/value_factory.h" +#include "type/value_peeker.h" +#include "optimizer/rule_rewrite.h" + +namespace peloton { + +namespace test { + +using namespace optimizer; + +class RewriterTests : public PelotonTest {}; + +TEST_F(RewriterTests, ConvertAbsExpr) { + type::Value leftValue = type::ValueFactory::GetIntegerValue(1); + type::Value rightValue = type::ValueFactory::GetIntegerValue(2); + auto left = new expression::ConstantValueExpression(leftValue); + auto right = new expression::ConstantValueExpression(rightValue); + auto common = new expression::ComparisonExpression(ExpressionType::COMPARE_EQUAL, left, right); + + Rewriter *rewriter = new Rewriter(); + + auto absexpr = rewriter->ConvertToAbsExpr(common); + EXPECT_TRUE(absexpr != nullptr); + EXPECT_TRUE(absexpr->Op().GetType() == ExpressionType::COMPARE_EQUAL); + EXPECT_TRUE(absexpr->Children().size() == 2); + + auto lefta = absexpr->Children()[0]; + auto righta = absexpr->Children()[1]; + EXPECT_TRUE(lefta != nullptr && righta != nullptr); + EXPECT_TRUE(lefta->Op().GetType() == righta->Op().GetType()); + EXPECT_TRUE(lefta->Op().GetType() == ExpressionType::VALUE_CONSTANT); + + auto left_cve = static_cast(lefta->Op().GetExpr()); + auto right_cve = static_cast(righta->Op().GetExpr()); + EXPECT_TRUE(left_cve == left); + EXPECT_TRUE(right_cve == right); + + // Try applying the rule + ComparatorElimination rule; + EXPECT_TRUE(rule.Check(absexpr, nullptr) == true); + + std::vector> transform; + rule.Transform(absexpr, transform, nullptr); + EXPECT_TRUE(transform.size() == 1); + + delete rewriter; + delete common; + + auto tr_expr = transform[0]; + EXPECT_TRUE(tr_expr != nullptr); + EXPECT_TRUE(tr_expr->Op().GetType() == ExpressionType::VALUE_CONSTANT); + EXPECT_TRUE(tr_expr->Children().size() == 0); + + auto tr_cve = static_cast(tr_expr->Op().GetExpr()); + EXPECT_TRUE(type::ValuePeeker::PeekBoolean(tr_cve->GetValue()) == false); + + // (TODO): hack to fix the memory leak bubbled from Transform() + delete tr_cve; +} + +TEST_F(RewriterTests, SingleCompareEqualRewritePassFalse) { + type::Value leftValue = type::ValueFactory::GetIntegerValue(3); + type::Value rightValue = type::ValueFactory::GetIntegerValue(2); + auto left = new expression::ConstantValueExpression(leftValue); + auto right = new expression::ConstantValueExpression(rightValue); + auto common = new expression::ComparisonExpression(ExpressionType::COMPARE_EQUAL, left, right); + + Rewriter *rewriter = new Rewriter(); + auto rewrote = rewriter->RewriteExpression(common); + + delete rewriter; + delete common; + + auto casted = dynamic_cast(rewrote); + EXPECT_TRUE(casted != nullptr); + EXPECT_TRUE(type::ValuePeeker::PeekBoolean(casted->GetValue()) == false); + + delete rewrote; +} + +TEST_F(RewriterTests, SingleCompareEqualRewritePassTrue) { + type::Value leftValue = type::ValueFactory::GetIntegerValue(4); + type::Value rightValue = type::ValueFactory::GetIntegerValue(4); + auto left = new expression::ConstantValueExpression(leftValue); + auto right = new expression::ConstantValueExpression(rightValue); + auto common = new expression::ComparisonExpression(ExpressionType::COMPARE_EQUAL, left, right); + + Rewriter *rewriter = new Rewriter(); + auto rewrote = rewriter->RewriteExpression(common); + + delete rewriter; + delete common; + + auto casted = dynamic_cast(rewrote); + EXPECT_TRUE(casted != nullptr); + EXPECT_TRUE(type::ValuePeeker::PeekBoolean(casted->GetValue()) == true); + + delete rewrote; +} + +TEST_F(RewriterTests, SimpleEqualityTree) { + // [=] + // [=] [=] + // [4] [5] [3] [3] + type::Value val4 = type::ValueFactory::GetIntegerValue(4); + type::Value val5 = type::ValueFactory::GetIntegerValue(5); + type::Value val3 = type::ValueFactory::GetIntegerValue(3); + + auto lb_left_child = new expression::ConstantValueExpression(val4); + auto lb_right_child = new expression::ConstantValueExpression(val5); + auto rb_left_child = new expression::ConstantValueExpression(val3); + auto rb_right_child = new expression::ConstantValueExpression(val3); + + auto lb = new expression::ComparisonExpression(ExpressionType::COMPARE_EQUAL, + lb_left_child, lb_right_child); + auto rb = new expression::ComparisonExpression(ExpressionType::COMPARE_EQUAL, + rb_left_child, rb_right_child); + auto top = new expression::ComparisonExpression(ExpressionType::COMPARE_EQUAL, lb, rb); + + Rewriter *rewriter = new Rewriter(); + auto rewrote = rewriter->RewriteExpression(top); + + delete rewriter; + delete top; + + EXPECT_TRUE(rewrote != nullptr); + EXPECT_TRUE(rewrote->GetChildrenSize() == 0); + EXPECT_TRUE(rewrote->GetExpressionType() == ExpressionType::VALUE_CONSTANT); + + auto casted = dynamic_cast(rewrote); + EXPECT_TRUE(casted->GetValueType() == type::TypeId::BOOLEAN); + EXPECT_TRUE(type::ValuePeeker::PeekBoolean(casted->GetValue()) == false); + + delete rewrote; +} + +// (TODO): delete this test once more rewriting rules implemented +TEST_F(RewriterTests, SimpleJunctionPreserve) { + // [AND] + // [=] [=] + // [4] [5] [3] [3] + type::Value val4 = type::ValueFactory::GetIntegerValue(4); + type::Value val5 = type::ValueFactory::GetIntegerValue(5); + type::Value val3 = type::ValueFactory::GetIntegerValue(3); + + auto lb_left_child = new expression::ConstantValueExpression(val4); + auto lb_right_child = new expression::ConstantValueExpression(val5); + auto rb_left_child = new expression::ConstantValueExpression(val3); + auto rb_right_child = new expression::ConstantValueExpression(val3); + + auto lb = new expression::ComparisonExpression(ExpressionType::COMPARE_EQUAL, + lb_left_child, lb_right_child); + auto rb = new expression::ComparisonExpression(ExpressionType::COMPARE_EQUAL, + rb_left_child, rb_right_child); + auto top = new expression::ConjunctionExpression(ExpressionType::CONJUNCTION_AND, lb, rb); + + Rewriter *rewriter = new Rewriter(); + auto rewrote = rewriter->RewriteExpression(top); + + delete rewriter; + delete top; + + EXPECT_TRUE(rewrote != nullptr); + EXPECT_TRUE(rewrote->GetChildrenSize() == 2); + EXPECT_TRUE(rewrote->GetExpressionType() == ExpressionType::CONJUNCTION_AND); + + auto left = rewrote->GetChild(0); + auto right = rewrote->GetChild(1); + + EXPECT_TRUE(left != nullptr && right != nullptr); + EXPECT_TRUE(left->GetExpressionType() == ExpressionType::VALUE_CONSTANT); + EXPECT_TRUE(right->GetExpressionType() == ExpressionType::VALUE_CONSTANT); + + auto left_cast = dynamic_cast(left); + auto right_cast = dynamic_cast(right); + EXPECT_TRUE(left_cast->GetValueType() == type::TypeId::BOOLEAN); + EXPECT_TRUE(right_cast->GetValueType() == type::TypeId::BOOLEAN); + + EXPECT_TRUE(type::ValuePeeker::PeekBoolean(left_cast->GetValue()) == false); + EXPECT_TRUE(type::ValuePeeker::PeekBoolean(right_cast->GetValue()) == true); + + delete rewrote; +} + +} // namespace test +} // namespace peloton