Skip to content
This repository was archived by the owner on Sep 27, 2019. It is now read-only.

Commit 000b9ff

Browse files
committed
Add a rule to pull correlated predicates up through aggregation
1 parent f335842 commit 000b9ff

File tree

5 files changed

+196
-17
lines changed

5 files changed

+196
-17
lines changed

src/include/common/internal_types.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1335,6 +1335,7 @@ enum class RuleType : uint32_t {
13351335
MARK_JOIN_INNER_JOIN_TO_INNER_JOIN,
13361336
MARK_JOIN_FILTER_TO_INNER_JOIN,
13371337
PULL_FILTER_THROUGH_MARK_JOIN,
1338+
PULL_FILTER_THROUGH_AGGREGATION,
13381339

13391340
// Place holder to generate number of rules compile time
13401341
NUM_RULES

src/include/optimizer/rule_impls.h

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -307,21 +307,37 @@ class EmbedFilterIntoGet : public Rule {
307307
OptimizeContext *context) const override;
308308
};
309309

310-
311310
///////////////////////////////////////////////////////////////////////////////
312311
/// Unnesting rules
313-
enum class UnnestPromise {
314-
Low = 1,
315-
High
316-
};
312+
enum class UnnestPromise { Low = 1, High };
313+
// TODO(boweic): MarkJoin and SingleJoin should not be transformed into inner
314+
// join. Sometimes MarkJoin could be transformed into semi-join, but for now we
315+
// do not have these operators in the llvm cogen engine. Once we have those, we
316+
// should not use the following rules in the rewrite phase
317317
///////////////////////////////////////////////////////////////////////////////
318318
/// MarkJoinGetToInnerJoin
319319
class MarkJoinToInnerJoin : public Rule {
320320
public:
321321
MarkJoinToInnerJoin();
322322

323323
int Promise(GroupExpression *group_expr,
324-
OptimizeContext *context) const override;
324+
OptimizeContext *context) const override;
325+
326+
bool Check(std::shared_ptr<OperatorExpression> plan,
327+
OptimizeContext *context) const override;
328+
329+
void Transform(std::shared_ptr<OperatorExpression> input,
330+
std::vector<std::shared_ptr<OperatorExpression>> &transformed,
331+
OptimizeContext *context) const override;
332+
};
333+
///////////////////////////////////////////////////////////////////////////////
334+
/// SingleJoinToInnerJoin
335+
class SingleJoinToInnerJoin : public Rule {
336+
public:
337+
SingleJoinToInnerJoin();
338+
339+
int Promise(GroupExpression *group_expr,
340+
OptimizeContext *context) const override;
325341

326342
bool Check(std::shared_ptr<OperatorExpression> plan,
327343
OptimizeContext *context) const override;
@@ -338,7 +354,7 @@ class PullFilterThroughMarkJoin : public Rule {
338354
PullFilterThroughMarkJoin();
339355

340356
int Promise(GroupExpression *group_expr,
341-
OptimizeContext *context) const override;
357+
OptimizeContext *context) const override;
342358

343359
bool Check(std::shared_ptr<OperatorExpression> plan,
344360
OptimizeContext *context) const override;
@@ -348,5 +364,21 @@ class PullFilterThroughMarkJoin : public Rule {
348364
OptimizeContext *context) const override;
349365
};
350366

367+
///////////////////////////////////////////////////////////////////////////////
368+
/// PullFilterThroughAggregation
369+
class PullFilterThroughAggregation : public Rule {
370+
public:
371+
PullFilterThroughAggregation();
372+
373+
int Promise(GroupExpression *group_expr,
374+
OptimizeContext *context) const override;
375+
376+
bool Check(std::shared_ptr<OperatorExpression> plan,
377+
OptimizeContext *context) const override;
378+
379+
void Transform(std::shared_ptr<OperatorExpression> input,
380+
std::vector<std::shared_ptr<OperatorExpression>> &transformed,
381+
OptimizeContext *context) const override;
382+
};
351383
} // namespace optimizer
352384
} // namespace peloton

