4747import java .util .Map ;
4848import java .util .Map .Entry ;
4949import java .util .Set ;
50+ import java .util .stream .Collectors ;
5051
5152/**
5253 * Push down agg through join with foreign key:
@@ -131,13 +132,27 @@ private LogicalAggregate<?> eliminatePrimaryOutput(LogicalAggregate<?> agg, Plan
131132 if (primary .getOutputSet ().stream ().noneMatch (aggInputs ::contains )) {
132133 return agg ;
133134 }
134- Set <Slot > primaryOutputSet = primary .getOutputSet ();
135- Set <Slot > primarySlots = Sets .intersection (aggInputs , primaryOutputSet );
135+ // Firstly, using fd to eliminate group by key.
136+ // group by primary_table_pk, primary_table_other
137+ // -> group by primary_table_pk
138+ Set <Set <Slot >> groupBySlots = new HashSet <>();
139+ for (Expression expression : agg .getGroupByExpressions ()) {
140+ groupBySlots .add (expression .getInputSlots ());
141+ }
136142 DataTrait dataTrait = child .getLogicalProperties ().getTrait ();
137143 FuncDeps funcDeps = dataTrait .getAllValidFuncDeps (Sets .union (foreign .getOutputSet (), primary .getOutputSet ()));
144+ Set <Slot > foreignOutput = Sets .intersection (agg .getOutputSet (), foreign .getOutputSet ());
145+ Set <Set <Slot >> minGroupBySlots = funcDeps .eliminateDeps (groupBySlots , foreignOutput );
146+ List <Slot > minGroupBySlotList = minGroupBySlots .stream ().flatMap (Set ::stream ).collect (Collectors .toList ());
147+
148+ // Secondly, put bijective slot into map: {primary_table_pk : foreign_table_fk}
149+ // Bijective slots are mutually interchangeable within GROUP BY keys.
150+ // group by primary_table_pk equals group by foreign_table_fk
151+ Set <Slot > primaryOutputSet = primary .getOutputSet ();
152+ Set <Slot > primarySlots = Sets .intersection (aggInputs , primaryOutputSet );
138153 HashMap <Slot , Slot > primaryToForeignDeps = new HashMap <>();
139154 for (Slot slot : primarySlots ) {
140- Set <Set <Slot >> replacedSlotSets = funcDeps .findDeterminats (ImmutableSet .of (slot ));
155+ Set <Set <Slot >> replacedSlotSets = funcDeps .findBijectionSlots (ImmutableSet .of (slot ));
141156 for (Set <Slot > replacedSlots : replacedSlotSets ) {
142157 if (primaryOutputSet .stream ().noneMatch (replacedSlots ::contains )
143158 && replacedSlots .size () == 1 ) {
@@ -147,7 +162,9 @@ private LogicalAggregate<?> eliminatePrimaryOutput(LogicalAggregate<?> agg, Plan
147162 }
148163 }
149164
150- Set <Expression > newGroupBySlots = constructNewGroupBy (agg , primaryOutputSet , primaryToForeignDeps );
165+ // Thirdly, construct new Agg below join.
166+ Set <Expression > newGroupBySlots = constructNewGroupBy (minGroupBySlotList , primaryOutputSet ,
167+ primaryToForeignDeps );
151168 List <NamedExpression > newOutput = constructNewOutput (
152169 agg , primaryOutputSet , primaryToForeignDeps , funcDeps , primary );
153170 if (newGroupBySlots == null || newOutput == null ) {
@@ -156,10 +173,10 @@ private LogicalAggregate<?> eliminatePrimaryOutput(LogicalAggregate<?> agg, Plan
156173 return agg .withGroupByAndOutput (ImmutableList .copyOf (newGroupBySlots ), ImmutableList .copyOf (newOutput ));
157174 }
158175
159- private @ Nullable Set <Expression > constructNewGroupBy (LogicalAggregate <?> agg , Set < Slot > primaryOutputs ,
160- Map <Slot , Slot > primaryToForeignBiDeps ) {
176+ private @ Nullable Set <Expression > constructNewGroupBy (List <? extends Expression > gbyExpression ,
177+ Set < Slot > primaryOutputs , Map <Slot , Slot > primaryToForeignBiDeps ) {
161178 Set <Expression > newGroupBySlots = new HashSet <>();
162- for (Expression expression : agg . getGroupByExpressions () ) {
179+ for (Expression expression : gbyExpression ) {
163180 if (!(expression instanceof Slot )) {
164181 return null ;
165182 }
0 commit comments