@@ -103,14 +103,12 @@ void InnerJoinAssociativity::Transform(
103
103
// NOTE: Transforms (left JOIN middle) JOIN right -> left JOIN (middle JOIN right)
104
104
// Variables are named accordingly to above transformation
105
105
106
-
107
-
108
- // auto result_plan = std::make_shared<OperatorExpression>(
109
- // LogicalInnerJoin::make(join_predicates));
110
-
111
106
auto parent_join = input->Op ().As <LogicalInnerJoin>();
112
107
std::vector<std::shared_ptr<OperatorExpression>> children = input->Children ();
113
108
auto child_join = children[0 ]->Op ().As <LogicalInnerJoin>();
109
+ auto left = children[0 ]->Children ()[0 ];
110
+ auto middle = children[0 ]->Children ()[1 ];
111
+ auto right = children[1 ];
114
112
PL_ASSERT (children.size () == 2 );
115
113
PL_ASSERT (children[0 ]->Op ().GetType () == OpType::InnerJoin);
116
114
PL_ASSERT (children[0 ]->Children ().size () == 2 );
@@ -128,8 +126,6 @@ void InnerJoinAssociativity::Transform(
128
126
const auto &right_group_aliases_set =
129
127
memo.GetGroupByID (right_group_id)->GetTableAliases ();
130
128
131
-
132
-
133
129
// Redistribute predicates
134
130
auto parent_join_predicates = std::vector<AnnotatedExpression>(parent_join->join_predicates );
135
131
auto child_join_predicates = std::vector<AnnotatedExpression>(child_join->join_predicates );
@@ -138,19 +134,41 @@ void InnerJoinAssociativity::Transform(
138
134
predicates.insert (predicates.end (), parent_join_predicates.begin (), parent_join_predicates.end ());
139
135
predicates.insert (predicates.end (), child_join_predicates.begin (), child_join_predicates.end ());
140
136
141
- // for (auto predicate : predicates) {
142
- //
143
- // }
144
- //
145
- //
146
- //
147
- // LOG_TRACE(
148
- // "Reorder left child with op %s and right child with op %s for inner join",
149
- // children[0]->Op().GetName().c_str(), children[1]->Op().GetName().c_str());
150
- // result_plan->PushChild(children[1]);
151
- // result_plan->PushChild(children[0]);
152
- //
153
- // transformed.push_back(result_plan);
137
+ std::vector<AnnotatedExpression> new_child_join_predicates;
138
+ std::vector<AnnotatedExpression> new_parent_join_predicates;
139
+
140
+ // TODO: This assumes that predicate pushdown has not occured yet, as it will put all non-join predicates into parent join
141
+ for (auto predicate : predicates) {
142
+
143
+ // New child join predicate must contain middle and right group
144
+ if (util::IsSubset (middle_group_aliases_set, predicate.table_alias_set ) &&
145
+ util::IsSubset (right_group_aliases_set, predicate.table_alias_set ))
146
+ new_child_join_predicates.emplace_back (predicate);
147
+ else
148
+ new_parent_join_predicates.emplace_back (predicate);
149
+ }
150
+
151
+ // Construct new child join operator
152
+ std::shared_ptr<OperatorExpression> new_child_join =
153
+ std::make_shared<OperatorExpression>(
154
+ LogicalInnerJoin::make (new_child_join_predicates));
155
+ new_child_join->PushChild (middle);
156
+ new_child_join->PushChild (right);
157
+
158
+ // Construct new parent join operator
159
+ std::shared_ptr<OperatorExpression> new_parent_join =
160
+ std::make_shared<OperatorExpression>(
161
+ LogicalInnerJoin::make (new_parent_join_predicates));
162
+ new_parent_join->PushChild (left);
163
+ new_parent_join->PushChild (new_child_join);
164
+
165
+
166
+ LOG_TRACE (
167
+ " Reordered join structured: (%s JOIN %s) JOIN %s" ,
168
+ left->Op ().GetName ().c_str (), middle->Op ().GetName ().c_str (), right->Op ().GetName ().c_str ());
169
+
170
+ transformed.push_back (new_parent_join);
171
+
154
172
}
155
173
156
174
// ===--------------------------------------------------------------------===//
0 commit comments