src/optimizer/optimizer_task.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -296,8 +296,8 @@ void OptimizeInputs::execute() {
296296
// 1. Collect stats needed and cache them in the group
297297
// 2. Calculate cost based on children's stats
298298
CostCalculator cost_calculator;
299-
cur_total_cost_ += cost_calculator.CalculateCost(
300-
group_expr_, &context_->metadata->memo);
299+
cur_total_cost_ +=
300+
cost_calculator.CalculateCost(group_expr_, &context_->metadata->memo);
301301
}
302302

303303
for (; cur_child_idx_ < (int)group_expr_->GetChildrenGroupsSize();
@@ -418,7 +418,8 @@ void TopDownRewrite::execute() {
418418
valid_rules);
419419

420420
// Sort so that we apply rewrite rules with higher promise first
421-
std::sort(valid_rules.begin(), valid_rules.end(), std::greater<RuleWithPromise>());
421+
std::sort(valid_rules.begin(), valid_rules.end(),
422+
std::greater<RuleWithPromise>());
422423

423424
for (auto &r : valid_rules) {
424425
GroupExprBindingIterator iterator(GetMemo(), cur_group_expr,
@@ -478,8 +479,8 @@ void BottomUpRewrite::execute() {
478479
valid_rules);
479480

480481
// Sort so that we apply rewrite rules with higher promise first
481-
std::sort(valid_rules.begin(), valid_rules.end(), std::greater<RuleWithPromise>());
482-
// std::reverse(valid_rules.begin(), valid_rules.end());
482+
std::sort(valid_rules.begin(), valid_rules.end(),
483+
std::greater<RuleWithPromise>());
483484

484485
for (auto &r : valid_rules) {
485486
GroupExprBindingIterator iterator(GetMemo(), cur_group_expr,

src/optimizer/rule_impls.cpp

Lines changed: 150 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -876,7 +876,7 @@ void EmbedFilterIntoGet::Transform(
876876
}
877877

878878
///////////////////////////////////////////////////////////////////////////////
879-
/// MarkJoinGetToInnerJoin
879+
/// MarkJoinToInnerJoin
880880
MarkJoinToInnerJoin::MarkJoinToInnerJoin() {
881881
type_ = RuleType::MARK_JOIN_GET_TO_INNER_JOIN;
882882

@@ -886,7 +886,7 @@ MarkJoinToInnerJoin::MarkJoinToInnerJoin() {
886886
}
887887

888888
int MarkJoinToInnerJoin::Promise(GroupExpression *group_expr,
889-
OptimizeContext *context) const {
889+
OptimizeContext *context) const {
890890
(void)context;
891891
auto root_type = match_pattern->Type();
892892
// This rule is not applicable
@@ -897,7 +897,7 @@ int MarkJoinToInnerJoin::Promise(GroupExpression *group_expr,
897897
}
898898

899899
bool MarkJoinToInnerJoin::Check(std::shared_ptr<OperatorExpression> plan,
900-
OptimizeContext *context) const {
900+
OptimizeContext *context) const {
901901
(void)context;
902902
(void)plan;
903903

@@ -925,6 +925,56 @@ void MarkJoinToInnerJoin::Transform(
925925
transformed.push_back(output);
926926
}
927927

928+
///////////////////////////////////////////////////////////////////////////////
929+
/// SingleJoinGetToInnerJoin
930+
SingleJoinToInnerJoin::SingleJoinToInnerJoin() {
931+
type_ = RuleType::MARK_JOIN_GET_TO_INNER_JOIN;
932+
933+
match_pattern = std::make_shared<Pattern>(OpType::LogicalSingleJoin);
934+
match_pattern->AddChild(std::make_shared<Pattern>(OpType::Leaf));
935+
match_pattern->AddChild(std::make_shared<Pattern>(OpType::Leaf));
936+
}
937+
938+
int SingleJoinToInnerJoin::Promise(GroupExpression *group_expr,
939+
OptimizeContext *context) const {
940+
(void)context;
941+
auto root_type = match_pattern->Type();
942+
// This rule is not applicable
943+
if (root_type != OpType::Leaf && root_type != group_expr->Op().type()) {
944+
return 0;
945+
}
946+
return static_cast<int>(UnnestPromise::Low);
947+
}
948+
949+
bool SingleJoinToInnerJoin::Check(std::shared_ptr<OperatorExpression> plan,
950+
OptimizeContext *context) const {
951+
(void)context;
952+
(void)plan;
953+
954+
UNUSED_ATTRIBUTE auto &children = plan->Children();
955+
PL_ASSERT(children.size() == 2);
956+
957+
return true;
958+
}
959+
960+
void SingleJoinToInnerJoin::Transform(
961+
std::shared_ptr<OperatorExpression> input,
962+
std::vector<std::shared_ptr<OperatorExpression>> &transformed,
963+
UNUSED_ATTRIBUTE OptimizeContext *context) const {
964+
UNUSED_ATTRIBUTE auto single_join = input->Op().As<LogicalSingleJoin>();
965+
auto &join_children = input->Children();
966+
967+
PL_ASSERT(single_join->join_predicates.empty());
968+
969+
std::shared_ptr<OperatorExpression> output =
970+
std::make_shared<OperatorExpression>(LogicalInnerJoin::make());
971+
972+
output->PushChild(join_children[0]);
973+
output->PushChild(join_children[1]);
974+
975+
transformed.push_back(output);
976+
}
977+
928978
///////////////////////////////////////////////////////////////////////////////
929979
/// PullFilterThroughMarkJoin
930980
PullFilterThroughMarkJoin::PullFilterThroughMarkJoin() {
@@ -986,5 +1036,102 @@ void PullFilterThroughMarkJoin::Transform(
9861036
transformed.push_back(output);
9871037
}
9881038

1039+
///////////////////////////////////////////////////////////////////////////////
1040+
/// PullFilterThroughAggregation
1041+
PullFilterThroughAggregation::PullFilterThroughAggregation() {
1042+
type_ = RuleType::PULL_FILTER_THROUGH_AGGREGATION;
1043+
1044+
auto filter = std::make_shared<Pattern>(OpType::LogicalFilter);
1045+
filter->AddChild(std::make_shared<Pattern>(OpType::Leaf));
1046+
match_pattern = std::make_shared<Pattern>(OpType::LogicalAggregateAndGroupBy);
1047+
match_pattern->AddChild(filter);
1048+
}
1049+
1050+
int PullFilterThroughAggregation::Promise(GroupExpression *group_expr,
1051+
OptimizeContext *context) const {
1052+
(void)context;
1053+
auto root_type = match_pattern->Type();
1054+
// This rule is not applicable
1055+
if (root_type != OpType::Leaf && root_type != group_expr->Op().type()) {
1056+
return 0;
1057+
}
1058+
return static_cast<int>(UnnestPromise::High);
1059+
}
1060+
1061+
bool PullFilterThroughAggregation::Check(
1062+
std::shared_ptr<OperatorExpression> plan, OptimizeContext *context) const {
1063+
(void)context;
1064+
(void)plan;
1065+
1066+
auto &children = plan->Children();
1067+
PL_ASSERT(children.size() == 1);
1068+
UNUSED_ATTRIBUTE auto &r_grandchildren = children[1]->Children();
1069+
PL_ASSERT(r_grandchildren.size() == 1);
1070+
1071+
return true;
1072+
}
1073+
1074+
void PullFilterThroughAggregation::Transform(
1075+
std::shared_ptr<OperatorExpression> input,
1076+
std::vector<std::shared_ptr<OperatorExpression>> &transformed,
1077+
UNUSED_ATTRIBUTE OptimizeContext *context) const {
1078+
auto &memo = context->metadata->memo;
1079+
auto &filter_expr = input->Children()[0];
1080+
auto child_group_id =
1081+
filter_expr->Children()[0]->Op().As<LeafOperator>()->origin_group;
1082+
const auto &child_group_aliases_set =
1083+
memo.GetGroupByID(child_group_id)->GetTableAliases();
1084+
1085+
auto &predicates = filter_expr->Op().As<LogicalFilter>()->predicates;
1086+
1087+
std::vector<AnnotatedExpression> correlated_predicates;
1088+
std::vector<AnnotatedExpression> normal_predicates;
1089+
std::vector<std::shared_ptr<expression::AbstractExpression>> new_groupby_cols;
1090+
for (auto &predicate : predicates) {
1091+
if (util::IsSubset(child_group_aliases_set, predicate.table_alias_set)) {
1092+
normal_predicates.emplace_back(predicate);
1093+
} else {
1094+
// Correlated predicate, already in the form of
1095+
// (outer_relation.a = (expr))
1096+
correlated_predicates.emplace_back(predicate);
1097+
auto &root_expr = predicate.expr;
1098+
if (root_expr->GetChild(0)->GetDepth() < root_expr->GetDepth()) {
1099+
new_groupby_cols.emplace_back(root_expr->GetChild(1)->Copy());
1100+
} else {
1101+
new_groupby_cols.emplace_back(root_expr->GetChild(0)->Copy());
1102+
}
1103+
}
1104+
}
1105+
1106+
if (correlated_predicates.empty()) {
1107+
// No need to pull
1108+
return;
1109+
}
1110+
auto aggregation = input->Op().As<LogicalAggregateAndGroupBy>();
1111+
for (auto &col : aggregation->columns) {
1112+
new_groupby_cols.emplace_back(col->Copy());
1113+
}
1114+
std::vector<AnnotatedExpression> new_having(aggregation->having);
1115+
std::shared_ptr<OperatorExpression> new_aggregation =
1116+
std::make_shared<OperatorExpression>(LogicalAggregateAndGroupBy::make(
1117+
new_groupby_cols, new_having));
1118+
std::shared_ptr<OperatorExpression> output =
1119+
std::make_shared<OperatorExpression>(
1120+
LogicalFilter::make(correlated_predicates));
1121+
output->PushChild(new_aggregation);
1122+
auto bottom_operator = new_aggregation;
1123+
1124+
// Construct child filter if any
1125+
if (!normal_predicates.empty()) {
1126+
std::shared_ptr<OperatorExpression> new_filter =
1127+
std::make_shared<OperatorExpression>(
1128+
LogicalFilter::make(normal_predicates));
1129+
new_aggregation->PushChild(new_filter);
1130+
bottom_operator = new_filter;
1131+
}
1132+
bottom_operator->PushChild(filter_expr->Children()[0]);
1133+
1134+
transformed.push_back(output);
1135+
}
9891136
} // namespace optimizer
9901137
} // namespace peloton

test/sql/optimizer_sql_test.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -733,7 +733,6 @@ TEST_F(OptimizerSQLTests, NestedQueryTest) {
733733
{"1", "2", "3", "4"}, false);
734734
}
735735

736-
/*
737736
TEST_F(OptimizerSQLTests, NestedQueryWithAggregationTest) {
738737
// Nested with aggregation
739738
TestingSQLUtil::ExecuteSQLQuery("CREATE TABLE agg(a int, b int);");
@@ -790,7 +789,6 @@ TEST_F(OptimizerSQLTests, NestedQueryWithAggregationTest) {
790789
"s.sid < 4;",
791790
{"Patrick", "4", "David", "4", "Alice", "2"}, false);
792791
}
793-
*/
794792

795793
} // namespace test
796794
} // namespace peloton

0 commit comments

Comments
 (0)