Skip to content
This repository was archived by the owner on Sep 27, 2019. It is now read-only.

Commit 82ccbd9

Browse files
author
GustavoAngulo
committed
Simple reordering test case
1 parent b01590e commit 82ccbd9

File tree

2 files changed

+112
-38
lines changed

2 files changed

+112
-38
lines changed

src/optimizer/rule_impls.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,6 @@ void InnerJoinAssociativity::Transform(
102102

103103
// NOTE: Transforms (left JOIN middle) JOIN right -> left JOIN (middle JOIN right)
104104
// Variables are named accordingly to above transformation
105-
106105
auto parent_join = input->Op().As<LogicalInnerJoin>();
107106
std::vector<std::shared_ptr<OperatorExpression>> children = input->Children();
108107
auto child_join = children[0]->Op().As<LogicalInnerJoin>();
@@ -115,12 +114,9 @@ void InnerJoinAssociativity::Transform(
115114

116115
// Get Alias sets
117116
auto &memo = context->metadata->memo;
118-
// auto left_group_id = children[0]->Children()[0]->Op().As<LeafOperator>()->origin_group;
119117
auto middle_group_id = children[0]->Children()[1]->Op().As<LeafOperator>()->origin_group;
120118
auto right_group_id = children[1]->Op().As<LeafOperator>()->origin_group;
121119

122-
// const auto &left_group_aliases_set =
123-
// memo.GetGroupByID(left_group_id)->GetTableAliases();
124120
const auto &middle_group_aliases_set =
125121
memo.GetGroupByID(middle_group_id)->GetTableAliases();
126122
const auto &right_group_aliases_set =
@@ -163,7 +159,7 @@ void InnerJoinAssociativity::Transform(
163159
new_parent_join->PushChild(new_child_join);
164160

165161

166-
LOG_TRACE(
162+
LOG_DEBUG(
167163
"Reordered join structured: (%s JOIN %s) JOIN %s",
168164
left->Op().GetName().c_str(), middle->Op().GetName().c_str(), right->Op().GetName().c_str());
169165

test/optimizer/optimizer_rule_test.cpp

Lines changed: 111 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
//
1111
//===----------------------------------------------------------------------===//
1212

13+
#include <include/concurrency/transaction_manager_factory.h>
14+
#include <include/parser/postgresparser.h>
1315
#include "common/harness.h"
1416

1517
#define private public
@@ -19,7 +21,7 @@
1921
#include "optimizer/optimizer.h"
2022
#include "optimizer/rule.h"
2123
#include "optimizer/rule_impls.h"
22-
24+
#include "sql/testing_sql_util.h"
2325
#include "catalog/catalog.h"
2426
#include "common/logger.h"
2527
#include "common/statement.h"
@@ -33,6 +35,8 @@
3335
#include "planner/insert_plan.h"
3436
#include "planner/update_plan.h"
3537
#include "type/value_factory.h"
38+
#include "binder/bind_node_visitor.h"
39+
3640

3741
namespace peloton {
3842
namespace test {
@@ -43,7 +47,16 @@ namespace test {
4347

4448
using namespace optimizer;
4549

46-
class OptimizerRuleTests : public PelotonTest {};
50+
class OptimizerRuleTests : public PelotonTest {
51+
protected:
52+
GroupExpression *GetSingleGroupExpression(Memo &memo, GroupExpression *expr,
53+
int child_group_idx) {
54+
auto group = memo.GetGroupByID(expr->GetChildGroupId(child_group_idx));
55+
EXPECT_EQ(1, group->logical_expressions_.size());
56+
return group->logical_expressions_[0].get();
57+
}
58+
};
59+
4760

4861
TEST_F(OptimizerRuleTests, SimpleCommutativeRuleTest) {
4962
// Build op plan node to match rule
@@ -69,37 +82,102 @@ TEST_F(OptimizerRuleTests, SimpleCommutativeRuleTest) {
6982

7083
}
7184

72-
TEST_F(OptimizerRuleTests, SimpleAssociativeRuleTest) {
73-
// Build op plan node to match rule
74-
// (left JOIN middle) JOIN right
75-
auto left_get = std::make_shared<OperatorExpression>(LogicalGet::make());
76-
auto middle_get = std::make_shared<OperatorExpression>(LogicalGet::make());
77-
auto right_get = std::make_shared<OperatorExpression>(LogicalGet::make());
78-
auto child_join = std::make_shared<OperatorExpression>(LogicalInnerJoin::make());
79-
child_join->PushChild(left_get);
80-
child_join->PushChild(middle_get);
81-
82-
auto parent_join = std::make_shared<OperatorExpression>(LogicalInnerJoin::make());
83-
parent_join->PushChild(child_join);
84-
parent_join->PushChild(right_get);
85-
86-
87-
// Setup rule
88-
InnerJoinAssociativity rule;
89-
90-
EXPECT_TRUE(rule.Check(parent_join, nullptr));
91-
92-
std::vector<std::shared_ptr<OperatorExpression>> outputs;
93-
rule.Transform(parent_join, outputs, nullptr);
94-
EXPECT_EQ(outputs.size(), 1);
95-
96-
auto output_join = outputs[0];
97-
98-
EXPECT_EQ(output_join->Children()[0], left_get);
99-
EXPECT_EQ(output_join->Children()[1]->Children()[0], middle_get);
100-
EXPECT_EQ(output_join->Children()[1]->Children()[1], right_get);
101-
102-
85+
TEST_F(OptimizerRuleTests, AssociativeRuleTest) {
86+
auto &txn_manager = concurrency::TransactionManagerFactory::GetInstance();
87+
auto txn = txn_manager.BeginTransaction();
88+
catalog::Catalog::GetInstance()->CreateDatabase(DEFAULT_DB_NAME, txn);
89+
txn_manager.CommitTransaction(txn);
90+
91+
TestingSQLUtil::ExecuteSQLQuery(
92+
"CREATE TABLE test1(a INT PRIMARY KEY, b INT, c INT);");
93+
TestingSQLUtil::ExecuteSQLQuery(
94+
"CREATE TABLE test2(a INT PRIMARY KEY, b INT, c INT);");
95+
TestingSQLUtil::ExecuteSQLQuery(
96+
"CREATE TABLE test3(a INT PRIMARY KEY, b INT, c INT);");
97+
98+
auto &peloton_parser = parser::PostgresParser::GetInstance();
99+
auto stmt = peloton_parser.BuildParseTree(
100+
"SELECT * FROM test1, test2, test3");
101+
auto parse_tree = stmt->GetStatements().at(0).get();
102+
auto predicates = std::vector<expression::AbstractExpression *>();
103+
104+
optimizer::Optimizer optimizer;
105+
106+
// Push Associativity rule and execute tasks
107+
optimizer.metadata_.rule_set.transformation_rules_.clear();
108+
optimizer.metadata_.rule_set.transformation_rules_.emplace_back(
109+
new InnerJoinAssociativity());
110+
111+
txn = txn_manager.BeginTransaction();
112+
113+
auto bind_node_visitor =
114+
std::make_shared<binder::BindNodeVisitor>(txn, DEFAULT_DB_NAME);
115+
bind_node_visitor->BindNameToNode(parse_tree);
116+
117+
std::shared_ptr<GroupExpression> gexpr =
118+
optimizer.InsertQueryTree(parse_tree, txn);
119+
std::vector<GroupID> child_groups = {gexpr->GetGroupID()};
120+
121+
auto &memo = optimizer.metadata_.memo;
122+
std::shared_ptr<GroupExpression> head_gexpr = std::make_shared<GroupExpression>(Operator(), child_groups);
123+
124+
// Check plan is of structure (left JOIN middle) JOIN right
125+
// Check Parent join
126+
auto group_expr = GetSingleGroupExpression(memo, head_gexpr.get(), 0);
127+
EXPECT_EQ(OpType::InnerJoin, group_expr->Op().GetType());
128+
auto join_op = group_expr->Op().As<LogicalInnerJoin>();
129+
EXPECT_EQ(0, join_op->join_predicates.size());
130+
131+
// Check left join
132+
auto l_group_expr = GetSingleGroupExpression(memo, group_expr, 0);
133+
EXPECT_EQ(OpType::InnerJoin, l_group_expr->Op().GetType());
134+
auto left = GetSingleGroupExpression(memo, l_group_expr, 0);
135+
auto middle = GetSingleGroupExpression(memo, l_group_expr, 1);
136+
EXPECT_EQ(OpType::Get, left->Op().GetType());
137+
EXPECT_EQ(OpType::Get, middle->Op().GetType());
138+
139+
// Check right Get
140+
auto right = GetSingleGroupExpression(memo, group_expr, 1);
141+
EXPECT_EQ(OpType::Get, right->Op().GetType());
142+
143+
std::shared_ptr<OptimizeContext> root_context =
144+
std::make_shared<OptimizeContext>(&(optimizer.metadata_), nullptr);
145+
146+
auto task_stack = std::unique_ptr<OptimizerTaskStack>(new OptimizerTaskStack());
147+
optimizer.metadata_.SetTaskPool(task_stack.get());
148+
task_stack->Push(new ApplyRule(group_expr, new InnerJoinAssociativity, root_context));
149+
150+
while (!task_stack->Empty()) {
151+
auto task = task_stack->Pop();
152+
task->execute();
153+
}
154+
155+
LOG_DEBUG("Executed all tasks");
156+
157+
// Check plan is now: left JOIN (middle JOIN right)
158+
// Check Parent join
159+
EXPECT_EQ(OpType::InnerJoin, group_expr->Op().GetType());
160+
join_op = group_expr->Op().As<LogicalInnerJoin>();
161+
EXPECT_EQ(0, join_op->join_predicates.size());
162+
EXPECT_EQ(2, group_expr->GetChildrenGroupsSize());
163+
LOG_DEBUG("Parent join: OK");
164+
165+
// Check left Get
166+
//TODO: Not sure why left is at index 1, but the (middle JOIN right) is at index 0
167+
left = GetSingleGroupExpression(memo, group_expr, 1);
168+
EXPECT_EQ(OpType::Get, left->Op().GetType());
169+
LOG_DEBUG("Left Leaf: OK");
170+
171+
// Check (right JOIN right)
172+
auto r_group_expr = GetSingleGroupExpression(memo, group_expr, 0);
173+
EXPECT_EQ(OpType::InnerJoin, r_group_expr->Op().GetType());
174+
middle = GetSingleGroupExpression(memo, r_group_expr, 0);
175+
right = GetSingleGroupExpression(memo, r_group_expr, 1);
176+
EXPECT_EQ(OpType::Get, middle->Op().GetType());
177+
EXPECT_EQ(OpType::Get, right->Op().GetType());
178+
LOG_DEBUG("Right join: OK");
179+
180+
txn_manager.CommitTransaction(txn);
103181
}
104182

105183
} // namespace test

0 commit comments

Comments
 (0)