Skip to content
This repository was archived by the owner on Sep 27, 2019. It is now read-only.

Commit 17de3b9

Browse files
committed
Using AbstractNode throughout optimizer.
Abstract nodes were implemented in 209c46a. This is essentially just refactoring and plugging in abstract nodes throughout the optimizer. The abstract interface exposes OpType and ExpressionType for now, which ideally will be fixed later. Work remaining for abstracting OperatorExpression.
1 parent 209c46a commit 17de3b9

23 files changed

+193
-164
lines changed

src/include/optimizer/abstract_node.h

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,33 +83,44 @@ enum class OpType {
8383
class OperatorVisitor;
8484

8585
struct AbstractNode {
86-
AbstractNode() {}
86+
AbstractNode(AbstractNode *node) : node(node) {}
8787

8888
~AbstractNode() {}
8989

9090
virtual void Accept(OperatorVisitor *v) const = 0;
9191

9292
virtual std::string GetName() const = 0;
9393

94-
// TODO(ncx): problematic dependence on OpType
95-
virtual OpType GetType() const = 0;
94+
// TODO(ncx): dependence on OpType and ExpressionType (ideally abstracted away)
95+
virtual OpType GetOpType() const = 0;
96+
97+
virtual ExpressionType GetExpType() const = 0;
9698

9799
virtual bool IsLogical() const = 0;
98100

99101
virtual bool IsPhysical() const = 0;
100102

101103
virtual hash_t Hash() const {
102-
OpType t = GetType();
104+
// TODO(ncx): hash should work for ExpressionType nodes
105+
OpType t = GetOpType();
103106
return HashUtil::Hash(&t);
104107
}
105108

106109
virtual bool operator==(const AbstractNode &r) {
107-
return GetType() == r.GetType();
110+
return GetOpType() == r.GetOpType() && GetExpType() == r.GetExpType();
108111
}
109112

110113
virtual bool IsDefined() const { return node != nullptr; }
111114

112-
private:
115+
template <typename T>
116+
const T *As() const {
117+
if (node && typeid(*node) == typeid(T)) {
118+
return (const T *)node.get();
119+
}
120+
return nullptr;
121+
}
122+
123+
protected:
113124
std::shared_ptr<AbstractNode> node;
114125
};
115126

src/include/optimizer/cost_model/default_cost_model.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class DefaultCostModel : public AbstractCostModel {
3434
gexpr_ = gexpr;
3535
memo_ = memo;
3636
txn_ = txn;
37-
gexpr_->Op().Accept(this);
37+
gexpr_->Op()->Accept(this);
3838
return output_cost_;
3939
}
4040

src/include/optimizer/cost_model/postgres_cost_model.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class PostgresCostModel : public AbstractCostModel {
3939
gexpr_ = gexpr;
4040
memo_ = memo;
4141
txn_ = txn;
42-
gexpr_->Op().Accept(this);
42+
gexpr_->Op()->Accept(this);
4343
return output_cost_;
4444
};
4545

@@ -279,4 +279,4 @@ class PostgresCostModel : public AbstractCostModel {
279279
};
280280

281281
} // namespace optimizer
282-
} // namespace peloton
282+
} // namespace peloton

src/include/optimizer/cost_model/trivial_cost_model.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class TrivialCostModel : public AbstractCostModel {
4141
gexpr_ = gexpr;
4242
memo_ = memo;
4343
txn_ = txn;
44-
gexpr_->Op().Accept(this);
44+
gexpr_->Op()->Accept(this);
4545
return output_cost_;
4646
};
4747

@@ -116,4 +116,4 @@ class TrivialCostModel : public AbstractCostModel {
116116
};
117117

118118
} // namespace optimizer
119-
} // namespace peloton
119+
} // namespace peloton

src/include/optimizer/group_expression.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
#pragma once
1414

