@@ -876,7 +876,7 @@ void EmbedFilterIntoGet::Transform(
876
876
}
877
877
878
878
// /////////////////////////////////////////////////////////////////////////////
879
- // / MarkJoinGetToInnerJoin
879
+ // / MarkJoinToInnerJoin
880
880
MarkJoinToInnerJoin::MarkJoinToInnerJoin () {
881
881
type_ = RuleType::MARK_JOIN_GET_TO_INNER_JOIN;
882
882
@@ -886,7 +886,7 @@ MarkJoinToInnerJoin::MarkJoinToInnerJoin() {
886
886
}
887
887
888
888
int MarkJoinToInnerJoin::Promise (GroupExpression *group_expr,
889
- OptimizeContext *context) const {
889
+ OptimizeContext *context) const {
890
890
(void )context;
891
891
auto root_type = match_pattern->Type ();
892
892
// This rule is not applicable
@@ -897,7 +897,7 @@ int MarkJoinToInnerJoin::Promise(GroupExpression *group_expr,
897
897
}
898
898
899
899
bool MarkJoinToInnerJoin::Check (std::shared_ptr<OperatorExpression> plan,
900
- OptimizeContext *context) const {
900
+ OptimizeContext *context) const {
901
901
(void )context;
902
902
(void )plan;
903
903
@@ -925,6 +925,56 @@ void MarkJoinToInnerJoin::Transform(
925
925
transformed.push_back (output);
926
926
}
927
927
928
+ // /////////////////////////////////////////////////////////////////////////////
929
+ // / SingleJoinGetToInnerJoin
930
+ SingleJoinToInnerJoin::SingleJoinToInnerJoin () {
931
+ type_ = RuleType::MARK_JOIN_GET_TO_INNER_JOIN;
932
+
933
+ match_pattern = std::make_shared<Pattern>(OpType::LogicalSingleJoin);
934
+ match_pattern->AddChild (std::make_shared<Pattern>(OpType::Leaf));
935
+ match_pattern->AddChild (std::make_shared<Pattern>(OpType::Leaf));
936
+ }
937
+
938
+ int SingleJoinToInnerJoin::Promise (GroupExpression *group_expr,
939
+ OptimizeContext *context) const {
940
+ (void )context;
941
+ auto root_type = match_pattern->Type ();
942
+ // This rule is not applicable
943
+ if (root_type != OpType::Leaf && root_type != group_expr->Op ().type ()) {
944
+ return 0 ;
945
+ }
946
+ return static_cast <int >(UnnestPromise::Low);
947
+ }
948
+
949
+ bool SingleJoinToInnerJoin::Check (std::shared_ptr<OperatorExpression> plan,
950
+ OptimizeContext *context) const {
951
+ (void )context;
952
+ (void )plan;
953
+
954
+ UNUSED_ATTRIBUTE auto &children = plan->Children ();
955
+ PL_ASSERT (children.size () == 2 );
956
+
957
+ return true ;
958
+ }
959
+
960
+ void SingleJoinToInnerJoin::Transform (
961
+ std::shared_ptr<OperatorExpression> input,
962
+ std::vector<std::shared_ptr<OperatorExpression>> &transformed,
963
+ UNUSED_ATTRIBUTE OptimizeContext *context) const {
964
+ UNUSED_ATTRIBUTE auto single_join = input->Op ().As <LogicalSingleJoin>();
965
+ auto &join_children = input->Children ();
966
+
967
+ PL_ASSERT (single_join->join_predicates .empty ());
968
+
969
+ std::shared_ptr<OperatorExpression> output =
970
+ std::make_shared<OperatorExpression>(LogicalInnerJoin::make ());
971
+
972
+ output->PushChild (join_children[0 ]);
973
+ output->PushChild (join_children[1 ]);
974
+
975
+ transformed.push_back (output);
976
+ }
977
+
928
978
// /////////////////////////////////////////////////////////////////////////////
929
979
// / PullFilterThroughMarkJoin
930
980
PullFilterThroughMarkJoin::PullFilterThroughMarkJoin () {
@@ -986,5 +1036,102 @@ void PullFilterThroughMarkJoin::Transform(
986
1036
transformed.push_back (output);
987
1037
}
988
1038
1039
+ // /////////////////////////////////////////////////////////////////////////////
1040
+ // / PullFilterThroughAggregation
1041
+ PullFilterThroughAggregation::PullFilterThroughAggregation () {
1042
+ type_ = RuleType::PULL_FILTER_THROUGH_AGGREGATION;
1043
+
1044
+ auto filter = std::make_shared<Pattern>(OpType::LogicalFilter);
1045
+ filter->AddChild (std::make_shared<Pattern>(OpType::Leaf));
1046
+ match_pattern = std::make_shared<Pattern>(OpType::LogicalAggregateAndGroupBy);
1047
+ match_pattern->AddChild (filter);
1048
+ }
1049
+
1050
+ int PullFilterThroughAggregation::Promise (GroupExpression *group_expr,
1051
+ OptimizeContext *context) const {
1052
+ (void )context;
1053
+ auto root_type = match_pattern->Type ();
1054
+ // This rule is not applicable
1055
+ if (root_type != OpType::Leaf && root_type != group_expr->Op ().type ()) {
1056
+ return 0 ;
1057
+ }
1058
+ return static_cast <int >(UnnestPromise::High);
1059
+ }
1060
+
1061
+ bool PullFilterThroughAggregation::Check (
1062
+ std::shared_ptr<OperatorExpression> plan, OptimizeContext *context) const {
1063
+ (void )context;
1064
+ (void )plan;
1065
+
1066
+ auto &children = plan->Children ();
1067
+ PL_ASSERT (children.size () == 1 );
1068
+ UNUSED_ATTRIBUTE auto &r_grandchildren = children[1 ]->Children ();
1069
+ PL_ASSERT (r_grandchildren.size () == 1 );
1070
+
1071
+ return true ;
1072
+ }
1073
+
1074
+ void PullFilterThroughAggregation::Transform (
1075
+ std::shared_ptr<OperatorExpression> input,
1076
+ std::vector<std::shared_ptr<OperatorExpression>> &transformed,
1077
+ UNUSED_ATTRIBUTE OptimizeContext *context) const {
1078
+ auto &memo = context->metadata ->memo ;
1079
+ auto &filter_expr = input->Children ()[0 ];
1080
+ auto child_group_id =
1081
+ filter_expr->Children ()[0 ]->Op ().As <LeafOperator>()->origin_group ;
1082
+ const auto &child_group_aliases_set =
1083
+ memo.GetGroupByID (child_group_id)->GetTableAliases ();
1084
+
1085
+ auto &predicates = filter_expr->Op ().As <LogicalFilter>()->predicates ;
1086
+
1087
+ std::vector<AnnotatedExpression> correlated_predicates;
1088
+ std::vector<AnnotatedExpression> normal_predicates;
1089
+ std::vector<std::shared_ptr<expression::AbstractExpression>> new_groupby_cols;
1090
+ for (auto &predicate : predicates) {
1091
+ if (util::IsSubset (child_group_aliases_set, predicate.table_alias_set )) {
1092
+ normal_predicates.emplace_back (predicate);
1093
+ } else {
1094
+ // Correlated predicate, already in the form of
1095
+ // (outer_relation.a = (expr))
1096
+ correlated_predicates.emplace_back (predicate);
1097
+ auto &root_expr = predicate.expr ;
1098
+ if (root_expr->GetChild (0 )->GetDepth () < root_expr->GetDepth ()) {
1099
+ new_groupby_cols.emplace_back (root_expr->GetChild (1 )->Copy ());
1100
+ } else {
1101
+ new_groupby_cols.emplace_back (root_expr->GetChild (0 )->Copy ());
1102
+ }
1103
+ }
1104
+ }
1105
+
1106
+ if (correlated_predicates.empty ()) {
1107
+ // No need to pull
1108
+ return ;
1109
+ }
1110
+ auto aggregation = input->Op ().As <LogicalAggregateAndGroupBy>();
1111
+ for (auto &col : aggregation->columns ) {
1112
+ new_groupby_cols.emplace_back (col->Copy ());
1113
+ }
1114
+ std::vector<AnnotatedExpression> new_having (aggregation->having );
1115
+ std::shared_ptr<OperatorExpression> new_aggregation =
1116
+ std::make_shared<OperatorExpression>(LogicalAggregateAndGroupBy::make (
1117
+ new_groupby_cols, new_having));
1118
+ std::shared_ptr<OperatorExpression> output =
1119
+ std::make_shared<OperatorExpression>(
1120
+ LogicalFilter::make (correlated_predicates));
1121
+ output->PushChild (new_aggregation);
1122
+ auto bottom_operator = new_aggregation;
1123
+
1124
+ // Construct child filter if any
1125
+ if (!normal_predicates.empty ()) {
1126
+ std::shared_ptr<OperatorExpression> new_filter =
1127
+ std::make_shared<OperatorExpression>(
1128
+ LogicalFilter::make (normal_predicates));
1129
+ new_aggregation->PushChild (new_filter);
1130
+ bottom_operator = new_filter;
1131
+ }
1132
+ bottom_operator->PushChild (filter_expr->Children ()[0 ]);
1133
+
1134
+ transformed.push_back (output);
1135
+ }
989
1136
} // namespace optimizer
990
1137
} // namespace peloton
0 commit comments