diff --git a/src/optimizer/group_expression.cpp b/src/optimizer/group_expression.cpp index 4d874bd27ef..b1f76b5b3c7 100644 --- a/src/optimizer/group_expression.cpp +++ b/src/optimizer/group_expression.cpp @@ -87,15 +87,21 @@ hash_t GroupExpression::Hash() const { bool GroupExpression::operator==(const GroupExpression &r) { bool eq = (op == r.Op()); - for (size_t i = 0; i < child_groups.size(); ++i) { - eq = eq && (child_groups[i] == r.child_groups[i]); + auto left_groups = child_groups; + auto right_groups = r.child_groups; + + std::sort(left_groups.begin(), left_groups.end()); + + std::sort(right_groups.begin(), right_groups.end()); + for (size_t i = 0; i < left_groups.size(); ++i) { + eq = eq && (left_groups[i] == right_groups[i]); } return eq; } void GroupExpression::SetRuleExplored(Rule *rule) { - rule_mask_.set(rule->GetRuleIdx()) = true; + rule_mask_.set(rule->GetRuleIdx(), true); } bool GroupExpression::HasRuleExplored(Rule *rule) { diff --git a/src/optimizer/memo.cpp b/src/optimizer/memo.cpp index 28eed420726..2a6cc49bf4c 100644 --- a/src/optimizer/memo.cpp +++ b/src/optimizer/memo.cpp @@ -43,8 +43,6 @@ GroupExpression *Memo::InsertExpression(std::shared_ptr gexpr, auto it = group_expressions_.find(gexpr.get()); if (it != group_expressions_.end()) { - PELOTON_ASSERT(target_group == UNDEFINED_GROUP || - target_group == (*it)->GetGroupID()); gexpr->SetGroupID((*it)->GetGroupID()); return *it; } else { diff --git a/test/optimizer/optimizer_rule_test.cpp b/test/optimizer/optimizer_rule_test.cpp index 12d047ad51a..23f520596dc 100644 --- a/test/optimizer/optimizer_rule_test.cpp +++ b/test/optimizer/optimizer_rule_test.cpp @@ -261,5 +261,32 @@ TEST_F(OptimizerRuleTests, SimpleAssociativeRuleTest2) { delete root_context; } +TEST_F(OptimizerRuleTests, RuleBitmapTest) { + Optimizer optimizer; + auto &memo = optimizer.GetMetadata().memo; + + auto dummy_operator = std::make_shared(LogicalGet::make()); + auto dummy_group = memo.InsertExpression(optimizer.GetMetadata().MakeGroupExpression(dummy_operator), false); + + auto rule1 = new InnerJoinCommutativity(); + auto rule2 = new GetToSeqScan(); + + EXPECT_FALSE(dummy_group->HasRuleExplored(rule1)); + EXPECT_FALSE(dummy_group->HasRuleExplored(rule2)); + + dummy_group->SetRuleExplored(rule1); + + EXPECT_TRUE(dummy_group->HasRuleExplored(rule1)); + EXPECT_FALSE(dummy_group->HasRuleExplored(rule2)); + + dummy_group->SetRuleExplored(rule2); + + EXPECT_TRUE(dummy_group->HasRuleExplored(rule1)); + EXPECT_TRUE(dummy_group->HasRuleExplored(rule2)); + + delete rule1; + delete rule2; +} + } // namespace test } // namespace peloton diff --git a/test/sql/testing_sql_util.cpp b/test/sql/testing_sql_util.cpp index 9549c794a91..feb9fad20c4 100644 --- a/test/sql/testing_sql_util.cpp +++ b/test/sql/testing_sql_util.cpp @@ -98,6 +98,21 @@ ResultType TestingSQLUtil::ExecuteSQLQuery( return status; } +void PrintPlan(planner::AbstractPlan *plan, int level = 0) { + auto spacing = std::string(level, '\t'); + if (plan->GetPlanNodeType() == PlanNodeType::SEQSCAN) { + auto scan = dynamic_cast(plan); + LOG_INFO("%s%s(%s)", spacing.c_str(), scan->GetInfo().c_str(), + scan->GetTable()->GetName().c_str()); + } else { + LOG_INFO("%s%s", spacing.c_str(), plan->GetInfo().c_str()); + } + for (size_t i = 0; i < plan->GetChildren().size(); i++) { + PrintPlan(plan->GetChildren()[i].get(), level + 1); + } + return; +} + // Execute a SQL query end-to-end with the specific optimizer ResultType TestingSQLUtil::ExecuteSQLQueryWithOptimizer( std::unique_ptr &optimizer,