|
23 | 23 | #include "executor/insert_executor.h"
|
24 | 24 | #include "executor/plan_executor.h"
|
25 | 25 | #include "executor/update_executor.h"
|
| 26 | +#include "expression/abstract_expression.h" |
| 27 | +#include "expression/operator_expression.h" |
26 | 28 | #include "optimizer/operator_expression.h"
|
27 | 29 | #include "optimizer/operators.h"
|
28 | 30 | #include "optimizer/optimizer.h"
|
|
36 | 38 | #include "sql/testing_sql_util.h"
|
37 | 39 | #include "type/value_factory.h"
|
38 | 40 |
|
| 41 | + |
39 | 42 | namespace peloton {
|
40 | 43 | namespace test {
|
41 | 44 |
|
@@ -78,106 +81,73 @@ TEST_F(OptimizerRuleTests, SimpleCommutativeRuleTest) {
|
78 | 81 | EXPECT_EQ(output_join->Children()[1], left_get);
|
79 | 82 | }
|
80 | 83 |
|
81 |
| -TEST_F(OptimizerRuleTests, AssociativeRuleTest) { |
82 |
| - auto &txn_manager = concurrency::TransactionManagerFactory::GetInstance(); |
83 |
| - auto txn = txn_manager.BeginTransaction(); |
84 |
| - catalog::Catalog::GetInstance()->CreateDatabase(DEFAULT_DB_NAME, txn); |
85 |
| - txn_manager.CommitTransaction(txn); |
86 |
| - |
87 |
| - TestingSQLUtil::ExecuteSQLQuery( |
88 |
| - "CREATE TABLE test1(a INT PRIMARY KEY, b INT, c INT);"); |
89 |
| - TestingSQLUtil::ExecuteSQLQuery( |
90 |
| - "CREATE TABLE test2(a INT PRIMARY KEY, b INT, c INT);"); |
91 |
| - TestingSQLUtil::ExecuteSQLQuery( |
92 |
| - "CREATE TABLE test3(a INT PRIMARY KEY, b INT, c INT);"); |
93 |
| - |
94 |
| - auto &peloton_parser = parser::PostgresParser::GetInstance(); |
95 |
| - auto stmt = |
96 |
| - peloton_parser.BuildParseTree("SELECT * FROM test1, test2, test3"); |
97 |
| - auto parse_tree = stmt->GetStatements().at(0).get(); |
98 |
| - auto predicates = std::vector<expression::AbstractExpression *>(); |
99 |
| - |
100 |
| - optimizer::Optimizer optimizer; |
101 |
| - |
102 |
| - // Push Associativity rule and execute tasks |
103 |
| - optimizer.metadata_.rule_set.transformation_rules_.clear(); |
104 |
| - optimizer.metadata_.rule_set.transformation_rules_.emplace_back( |
105 |
| - new InnerJoinAssociativity()); |
106 |
| - |
107 |
| - txn = txn_manager.BeginTransaction(); |
108 |
| - |
109 |
| - auto bind_node_visitor = |
110 |
| - std::make_shared<binder::BindNodeVisitor>(txn, DEFAULT_DB_NAME); |
111 |
| - bind_node_visitor->BindNameToNode(parse_tree); |
112 |
| - |
113 |
| - std::shared_ptr<GroupExpression> gexpr = |
114 |
| - optimizer.InsertQueryTree(parse_tree, txn); |
115 |
| - std::vector<GroupID> child_groups = {gexpr->GetGroupID()}; |
116 |
| - |
117 |
| - auto &memo = optimizer.metadata_.memo; |
118 |
| - std::shared_ptr<GroupExpression> head_gexpr = |
119 |
| - std::make_shared<GroupExpression>(Operator(), child_groups); |
120 |
| - |
121 |
| - // Check plan is of structure (left JOIN middle) JOIN right |
122 |
| - // Check Parent join |
123 |
| - auto group_expr = GetSingleGroupExpression(memo, head_gexpr.get(), 0); |
124 |
| - EXPECT_EQ(OpType::InnerJoin, group_expr->Op().GetType()); |
125 |
| - auto join_op = group_expr->Op().As<LogicalInnerJoin>(); |
126 |
| - EXPECT_EQ(0, join_op->join_predicates.size()); |
127 |
| - |
128 |
| - // Check left join |
129 |
| - auto l_group_expr = GetSingleGroupExpression(memo, group_expr, 0); |
130 |
| - EXPECT_EQ(OpType::InnerJoin, l_group_expr->Op().GetType()); |
131 |
| - auto left = GetSingleGroupExpression(memo, l_group_expr, 0); |
132 |
| - auto middle = GetSingleGroupExpression(memo, l_group_expr, 1); |
133 |
| - EXPECT_EQ(OpType::Get, left->Op().GetType()); |
134 |
| - EXPECT_EQ(OpType::Get, middle->Op().GetType()); |
135 |
| - |
136 |
| - // Check right Get |
137 |
| - auto right = GetSingleGroupExpression(memo, group_expr, 1); |
138 |
| - EXPECT_EQ(OpType::Get, right->Op().GetType()); |
139 |
| - |
140 |
| - std::shared_ptr<OptimizeContext> root_context = |
141 |
| - std::make_shared<OptimizeContext>(&(optimizer.metadata_), nullptr); |
142 |
| - |
143 |
| - auto task_stack = |
144 |
| - std::unique_ptr<OptimizerTaskStack>(new OptimizerTaskStack()); |
145 |
| - optimizer.metadata_.SetTaskPool(task_stack.get()); |
146 |
| - task_stack->Push( |
147 |
| - new ApplyRule(group_expr, new InnerJoinAssociativity, root_context)); |
148 |
| - |
149 |
| - while (!task_stack->Empty()) { |
150 |
| - auto task = task_stack->Pop(); |
151 |
| - task->execute(); |
152 |
| - } |
| 84 | +TEST_F(OptimizerRuleTests, SimpleAssociativeRuleTest) { |
| 85 | + |
| 86 | + // (left JOIN middle) JOIN right |
| 87 | + // Build Operator Expression |
| 88 | + |
| 89 | + // Setup Memo |
| 90 | + Optimizer optimizer; |
| 91 | + |
| 92 | + auto left_get = std::make_shared<OperatorExpression>(LogicalGet::make(0, {}, nullptr, "test1")); |
| 93 | + auto middle_get = std::make_shared<OperatorExpression>(LogicalGet::make(0, {}, nullptr, "test2")); |
| 94 | + auto right_get = std::make_shared<OperatorExpression>(LogicalGet::make(0, {}, nullptr, "test3")); |
| 95 | + |
| 96 | + auto left_get_group = optimizer.metadata_.memo.InsertExpression(optimizer.metadata_.MakeGroupExpression(left_get), true); |
| 97 | + auto middle_get_group = optimizer.metadata_.memo.InsertExpression(optimizer.metadata_.MakeGroupExpression(middle_get),true); |
| 98 | + auto right_get_group = optimizer.metadata_.memo.InsertExpression(optimizer.metadata_.MakeGroupExpression(right_get),true); |
| 99 | + |
| 100 | + auto left_leaf = std::make_shared<OperatorExpression>(LeafOperator::make(left_get_group->GetGroupID())); |
| 101 | + auto middle_leaf = std::make_shared<OperatorExpression>(LeafOperator::make(middle_get_group->GetGroupID())); |
| 102 | + auto right_leaf = std::make_shared<OperatorExpression>(LeafOperator::make(right_get_group->GetGroupID())); |
| 103 | + |
| 104 | + // Make Child Join |
| 105 | + std::vector<AnnotatedExpression> child_join_predicates; |
| 106 | + std::unordered_set<std::string> child_tables({"test1","test2"}); |
| 107 | + auto dummy_expr = std::shared_ptr<expression::AbstractExpression>{ |
| 108 | + new expression::OperatorExpression(ExpressionType::COMPARE_EQUAL, type::TypeId::INTEGER)}; |
| 109 | + |
| 110 | + AnnotatedExpression pred = {dummy_expr, child_tables}; |
| 111 | + child_join_predicates.push_back(pred); |
| 112 | + |
| 113 | + auto child_join = std::make_shared<OperatorExpression>(LogicalInnerJoin::make(child_join_predicates)); |
| 114 | + child_join->PushChild(left_leaf); |
| 115 | + child_join->PushChild(middle_leaf); |
| 116 | + optimizer.metadata_.memo.InsertExpression(optimizer.metadata_.MakeGroupExpression(child_join), true); |
| 117 | + |
| 118 | + // Make Parent join |
| 119 | + std::vector<AnnotatedExpression> parent_join_predicates; |
| 120 | + std::unordered_set<std::string> parent_tables({"test1","test3"}); |
| 121 | + pred = {dummy_expr, parent_tables}; |
| 122 | + parent_join_predicates.push_back(pred); |
| 123 | + |
| 124 | + auto parent_join = std::make_shared<OperatorExpression>(LogicalInnerJoin::make(parent_join_predicates)); |
| 125 | + parent_join->PushChild(child_join); |
| 126 | + parent_join->PushChild(right_leaf); |
| 127 | + |
| 128 | + optimizer.metadata_.memo.InsertExpression(optimizer.metadata_.MakeGroupExpression(parent_join), true); |
| 129 | + OptimizeContext* root_context = new OptimizeContext(&(optimizer.metadata_), nullptr); |
| 130 | + LOG_DEBUG("Set up Memo"); |
| 131 | + |
| 132 | + // Setup rule |
| 133 | + InnerJoinAssociativity rule; |
| 134 | + |
| 135 | + EXPECT_TRUE(rule.Check(parent_join, root_context)); |
| 136 | + std::vector<std::shared_ptr<OperatorExpression>> outputs; |
| 137 | + rule.Transform(parent_join, outputs, root_context); |
| 138 | + EXPECT_EQ(1, outputs.size()); |
| 139 | + |
| 140 | + auto output_join = outputs[0]; |
| 141 | + |
| 142 | + EXPECT_EQ(left_leaf, output_join->Children()[0]); |
| 143 | + EXPECT_EQ(middle_leaf, output_join->Children()[1]->Children()[0]); |
| 144 | + EXPECT_EQ(right_leaf, output_join->Children()[1]->Children()[1]); |
| 145 | + |
| 146 | + auto parent_join_op = output_join->Op().As<LogicalInnerJoin>(); |
| 147 | + auto child_join_op = output_join->Children()[1]->Op().As<LogicalInnerJoin>(); |
| 148 | + EXPECT_EQ(2, parent_join_op->join_predicates.size()); |
| 149 | + EXPECT_EQ(0, child_join_op->join_predicates.size()); |
153 | 150 |
|
154 |
| - LOG_DEBUG("Executed all tasks"); |
155 |
| - |
156 |
| - // Check plan is now: left JOIN (middle JOIN right) |
157 |
| - // Check Parent join |
158 |
| - EXPECT_EQ(OpType::InnerJoin, group_expr->Op().GetType()); |
159 |
| - join_op = group_expr->Op().As<LogicalInnerJoin>(); |
160 |
| - EXPECT_EQ(0, join_op->join_predicates.size()); |
161 |
| - EXPECT_EQ(2, group_expr->GetChildrenGroupsSize()); |
162 |
| - LOG_DEBUG("Parent join: OK"); |
163 |
| - |
164 |
| - // Check left Get |
165 |
| - // TODO: Not sure why left is at index 1, but the (middle JOIN right) is at |
166 |
| - // index 0 |
167 |
| - left = GetSingleGroupExpression(memo, group_expr, 1); |
168 |
| - EXPECT_EQ(OpType::Get, left->Op().GetType()); |
169 |
| - LOG_DEBUG("Left Leaf: OK"); |
170 |
| - |
171 |
| - // Check (right JOIN right) |
172 |
| - auto r_group_expr = GetSingleGroupExpression(memo, group_expr, 0); |
173 |
| - EXPECT_EQ(OpType::InnerJoin, r_group_expr->Op().GetType()); |
174 |
| - middle = GetSingleGroupExpression(memo, r_group_expr, 0); |
175 |
| - right = GetSingleGroupExpression(memo, r_group_expr, 1); |
176 |
| - EXPECT_EQ(OpType::Get, middle->Op().GetType()); |
177 |
| - EXPECT_EQ(OpType::Get, right->Op().GetType()); |
178 |
| - LOG_DEBUG("Right join: OK"); |
179 |
| - |
180 |
| - txn_manager.CommitTransaction(txn); |
181 | 151 | }
|
182 | 152 |
|
183 | 153 | } // namespace test
|
|
0 commit comments