Skip to content
This repository was archived by the owner on Sep 27, 2019. It is now read-only.

Commit b7a035b

Browse files
authored
Merge pull request #1150 from GustavoAngulo/join-reordering
Support for Join reordering
2 parents 8dde4d6 + b8d8cf8 commit b7a035b

15 files changed

+395
-73
lines changed

src/include/common/internal_types.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1304,6 +1304,7 @@ std::ostream &operator<<(std::ostream &os, const PropertyType &type);
13041304
enum class RuleType : uint32_t {
13051305
// Transformation rules (logical -> logical)
13061306
INNER_JOIN_COMMUTE = 0,
1307+
INNER_JOIN_ASSOCIATE,
13071308

13081309
// Don't move this one
13091310
LogicalPhysicalDelimiter,

src/include/optimizer/operator_node.h

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,9 @@ struct BaseOperatorNode {
8484

8585
virtual void Accept(OperatorVisitor *v) const = 0;
8686

87-
virtual std::string name() const = 0;
87+
virtual std::string GetName() const = 0;
8888

89-
virtual OpType type() const = 0;
89+
virtual OpType GetType() const = 0;
9090

9191
virtual bool IsLogical() const = 0;
9292

@@ -97,12 +97,12 @@ struct BaseOperatorNode {
9797
}
9898

9999
virtual hash_t Hash() const {
100-
OpType t = type();
100+
OpType t = GetType();
101101
return HashUtil::Hash(&t);
102102
}
103103

104104
virtual bool operator==(const BaseOperatorNode &r) {
105-
return type() == r.type();
105+
return GetType() == r.GetType();
106106
}
107107
};
108108

@@ -111,9 +111,9 @@ template <typename T>
111111
struct OperatorNode : public BaseOperatorNode {
112112
void Accept(OperatorVisitor *v) const;
113113

114-
std::string name() const { return name_; }
114+
std::string GetName() const { return name_; }
115115

116-
OpType type() const { return type_; }
116+
OpType GetType() const { return type_; }
117117

118118
bool IsLogical() const;
119119

@@ -130,21 +130,27 @@ class Operator {
130130

131131
Operator(BaseOperatorNode *node);
132132

133+
// Calls corresponding visitor to node
133134
void Accept(OperatorVisitor *v) const;
134135

135-
std::string name() const;
136+
// Return name of operator
137+
std::string GetName() const;
136138

137-
OpType type() const;
139+
// Return operator type
140+
OpType GetType() const;
138141

142+
// Operator contains Logical node
139143
bool IsLogical() const;
140144

145+
// Operator contains Physical node
141146
bool IsPhysical() const;
142147

143148
hash_t Hash() const;
144149

145150
bool operator==(const Operator &r);
146151

147-
bool defined() const;
152+
// Operator contains physical or logical operator node
153+
bool IsDefined() const;
148154

149155
template <typename T>
150156
const T *As() const {

src/include/optimizer/rule_impls.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,22 @@ class InnerJoinCommutativity : public Rule {
3838
OptimizeContext *context) const override;
3939
};
4040

41+
/**
42+
* @brief (A join B) join C -> A join (B join C)
43+
*/
44+
45+
class InnerJoinAssociativity : public Rule {
46+
public:
47+
InnerJoinAssociativity();
48+
49+
bool Check(std::shared_ptr<OperatorExpression> plan,
50+
OptimizeContext *context) const override;
51+
52+
void Transform(std::shared_ptr<OperatorExpression> input,
53+
std::vector<std::shared_ptr<OperatorExpression>> &transformed,
54+
OptimizeContext *context) const override;
55+
};
56+
4157
//===--------------------------------------------------------------------===//
4258
// Implementation rules
4359
//===--------------------------------------------------------------------===//

src/optimizer/binding.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,8 @@ GroupExprBindingIterator::GroupExprBindingIterator(Memo& memo,
8888
has_next_(false),
8989
current_binding_(std::make_shared<OperatorExpression>(gexpr->Op())) {
9090
LOG_TRACE("Attempting to bind on group %d with expression of type %s",
91-
gexpr->GetGroupID(), gexpr->Op().name().c_str());
92-
if (gexpr->Op().type() != pattern->Type()) return;
91+
gexpr->GetGroupID(), gexpr->Op().GetName().c_str());
92+
if (gexpr->Op().GetType() != pattern->Type()) return;
9393

9494
const std::vector<GroupID> &child_groups = gexpr->GetChildGroupIDs();
9595
const std::vector<std::shared_ptr<Pattern>> &child_patterns =

src/optimizer/group.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ void Group::AddExpression(std::shared_ptr<GroupExpression> expr,
4040
bool Group::SetExpressionCost(GroupExpression *expr, double cost,
4141
std::shared_ptr<PropertySet> &properties) {
4242
LOG_TRACE("Adding expression cost on group %d with op %s, req %s",
43-
expr->GetGroupID(), expr->Op().name().c_str(),
43+
expr->GetGroupID(), expr->Op().GetName().c_str(),
4444
properties->ToString().c_str());
4545
auto it = lowest_cost_expressions_.find(properties);
4646
if (it == lowest_cost_expressions_.end() || std::get<0>(it->second) > cost) {

src/optimizer/input_column_deriver.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -207,9 +207,9 @@ void InputColumnDeriver::AggregateHelper(const BaseOperatorNode *op) {
207207

208208
// TODO(boweic): do not use shared_ptr
209209
vector<shared_ptr<AbstractExpression>> groupby_cols;
210-
if (op->type() == OpType::HashGroupBy) {
210+
if (op->GetType() == OpType::HashGroupBy) {
211211
groupby_cols = reinterpret_cast<const PhysicalHashGroupBy *>(op)->columns;
212-
} else if (op->type() == OpType::SortGroupBy) {
212+
} else if (op->GetType() == OpType::SortGroupBy) {
213213
groupby_cols = reinterpret_cast<const PhysicalSortGroupBy *>(op)->columns;
214214
}
215215
for (auto &groupby_col : groupby_cols) {
@@ -230,12 +230,12 @@ void InputColumnDeriver::JoinHelper(const BaseOperatorNode *op) {
230230
const vector<unique_ptr<expression::AbstractExpression>> *left_keys = nullptr;
231231
const vector<unique_ptr<expression::AbstractExpression>> *right_keys =
232232
nullptr;
233-
if (op->type() == OpType::InnerHashJoin) {
233+
if (op->GetType() == OpType::InnerHashJoin) {
234234
auto join_op = reinterpret_cast<const PhysicalInnerHashJoin *>(op);
235235
join_conds = &(join_op->join_predicates);
236236
left_keys = &(join_op->left_keys);
237237
right_keys = &(join_op->right_keys);
238-
} else if (op->type() == OpType::InnerNLJoin) {
238+
} else if (op->GetType() == OpType::InnerNLJoin) {
239239
auto join_op = reinterpret_cast<const PhysicalInnerNLJoin *>(op);
240240
join_conds = &(join_op->join_predicates);
241241
left_keys = &(join_op->left_keys);

src/optimizer/memo.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ GroupExpression *Memo::InsertExpression(std::shared_ptr<GroupExpression> gexpr,
3131
GroupExpression *Memo::InsertExpression(std::shared_ptr<GroupExpression> gexpr,
3232
GroupID target_group, bool enforced) {
3333
// If leaf, then just return
34-
if (gexpr->Op().type() == OpType::Leaf) {
34+
if (gexpr->Op().GetType() == OpType::Leaf) {
3535
const LeafOperator *leaf = gexpr->Op().As<LeafOperator>();
3636
assert(target_group == UNDEFINED_GROUP ||
3737
target_group == leaf->origin_group);
@@ -73,7 +73,7 @@ GroupID Memo::AddNewGroup(std::shared_ptr<GroupExpression> gexpr) {
7373
GroupID new_group_id = groups_.size();
7474
// Find out the table alias that this group represents
7575
std::unordered_set<std::string> table_aliases;
76-
auto op_type = gexpr->Op().type();
76+
auto op_type = gexpr->Op().GetType();
7777
if (op_type == OpType::Get) {
7878
// For base group, the table alias can get directly from logical get
7979
const LogicalGet *logical_get = gexpr->Op().As<LogicalGet>();

src/optimizer/operator_node.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,51 +24,51 @@ Operator::Operator(BaseOperatorNode *node) : node(node) {}
2424

2525
void Operator::Accept(OperatorVisitor *v) const { node->Accept(v); }
2626

27-
std::string Operator::name() const {
28-
if (defined()) {
29-
return node->name();
27+
std::string Operator::GetName() const {
28+
if (IsDefined()) {
29+
return node->GetName();
3030
}
3131
return "Undefined";
3232
}
3333

34-
OpType Operator::type() const {
35-
if (defined()) {
36-
return node->type();
34+
OpType Operator::GetType() const {
35+
if (IsDefined()) {
36+
return node->GetType();
3737
}
3838
return OpType::Undefined;
3939
}
4040

4141
bool Operator::IsLogical() const {
42-
if (defined()) {
42+
if (IsDefined()) {
4343
return node->IsLogical();
4444
}
4545
return false;
4646
}
4747

4848
bool Operator::IsPhysical() const {
49-
if (defined()) {
49+
if (IsDefined()) {
5050
return node->IsPhysical();
5151
}
5252
return false;
5353
}
5454

5555
hash_t Operator::Hash() const {
56-
if (defined()) {
56+
if (IsDefined()) {
5757
return node->Hash();
5858
}
5959
return 0;
6060
}
6161

6262
bool Operator::operator==(const Operator &r) {
63-
if (defined() && r.defined()) {
63+
if (IsDefined() && r.IsDefined()) {
6464
return *node == *r.node;
65-
} else if (!defined() && !r.defined()) {
65+
} else if (!IsDefined() && !r.IsDefined()) {
6666
return true;
6767
}
6868
return false;
6969
}
7070

71-
bool Operator::defined() const { return node != nullptr; }
71+
bool Operator::IsDefined() const { return node != nullptr; }
7272

7373
} // namespace optimizer
7474
} // namespace peloton

src/optimizer/operators.cpp

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ hash_t LogicalGet::Hash() const {
5151
}
5252

5353
bool LogicalGet::operator==(const BaseOperatorNode &r) {
54-
if (r.type() != OpType::Get) return false;
54+
if (r.GetType()!= OpType::Get) return false;
5555
const LogicalGet &node = *static_cast<const LogicalGet *>(&r);
5656
if (predicates.size() != node.predicates.size()) return false;
5757
for (size_t i = 0; i < predicates.size(); i++) {
@@ -78,7 +78,7 @@ Operator LogicalQueryDerivedGet::make(
7878
}
7979

8080
bool LogicalQueryDerivedGet::operator==(const BaseOperatorNode &node) {
81-
if (node.type() != OpType::LogicalQueryDerivedGet) return false;
81+
if (node.GetType() != OpType::LogicalQueryDerivedGet) return false;
8282
const LogicalQueryDerivedGet &r =
8383
*static_cast<const LogicalQueryDerivedGet *>(&node);
8484
return get_id == r.get_id;
@@ -107,7 +107,7 @@ hash_t LogicalFilter::Hash() const {
107107
}
108108

109109
bool LogicalFilter::operator==(const BaseOperatorNode &r) {
110-
if (r.type() != OpType::LogicalFilter) return false;
110+
if (r.GetType() != OpType::LogicalFilter) return false;
111111
const LogicalFilter &node = *static_cast<const LogicalFilter *>(&r);
112112
if (predicates.size() != node.predicates.size()) return false;
113113
for (size_t i = 0; i < predicates.size(); i++) {
@@ -150,7 +150,7 @@ hash_t LogicalDependentJoin::Hash() const {
150150
}
151151

152152
bool LogicalDependentJoin::operator==(const BaseOperatorNode &r) {
153-
if (r.type() != OpType::LogicalDependentJoin) return false;
153+
if (r.GetType() != OpType::LogicalDependentJoin) return false;
154154
const LogicalDependentJoin &node =
155155
*static_cast<const LogicalDependentJoin *>(&r);
156156
if (join_predicates.size() != node.join_predicates.size()) return false;
@@ -184,7 +184,7 @@ hash_t LogicalMarkJoin::Hash() const {
184184
}
185185

186186
bool LogicalMarkJoin::operator==(const BaseOperatorNode &r) {
187-
if (r.type() != OpType::LogicalMarkJoin) return false;
187+
if (r.GetType() != OpType::LogicalMarkJoin) return false;
188188
const LogicalMarkJoin &node = *static_cast<const LogicalMarkJoin *>(&r);
189189
if (join_predicates.size() != node.join_predicates.size()) return false;
190190
for (size_t i = 0; i < join_predicates.size(); i++) {
@@ -218,7 +218,7 @@ hash_t LogicalSingleJoin::Hash() const {
218218
}
219219

220220
bool LogicalSingleJoin::operator==(const BaseOperatorNode &r) {
221-
if (r.type() != OpType::LogicalSingleJoin) return false;
221+
if (r.GetType() != OpType::LogicalSingleJoin) return false;
222222
const LogicalSingleJoin &node = *static_cast<const LogicalSingleJoin *>(&r);
223223
if (join_predicates.size() != node.join_predicates.size()) return false;
224224
for (size_t i = 0; i < join_predicates.size(); i++) {
@@ -251,7 +251,7 @@ hash_t LogicalInnerJoin::Hash() const {
251251
}
252252

253253
bool LogicalInnerJoin::operator==(const BaseOperatorNode &r) {
254-
if (r.type() != OpType::InnerJoin) return false;
254+
if (r.GetType() != OpType::InnerJoin) return false;
255255
const LogicalInnerJoin &node = *static_cast<const LogicalInnerJoin *>(&r);
256256
if (join_predicates.size() != node.join_predicates.size()) return false;
257257
for (size_t i = 0; i < join_predicates.size(); i++) {
@@ -319,7 +319,7 @@ Operator LogicalAggregateAndGroupBy::make(
319319
}
320320

321321
bool LogicalAggregateAndGroupBy::operator==(const BaseOperatorNode &node) {
322-
if (node.type() != OpType::LogicalAggregateAndGroupBy) return false;
322+
if (node.GetType() != OpType::LogicalAggregateAndGroupBy) return false;
323323
const LogicalAggregateAndGroupBy &r =
324324
*static_cast<const LogicalAggregateAndGroupBy *>(&node);
325325
if (having.size() != r.having.size() || columns.size() != r.columns.size())
@@ -426,7 +426,7 @@ Operator PhysicalSeqScan::make(oid_t get_id, std::shared_ptr<catalog::TableCatal
426426
}
427427

428428
bool PhysicalSeqScan::operator==(const BaseOperatorNode &r) {
429-
if (r.type() != OpType::SeqScan) return false;
429+
if (r.GetType() != OpType::SeqScan) return false;
430430
const PhysicalSeqScan &node = *static_cast<const PhysicalSeqScan *>(&r);
431431
if (predicates.size() != node.predicates.size()) return false;
432432
for (size_t i = 0; i < predicates.size(); i++) {
@@ -470,7 +470,7 @@ Operator PhysicalIndexScan::make(oid_t get_id, std::shared_ptr<catalog::TableCat
470470
}
471471

472472
bool PhysicalIndexScan::operator==(const BaseOperatorNode &r) {
473-
if (r.type() != OpType::IndexScan) return false;
473+
if (r.GetType() != OpType::IndexScan) return false;
474474
const PhysicalIndexScan &node = *static_cast<const PhysicalIndexScan *>(&r);
475475
// TODO: Should also check value list
476476
if (index_id != node.index_id ||
@@ -512,7 +512,7 @@ Operator QueryDerivedScan::make(
512512
}
513513

514514
bool QueryDerivedScan::operator==(const BaseOperatorNode &node) {
515-
if (node.type() != OpType::QueryDerivedScan) return false;
515+
if (node.GetType() != OpType::QueryDerivedScan) return false;
516516
const QueryDerivedScan &r = *static_cast<const QueryDerivedScan *>(&node);
517517
return get_id == r.get_id;
518518
}
@@ -569,7 +569,7 @@ hash_t PhysicalInnerNLJoin::Hash() const {
569569
}
570570

571571
bool PhysicalInnerNLJoin::operator==(const BaseOperatorNode &r) {
572-
if (r.type() != OpType::InnerNLJoin) return false;
572+
if (r.GetType() != OpType::InnerNLJoin) return false;
573573
const PhysicalInnerNLJoin &node =
574574
*static_cast<const PhysicalInnerNLJoin *>(&r);
575575
if (join_predicates.size() != node.join_predicates.size() ||
@@ -646,7 +646,7 @@ hash_t PhysicalInnerHashJoin::Hash() const {
646646
}
647647

648648
bool PhysicalInnerHashJoin::operator==(const BaseOperatorNode &r) {
649-
if (r.type() != OpType::InnerHashJoin) return false;
649+
if (r.GetType() != OpType::InnerHashJoin) return false;
650650
const PhysicalInnerHashJoin &node =
651651
*static_cast<const PhysicalInnerHashJoin *>(&r);
652652
if (join_predicates.size() != node.join_predicates.size() ||
@@ -755,7 +755,7 @@ Operator PhysicalHashGroupBy::make(
755755
}
756756

757757
bool PhysicalHashGroupBy::operator==(const BaseOperatorNode &node) {
758-
if (node.type() != OpType::HashGroupBy) return false;
758+
if (node.GetType() != OpType::HashGroupBy) return false;
759759
const PhysicalHashGroupBy &r =
760760
*static_cast<const PhysicalHashGroupBy *>(&node);
761761
if (having.size() != r.having.size() || columns.size() != r.columns.size())
@@ -788,7 +788,7 @@ Operator PhysicalSortGroupBy::make(
788788
}
789789

790790
bool PhysicalSortGroupBy::operator==(const BaseOperatorNode &node) {
791-
if (node.type() != OpType::SortGroupBy) return false;
791+
if (node.GetType() != OpType::SortGroupBy) return false;
792792
const PhysicalSortGroupBy &r =
793793
*static_cast<const PhysicalSortGroupBy *>(&node);
794794
if (having.size() != r.having.size() || columns.size() != r.columns.size())

src/optimizer/optimizer.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ unique_ptr<planner::AbstractPlan> Optimizer::ChooseBestPlan(
294294
auto gexpr = group->GetBestExpression(required_props);
295295

296296
LOG_TRACE("Choosing best plan for group %d with op %s", gexpr->GetGroupID(),
297-
gexpr->Op().name().c_str());
297+
gexpr->Op().GetName().c_str());
298298

299299
vector<GroupID> child_groups = gexpr->GetChildGroupIDs();
300300
auto required_input_props = gexpr->GetInputProperties(required_props);

0 commit comments

Comments
 (0)