Skip to content
This repository was archived by the owner on Sep 27, 2019. It is now read-only.

Commit f335842

Browse files
committed
Use promise to order unnesting rules apply order
1 parent 52be828 commit f335842

File tree

5 files changed

+55
-70
lines changed

5 files changed

+55
-70
lines changed

src/include/optimizer/rule.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ struct RuleWithPromise {
108108
int promise;
109109

110110
bool operator<(const RuleWithPromise &r) const { return promise < r.promise; }
111+
bool operator>(const RuleWithPromise &r) const { return promise > r.promise; }
111112
};
112113

113114
enum class RewriteRuleSetName : uint32_t {

src/include/optimizer/rule_impls.h

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -307,25 +307,21 @@ class EmbedFilterIntoGet : public Rule {
307307
OptimizeContext *context) const override;
308308
};
309309

310-
///////////////////////////////////////////////////////////////////////////////
311-
/// MarkJoinGetToInnerJoin
312-
class MarkJoinGetToInnerJoin : public Rule {
313-
public:
314-
MarkJoinGetToInnerJoin();
315310

316-
bool Check(std::shared_ptr<OperatorExpression> plan,
317-
OptimizeContext *context) const override;
318-
319-
void Transform(std::shared_ptr<OperatorExpression> input,
320-
std::vector<std::shared_ptr<OperatorExpression>> &transformed,
321-
OptimizeContext *context) const override;
311+
///////////////////////////////////////////////////////////////////////////////
312+
/// Unnesting rules
313+
enum class UnnestPromise {
314+
Low = 1,
315+
High
322316
};
323-
324317
///////////////////////////////////////////////////////////////////////////////
325-
/// MarkJoinInnerJoinToInnerJoin
326-
class MarkJoinInnerJoinToInnerJoin : public Rule {
318+
/// MarkJoinGetToInnerJoin
319+
class MarkJoinToInnerJoin : public Rule {
327320
public:
328-
MarkJoinInnerJoinToInnerJoin();
321+
MarkJoinToInnerJoin();
322+
323+
int Promise(GroupExpression *group_expr,
324+
OptimizeContext *context) const override;
329325

330326
bool Check(std::shared_ptr<OperatorExpression> plan,
331327
OptimizeContext *context) const override;
@@ -341,6 +337,9 @@ class PullFilterThroughMarkJoin : public Rule {
341337
public:
342338
PullFilterThroughMarkJoin();
343339

340+
int Promise(GroupExpression *group_expr,
341+
OptimizeContext *context) const override;
342+
344343
bool Check(std::shared_ptr<OperatorExpression> plan,
345344
OptimizeContext *context) const override;
346345

src/optimizer/optimizer_task.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,8 @@ void TopDownRewrite::execute() {
417417
GetRuleSet().GetRewriteRulesByName(rule_set_name_),
418418
valid_rules);
419419

420-
std::sort(valid_rules.begin(), valid_rules.end());
420+
// Sort so that we apply rewrite rules with higher promise first
421+
std::sort(valid_rules.begin(), valid_rules.end(), std::greater<RuleWithPromise>());
421422

422423
for (auto &r : valid_rules) {
423424
GroupExprBindingIterator iterator(GetMemo(), cur_group_expr,
@@ -476,7 +477,9 @@ void BottomUpRewrite::execute() {
476477
GetRuleSet().GetRewriteRulesByName(rule_set_name_),
477478
valid_rules);
478479

479-
std::sort(valid_rules.begin(), valid_rules.end());
480+
// Sort so that we apply rewrite rules with higher promise first
481+
std::sort(valid_rules.begin(), valid_rules.end(), std::greater<RuleWithPromise>());
482+
// std::reverse(valid_rules.begin(), valid_rules.end());
480483

481484
for (auto &r : valid_rules) {
482485
GroupExprBindingIterator iterator(GetMemo(), cur_group_expr,

src/optimizer/rule.cpp

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,17 @@ RuleSet::RuleSet() {
4444
AddImplementationRule(new ImplementDistinct());
4545
AddImplementationRule(new ImplementLimit());
4646

47-
AddRewriteRule(RewriteRuleSetName::PREDICATE_PUSH_DOWN, new PushFilterThroughJoin());
48-
AddRewriteRule(RewriteRuleSetName::PREDICATE_PUSH_DOWN, new CombineConsecutiveFilter());
49-
AddRewriteRule(RewriteRuleSetName::PREDICATE_PUSH_DOWN, new EmbedFilterIntoGet());
50-
51-
AddRewriteRule(RewriteRuleSetName::UNNEST_SUBQUERY, new PullFilterThroughMarkJoin());
52-
AddRewriteRule(RewriteRuleSetName::UNNEST_SUBQUERY, new MarkJoinInnerJoinToInnerJoin());
53-
AddRewriteRule(RewriteRuleSetName::UNNEST_SUBQUERY, new MarkJoinGetToInnerJoin());
54-
55-
47+
AddRewriteRule(RewriteRuleSetName::PREDICATE_PUSH_DOWN,
48+
new PushFilterThroughJoin());
49+
AddRewriteRule(RewriteRuleSetName::PREDICATE_PUSH_DOWN,
50+
new CombineConsecutiveFilter());
51+
AddRewriteRule(RewriteRuleSetName::PREDICATE_PUSH_DOWN,
52+
new EmbedFilterIntoGet());
53+
54+
AddRewriteRule(RewriteRuleSetName::UNNEST_SUBQUERY,
55+
new PullFilterThroughMarkJoin());
56+
AddRewriteRule(RewriteRuleSetName::UNNEST_SUBQUERY,
57+
new MarkJoinToInnerJoin());
5658
}
5759

5860
} // namespace optimizer

src/optimizer/rule_impls.cpp

Lines changed: 24 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -877,58 +877,27 @@ void EmbedFilterIntoGet::Transform(
877877

878878
///////////////////////////////////////////////////////////////////////////////
879879
/// MarkJoinGetToInnerJoin
880-
MarkJoinGetToInnerJoin::MarkJoinGetToInnerJoin() {
880+
MarkJoinToInnerJoin::MarkJoinToInnerJoin() {
881881
type_ = RuleType::MARK_JOIN_GET_TO_INNER_JOIN;
882882

883883
match_pattern = std::make_shared<Pattern>(OpType::LogicalMarkJoin);
884884
match_pattern->AddChild(std::make_shared<Pattern>(OpType::Leaf));
885-
match_pattern->AddChild(std::make_shared<Pattern>(OpType::Get));
885+
match_pattern->AddChild(std::make_shared<Pattern>(OpType::Leaf));
886886
}
887887

888-
bool MarkJoinGetToInnerJoin::Check(std::shared_ptr<OperatorExpression> plan,
889-
OptimizeContext *context) const {
888+
int MarkJoinToInnerJoin::Promise(GroupExpression *group_expr,
889+
OptimizeContext *context) const {
890890
(void)context;
891-
(void)plan;
892-
893-
UNUSED_ATTRIBUTE auto &children = plan->Children();
894-
PL_ASSERT(children.size() == 2);
895-
896-
return true;
897-
}
898-
899-
void MarkJoinGetToInnerJoin::Transform(
900-
std::shared_ptr<OperatorExpression> input,
901-
std::vector<std::shared_ptr<OperatorExpression>> &transformed,
902-
UNUSED_ATTRIBUTE OptimizeContext *context) const {
903-
UNUSED_ATTRIBUTE auto mark_join = input->Op().As<LogicalMarkJoin>();
904-
auto &join_children = input->Children();
905-
906-
PL_ASSERT(mark_join->join_predicates.empty());
907-
908-
std::shared_ptr<OperatorExpression> output =
909-
std::make_shared<OperatorExpression>(LogicalInnerJoin::make());
910-
911-
output->PushChild(join_children[0]);
912-
output->PushChild(join_children[1]);
913-
914-
transformed.push_back(output);
915-
}
916-
917-
///////////////////////////////////////////////////////////////////////////////
918-
/// MarkJoinInnerJoinToInnerJoin
919-
MarkJoinInnerJoinToInnerJoin::MarkJoinInnerJoinToInnerJoin() {
920-
type_ = RuleType::MARK_JOIN_INNER_JOIN_TO_INNER_JOIN;
921-
922-
match_pattern = std::make_shared<Pattern>(OpType::LogicalMarkJoin);
923-
match_pattern->AddChild(std::make_shared<Pattern>(OpType::Leaf));
924-
auto inner_join = std::make_shared<Pattern>(OpType::InnerJoin);
925-
inner_join->AddChild(std::make_shared<Pattern>(OpType::Leaf));
926-
inner_join->AddChild(std::make_shared<Pattern>(OpType::Leaf));
927-
match_pattern->AddChild(inner_join);
891+
auto root_type = match_pattern->Type();
892+
// This rule is not applicable
893+
if (root_type != OpType::Leaf && root_type != group_expr->Op().type()) {
894+
return 0;
895+
}
896+
return static_cast<int>(UnnestPromise::Low);
928897
}
929898

930-
bool MarkJoinInnerJoinToInnerJoin::Check(
931-
std::shared_ptr<OperatorExpression> plan, OptimizeContext *context) const {
899+
bool MarkJoinToInnerJoin::Check(std::shared_ptr<OperatorExpression> plan,
900+
OptimizeContext *context) const {
932901
(void)context;
933902
(void)plan;
934903

@@ -938,7 +907,7 @@ bool MarkJoinInnerJoinToInnerJoin::Check(
938907
return true;
939908
}
940909

941-
void MarkJoinInnerJoinToInnerJoin::Transform(
910+
void MarkJoinToInnerJoin::Transform(
942911
std::shared_ptr<OperatorExpression> input,
943912
std::vector<std::shared_ptr<OperatorExpression>> &transformed,
944913
UNUSED_ATTRIBUTE OptimizeContext *context) const {
@@ -968,6 +937,17 @@ PullFilterThroughMarkJoin::PullFilterThroughMarkJoin() {
968937
match_pattern->AddChild(filter);
969938
}
970939

940+
int PullFilterThroughMarkJoin::Promise(GroupExpression *group_expr,
941+
OptimizeContext *context) const {
942+
(void)context;
943+
auto root_type = match_pattern->Type();
944+
// This rule is not applicable
945+
if (root_type != OpType::Leaf && root_type != group_expr->Op().type()) {
946+
return 0;
947+
}
948+
return static_cast<int>(UnnestPromise::High);
949+
}
950+
971951
bool PullFilterThroughMarkJoin::Check(std::shared_ptr<OperatorExpression> plan,
972952
OptimizeContext *context) const {
973953
(void)context;

0 commit comments

Comments
 (0)