10
10
//
11
11
// ===----------------------------------------------------------------------===//
12
12
13
+ #include < include/concurrency/transaction_manager_factory.h>
14
+ #include < include/parser/postgresparser.h>
13
15
#include " common/harness.h"
14
16
15
17
#define private public
19
21
#include " optimizer/optimizer.h"
20
22
#include " optimizer/rule.h"
21
23
#include " optimizer/rule_impls.h"
22
-
24
+ # include " sql/testing_sql_util.h "
23
25
#include " catalog/catalog.h"
24
26
#include " common/logger.h"
25
27
#include " common/statement.h"
33
35
#include " planner/insert_plan.h"
34
36
#include " planner/update_plan.h"
35
37
#include " type/value_factory.h"
38
+ #include " binder/bind_node_visitor.h"
39
+
36
40
37
41
namespace peloton {
38
42
namespace test {
@@ -43,7 +47,16 @@ namespace test {
43
47
44
48
using namespace optimizer ;
45
49
46
- class OptimizerRuleTests : public PelotonTest {};
50
+ class OptimizerRuleTests : public PelotonTest {
51
+ protected:
52
+ GroupExpression *GetSingleGroupExpression (Memo &memo, GroupExpression *expr,
53
+ int child_group_idx) {
54
+ auto group = memo.GetGroupByID (expr->GetChildGroupId (child_group_idx));
55
+ EXPECT_EQ (1 , group->logical_expressions_ .size ());
56
+ return group->logical_expressions_ [0 ].get ();
57
+ }
58
+ };
59
+
47
60
48
61
TEST_F (OptimizerRuleTests, SimpleCommutativeRuleTest) {
49
62
// Build op plan node to match rule
@@ -69,37 +82,102 @@ TEST_F(OptimizerRuleTests, SimpleCommutativeRuleTest) {
69
82
70
83
}
71
84
72
- TEST_F (OptimizerRuleTests, SimpleAssociativeRuleTest) {
73
- // Build op plan node to match rule
74
- // (left JOIN middle) JOIN right
75
- auto left_get = std::make_shared<OperatorExpression>(LogicalGet::make ());
76
- auto middle_get = std::make_shared<OperatorExpression>(LogicalGet::make ());
77
- auto right_get = std::make_shared<OperatorExpression>(LogicalGet::make ());
78
- auto child_join = std::make_shared<OperatorExpression>(LogicalInnerJoin::make ());
79
- child_join->PushChild (left_get);
80
- child_join->PushChild (middle_get);
81
-
82
- auto parent_join = std::make_shared<OperatorExpression>(LogicalInnerJoin::make ());
83
- parent_join->PushChild (child_join);
84
- parent_join->PushChild (right_get);
85
-
86
-
87
- // Setup rule
88
- InnerJoinAssociativity rule;
89
-
90
- EXPECT_TRUE (rule.Check (parent_join, nullptr ));
91
-
92
- std::vector<std::shared_ptr<OperatorExpression>> outputs;
93
- rule.Transform (parent_join, outputs, nullptr );
94
- EXPECT_EQ (outputs.size (), 1 );
95
-
96
- auto output_join = outputs[0 ];
97
-
98
- EXPECT_EQ (output_join->Children ()[0 ], left_get);
99
- EXPECT_EQ (output_join->Children ()[1 ]->Children ()[0 ], middle_get);
100
- EXPECT_EQ (output_join->Children ()[1 ]->Children ()[1 ], right_get);
101
-
102
-
85
+ TEST_F (OptimizerRuleTests, AssociativeRuleTest) {
86
+ auto &txn_manager = concurrency::TransactionManagerFactory::GetInstance ();
87
+ auto txn = txn_manager.BeginTransaction ();
88
+ catalog::Catalog::GetInstance ()->CreateDatabase (DEFAULT_DB_NAME, txn);
89
+ txn_manager.CommitTransaction (txn);
90
+
91
+ TestingSQLUtil::ExecuteSQLQuery (
92
+ " CREATE TABLE test1(a INT PRIMARY KEY, b INT, c INT);" );
93
+ TestingSQLUtil::ExecuteSQLQuery (
94
+ " CREATE TABLE test2(a INT PRIMARY KEY, b INT, c INT);" );
95
+ TestingSQLUtil::ExecuteSQLQuery (
96
+ " CREATE TABLE test3(a INT PRIMARY KEY, b INT, c INT);" );
97
+
98
+ auto &peloton_parser = parser::PostgresParser::GetInstance ();
99
+ auto stmt = peloton_parser.BuildParseTree (
100
+ " SELECT * FROM test1, test2, test3" );
101
+ auto parse_tree = stmt->GetStatements ().at (0 ).get ();
102
+ auto predicates = std::vector<expression::AbstractExpression *>();
103
+
104
+ optimizer::Optimizer optimizer;
105
+
106
+ // Push Associativity rule and execute tasks
107
+ optimizer.metadata_ .rule_set .transformation_rules_ .clear ();
108
+ optimizer.metadata_ .rule_set .transformation_rules_ .emplace_back (
109
+ new InnerJoinAssociativity ());
110
+
111
+ txn = txn_manager.BeginTransaction ();
112
+
113
+ auto bind_node_visitor =
114
+ std::make_shared<binder::BindNodeVisitor>(txn, DEFAULT_DB_NAME);
115
+ bind_node_visitor->BindNameToNode (parse_tree);
116
+
117
+ std::shared_ptr<GroupExpression> gexpr =
118
+ optimizer.InsertQueryTree (parse_tree, txn);
119
+ std::vector<GroupID> child_groups = {gexpr->GetGroupID ()};
120
+
121
+ auto &memo = optimizer.metadata_ .memo ;
122
+ std::shared_ptr<GroupExpression> head_gexpr = std::make_shared<GroupExpression>(Operator (), child_groups);
123
+
124
+ // Check plan is of structure (left JOIN middle) JOIN right
125
+ // Check Parent join
126
+ auto group_expr = GetSingleGroupExpression (memo, head_gexpr.get (), 0 );
127
+ EXPECT_EQ (OpType::InnerJoin, group_expr->Op ().GetType ());
128
+ auto join_op = group_expr->Op ().As <LogicalInnerJoin>();
129
+ EXPECT_EQ (0 , join_op->join_predicates .size ());
130
+
131
+ // Check left join
132
+ auto l_group_expr = GetSingleGroupExpression (memo, group_expr, 0 );
133
+ EXPECT_EQ (OpType::InnerJoin, l_group_expr->Op ().GetType ());
134
+ auto left = GetSingleGroupExpression (memo, l_group_expr, 0 );
135
+ auto middle = GetSingleGroupExpression (memo, l_group_expr, 1 );
136
+ EXPECT_EQ (OpType::Get, left->Op ().GetType ());
137
+ EXPECT_EQ (OpType::Get, middle->Op ().GetType ());
138
+
139
+ // Check right Get
140
+ auto right = GetSingleGroupExpression (memo, group_expr, 1 );
141
+ EXPECT_EQ (OpType::Get, right->Op ().GetType ());
142
+
143
+ std::shared_ptr<OptimizeContext> root_context =
144
+ std::make_shared<OptimizeContext>(&(optimizer.metadata_ ), nullptr );
145
+
146
+ auto task_stack = std::unique_ptr<OptimizerTaskStack>(new OptimizerTaskStack ());
147
+ optimizer.metadata_ .SetTaskPool (task_stack.get ());
148
+ task_stack->Push (new ApplyRule (group_expr, new InnerJoinAssociativity, root_context));
149
+
150
+ while (!task_stack->Empty ()) {
151
+ auto task = task_stack->Pop ();
152
+ task->execute ();
153
+ }
154
+
155
+ LOG_DEBUG (" Executed all tasks" );
156
+
157
+ // Check plan is now: left JOIN (middle JOIN right)
158
+ // Check Parent join
159
+ EXPECT_EQ (OpType::InnerJoin, group_expr->Op ().GetType ());
160
+ join_op = group_expr->Op ().As <LogicalInnerJoin>();
161
+ EXPECT_EQ (0 , join_op->join_predicates .size ());
162
+ EXPECT_EQ (2 , group_expr->GetChildrenGroupsSize ());
163
+ LOG_DEBUG (" Parent join: OK" );
164
+
165
+ // Check left Get
166
+ // TODO: Not sure why left is at index 1, but the (middle JOIN right) is at index 0
167
+ left = GetSingleGroupExpression (memo, group_expr, 1 );
168
+ EXPECT_EQ (OpType::Get, left->Op ().GetType ());
169
+ LOG_DEBUG (" Left Leaf: OK" );
170
+
171
+ // Check (right JOIN right)
172
+ auto r_group_expr = GetSingleGroupExpression (memo, group_expr, 0 );
173
+ EXPECT_EQ (OpType::InnerJoin, r_group_expr->Op ().GetType ());
174
+ middle = GetSingleGroupExpression (memo, r_group_expr, 0 );
175
+ right = GetSingleGroupExpression (memo, r_group_expr, 1 );
176
+ EXPECT_EQ (OpType::Get, middle->Op ().GetType ());
177
+ EXPECT_EQ (OpType::Get, right->Op ().GetType ());
178
+ LOG_DEBUG (" Right join: OK" );
179
+
180
+ txn_manager.CommitTransaction (txn);
103
181
}
104
182
105
183
} // namespace test
0 commit comments