15-
#include "optimizer/operator_node.h"
15+
#include "optimizer/abstract_node.h"
1616
#include "optimizer/stats/stats.h"
1717
#include "optimizer/util.h"
1818
#include "optimizer/property_set.h"
@@ -34,7 +34,7 @@ using GroupID = int32_t;
3434
//===--------------------------------------------------------------------===//
3535
class GroupExpression {
3636
public:
37-
GroupExpression(Operator op, std::vector<GroupID> child_groups);
37+
GroupExpression(std::shared_ptr<AbstractNode> node, std::vector<GroupID> child_groups);
3838

3939
GroupID GetGroupID() const;
4040

@@ -46,7 +46,7 @@ class GroupExpression {
4646

4747
GroupID GetChildGroupId(int child_idx) const;
4848

49-
Operator Op() const;
49+
std::shared_ptr<AbstractNode> Op() const;
5050

5151
double GetCost(std::shared_ptr<PropertySet>& requirements) const;
5252

@@ -75,7 +75,7 @@ class GroupExpression {
7575

7676
private:
7777
GroupID group_id;
78-
Operator op;
78+
std::shared_ptr<AbstractNode> node;
7979
std::vector<GroupID> child_groups;
8080
std::bitset<static_cast<uint32_t>(RuleType::NUM_RULES)> rule_mask_;
8181
bool stats_derived_;

src/include/optimizer/operator_node.h

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -28,26 +28,29 @@ namespace optimizer {
2828
class OperatorVisitor;
2929

3030
// Curiously recurring template pattern
31-
// TODO(ncx): this templating would be nice to clean up
3231
template <typename T>
3332
struct OperatorNode : public AbstractNode {
34-
OperatorNode() {}
33+
OperatorNode() : AbstractNode(nullptr) {}
3534

3635
virtual ~OperatorNode() {}
3736

3837
void Accept(OperatorVisitor *v) const;
3938

40-
virtual std::string GetName() const { return name_; }
39+
std::string GetName() const { return name_; }
4140

42-
virtual OpType GetType() const { return type_; }
41+
OpType GetOpType() const { return op_type_; }
42+
43+
ExpressionType GetExpType() const { return exp_type_; }
4344

4445
bool IsLogical() const;
4546

4647
bool IsPhysical() const;
4748

4849
static std::string name_;
4950

50-
static OpType type_;
51+
static OpType op_type_;
52+
53+
static ExpressionType exp_type_;
5154
};
5255

5356
class Operator : public AbstractNode {
@@ -60,7 +63,9 @@ class Operator : public AbstractNode {
6063

6164
std::string GetName() const;
6265

63-
OpType GetType() const;
66+
OpType GetOpType() const;
67+
68+
ExpressionType GetExpType() const;
6469

6570
bool IsLogical() const;
6671

@@ -71,21 +76,6 @@ class Operator : public AbstractNode {
7176
bool operator==(const Operator &r);
7277

7378
bool IsDefined() const;
74-
75-
template <typename T>
76-
const T *As() const {
77-
if (node && typeid(*node) == typeid(T)) {
78-
return (const T *)node.get();
79-
}
80-
return nullptr;
81-
}
82-
83-
static std::string name_;
84-
85-
static OpType type_;
86-
87-
private:
88-
std::shared_ptr<AbstractNode> node;
8979
};
9080

9181
} // namespace optimizer

src/include/optimizer/optimizer_metadata.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
#include "optimizer/rule.h"
2020
#include "settings/settings_manager.h"
2121

22+
#include <memory>
23+
2224
namespace peloton {
2325
namespace catalog {
2426
class Catalog;
@@ -58,7 +60,7 @@ class OptimizerMetadata {
5860
memo.InsertExpression(gexpr, false);
5961
child_groups.push_back(gexpr->GetGroupID());
6062
}
61-
return std::make_shared<GroupExpression>(expr->Op(),
63+
return std::make_shared<GroupExpression>(std::make_shared<Operator>(expr->Op()),
6264
std::move(child_groups));
6365
}
6466

src/optimizer/binding.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,9 @@ GroupExprBindingIterator::GroupExprBindingIterator(
8585
pattern_(pattern),
8686
first_(true),
8787
has_next_(false),
88-
current_binding_(std::make_shared<OperatorExpression>(gexpr->Op())) {
89-
if (gexpr->Op().GetType() != pattern->Type()) {
88+
// TODO(ncx): fix once OperatorExpression is abstracted
89+
current_binding_(std::make_shared<OperatorExpression>(*(Operator *)gexpr->Op().get())) {
90+
if (gexpr->Op()->GetOpType() != pattern->Type()) {
9091
return;
9192
}
9293

@@ -100,7 +101,7 @@ GroupExprBindingIterator::GroupExprBindingIterator(
100101
LOG_TRACE(
101102
"Attempting to bind on group %d with expression of type %s, children "
102103
"size %lu",
103-
gexpr->GetGroupID(), gexpr->Op().GetName().c_str(), child_groups.size());
104+
gexpr->GetGroupID(), gexpr->Op()->GetName().c_str(), child_groups.size());
104105

105106
// Find all bindings for children
106107
children_bindings_.resize(child_groups.size(), {});

src/optimizer/child_property_deriver.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ ChildPropertyDeriver::GetProperties(GroupExpression *gexpr,
3838
output_.clear();
3939
memo_ = memo;
4040
gexpr_ = gexpr;
41-
gexpr->Op().Accept(this);
41+
gexpr->Op()->Accept(this);
4242
return move(output_);
4343
}
4444

src/optimizer/group.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ void Group::AddExpression(std::shared_ptr<GroupExpression> expr,
3131
expr->SetGroupID(id_);
3232
if (enforced)
3333
enforced_exprs_.push_back(expr);
34-
else if (expr->Op().IsPhysical())
34+
else if (expr->Op()->IsPhysical())
3535
physical_expressions_.push_back(expr);
3636
else
3737
logical_expressions_.push_back(expr);
@@ -40,7 +40,7 @@ void Group::AddExpression(std::shared_ptr<GroupExpression> expr,
4040
bool Group::SetExpressionCost(GroupExpression *expr, double cost,
4141
std::shared_ptr<PropertySet> &properties) {
4242
LOG_TRACE("Adding expression cost on group %d with op %s, req %s",
43-
expr->GetGroupID(), expr->Op().GetName().c_str(),
43+
expr->GetGroupID(), expr->Op()->GetName().c_str(),
4444
properties->ToString().c_str());
4545
auto it = lowest_cost_expressions_.find(properties);
4646
if (it == lowest_cost_expressions_.end() || std::get<0>(it->second) > cost) {
@@ -86,7 +86,7 @@ const std::string Group::GetInfo(int num_indent) const {
8686

8787
for (auto expr : logical_expressions_) {
8888
os << StringUtil::Indent(num_indent + 4)
89-
<< expr->Op().GetName() << std::endl;
89+
<< expr->Op()->GetName() << std::endl;
9090
const std::vector<GroupID> ChildGroupIDs = expr->GetChildGroupIDs();
9191
if (ChildGroupIDs.size() > 0) {
9292
os << StringUtil::Indent(num_indent + 6)
@@ -102,7 +102,7 @@ const std::string Group::GetInfo(int num_indent) const {
102102
<< "physical_expressions_: \n";
103103
for (auto expr : physical_expressions_) {
104104
os << StringUtil::Indent(num_indent + 4)
105-
<< expr->Op().GetName() << std::endl;
105+
<< expr->Op()->GetName() << std::endl;
106106
const std::vector<GroupID> ChildGroupIDs = expr->GetChildGroupIDs();
107107
if (ChildGroupIDs.size() > 0) {
108108
os << StringUtil::Indent(num_indent + 6)
@@ -119,7 +119,7 @@ const std::string Group::GetInfo(int num_indent) const {
119119
<< "enforced_exprs_: \n";
120120
for (auto expr : enforced_exprs_) {
121121
os << StringUtil::Indent(num_indent + 4)
122-
<< expr->Op().GetName() << std::endl;
122+
<< expr->Op()->GetName() << std::endl;
123123
const std::vector<GroupID> ChildGroupIDs = expr->GetChildGroupIDs();
124124
if (ChildGroupIDs.size() > 0) {
125125
os << StringUtil::Indent(num_indent + 6)

0 commit comments

Comments
 (0)