Skip to content
This repository was archived by the owner on Sep 27, 2019. It is now read-only.

Commit 4e168fd

Browse files
Attempt at fixing pushdown filter
1 parent aa739fd commit 4e168fd

File tree

6 files changed

+27
-8
lines changed

6 files changed

+27
-8
lines changed

src/include/planner/abstract_join_plan.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,10 @@ class AbstractJoinPlan : public AbstractPlan {
8787
virtual void HandleSubplanBinding(bool from_left,
8888
const BindingContext &input) = 0;
8989

90+
const std::string GetPredicateInfo() const {
91+
return predicate_ != nullptr ? predicate_->GetInfo() : "";
92+
}
93+
9094
private:
9195
/** @brief The type of join that we're going to perform */
9296
JoinType join_type_;

src/include/planner/hash_join_plan.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,9 @@ class HashJoinPlan : public AbstractJoinPlan {
5757

5858
void SetBloomFilterFlag(bool flag) { build_bloomfilter_ = flag; }
5959

60-
const std::string GetInfo() const override { return "HashJoin"; }
60+
const std::string GetInfo() const override {
61+
return "HashJoin(" + GetPredicateInfo() + ")";
62+
}
6163

6264
const std::vector<oid_t> &GetOuterHashIds() const {
6365
return outer_column_ids_;

src/optimizer/optimizer_task.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,11 @@ void TopDownRewrite::execute() {
423423
r.rule->GetMatchPattern());
424424
if (iterator.HasNext()) {
425425
auto before = iterator.Next();
426+
427+
if (!r.rule->Check(before, context_.get())) {
428+
continue;
429+
}
430+
426431
PELOTON_ASSERT(!iterator.HasNext());
427432
std::vector<std::shared_ptr<OperatorExpression>> after;
428433
r.rule->Transform(before, after, context_.get());

src/optimizer/query_to_operator_transformer.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@ void QueryToOperatorTransformer::Visit(parser::SelectStatement *op) {
123123
predicates_ = std::move(pre_predicates);
124124
}
125125
void QueryToOperatorTransformer::Visit(parser::JoinDefinition *node) {
126+
auto pre_predicates = std::move(predicates_);
127+
126128
// Get left operator
127129
node->left->Accept(this);
128130
auto left_expr = output_expr_;
@@ -137,25 +139,25 @@ void QueryToOperatorTransformer::Visit(parser::JoinDefinition *node) {
137139
case JoinType::INNER: {
138140
predicates_ = CollectPredicates(node->condition.get(), predicates_);
139141
join_expr = std::make_shared<OperatorExpression>(
140-
LogicalJoin::make(JoinType::INNER));
142+
LogicalJoin::make(JoinType::INNER, predicates_));
141143
break;
142144
}
143145
case JoinType::OUTER: {
144146
predicates_ = CollectPredicates(node->condition.get(), predicates_);
145147
join_expr = std::make_shared<OperatorExpression>(
146-
LogicalJoin::make(JoinType::OUTER));
148+
LogicalJoin::make(JoinType::OUTER, predicates_));
147149
break;
148150
}
149151
case JoinType::LEFT: {
150152
predicates_ = CollectPredicates(node->condition.get(), predicates_);
151153
join_expr = std::make_shared<OperatorExpression>(
152-
LogicalJoin::make(JoinType::LEFT));
154+
LogicalJoin::make(JoinType::LEFT, predicates_));
153155
break;
154156
}
155157
case JoinType::RIGHT: {
156158
predicates_ = CollectPredicates(node->condition.get(), predicates_);
157159
join_expr = std::make_shared<OperatorExpression>(
158-
LogicalJoin::make(JoinType::RIGHT));
160+
LogicalJoin::make(JoinType::RIGHT, predicates_));
159161
break;
160162
}
161163
case JoinType::SEMI: {
@@ -171,6 +173,7 @@ void QueryToOperatorTransformer::Visit(parser::JoinDefinition *node) {
171173
join_expr->PushChild(right_expr);
172174

173175
output_expr_ = join_expr;
176+
predicates_ = std::move(pre_predicates);
174177
}
175178
void QueryToOperatorTransformer::Visit(parser::TableRef *node) {
176179
if (node->select != nullptr) {

src/optimizer/rule.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ RuleSet::RuleSet() {
4646
AddImplementationRule(new ImplementLimit());
4747

4848
AddRewriteRule(RewriteRuleSetName::PREDICATE_PUSH_DOWN,
49-
new PushFilterThroughJoin());
49+
new PushFilterThroughJoin());
5050
AddRewriteRule(RewriteRuleSetName::PREDICATE_PUSH_DOWN,
5151
new PushFilterThroughAggregation());
5252
AddRewriteRule(RewriteRuleSetName::PREDICATE_PUSH_DOWN,

src/optimizer/rule_impls.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -853,17 +853,22 @@ void PushFilterThroughJoin::Transform(
853853
std::vector<AnnotatedExpression> right_predicates;
854854
std::vector<AnnotatedExpression> join_predicates;
855855

856+
auto join_type = join_op_expr->Op().As<LogicalJoin>()->type;
857+
bool outer_push = (join_type == JoinType::OUTER ||
858+
join_type == JoinType::LEFT ||
859+
join_type == JoinType::RIGHT);
860+
856861
// Loop over all predicates, check each of them if they can be pushed down to
857862
// either the left child or the right child to be evaluated
858863
// All predicates in this loop follow conjunction relationship because we
859864
// already extract these predicates from the original.
860865
// E.g. An expression (test.a = test1.b and test.a = 5) would become
861866
// {test.a = test1.b, test.a = 5}
862867
for (auto &predicate : predicates) {
863-
if (util::IsSubset(left_group_aliases_set, predicate.table_alias_set)) {
868+
if (util::IsSubset(left_group_aliases_set, predicate.table_alias_set) && !outer_push) {
864869
left_predicates.emplace_back(predicate);
865870
} else if (util::IsSubset(right_group_aliases_set,
866-
predicate.table_alias_set)) {
871+
predicate.table_alias_set) && !outer_push) {
867872
right_predicates.emplace_back(predicate);
868873
} else {
869874
join_predicates.emplace_back(predicate);

0 commit comments

Comments
 (0)