3636
3737import java .util .HashMap ;
3838import java .util .HashSet ;
39+ import java .util .LinkedList ;
3940import java .util .List ;
4041import java .util .Map ;
42+ import java .util .Optional ;
4143import java .util .Set ;
44+ import java .util .function .Function ;
4245import java .util .stream .Collectors ;
4346
4447import static com .facebook .presto .SystemSessionProperties .isMergeAggregationsWithAndWithoutFilter ;
48+ import static com .facebook .presto .expressions .LogicalRowExpressions .or ;
4549import static com .facebook .presto .spi .StandardErrorCode .GENERIC_INTERNAL_ERROR ;
4650import static com .facebook .presto .spi .plan .AggregationNode .Step .FINAL ;
4751import static com .facebook .presto .spi .plan .AggregationNode .Step .PARTIAL ;
@@ -123,11 +127,13 @@ private static class Context
123127 {
124128 private final Map <VariableReferenceExpression , VariableReferenceExpression > partialResultToMask ;
125129 private final Map <VariableReferenceExpression , VariableReferenceExpression > partialOutputMapping ;
130+ private final List <VariableReferenceExpression > newAggregationOutput ;
126131
127132 public Context ()
128133 {
129134 partialResultToMask = new HashMap <>();
130135 partialOutputMapping = new HashMap <>();
136+ newAggregationOutput = new LinkedList <>();
131137 }
132138
133139 public boolean isEmpty ()
@@ -139,6 +145,7 @@ public void clear()
139145 {
140146 partialResultToMask .clear ();
141147 partialOutputMapping .clear ();
148+ newAggregationOutput .clear ();
142149 }
143150
144151 public Map <VariableReferenceExpression , VariableReferenceExpression > getPartialOutputMapping ()
@@ -150,6 +157,11 @@ public Map<VariableReferenceExpression, VariableReferenceExpression> getPartialR
150157 {
151158 return partialResultToMask ;
152159 }
160+
161+ public List <VariableReferenceExpression > getNewAggregationOutput ()
162+ {
163+ return newAggregationOutput ;
164+ }
153165 }
154166
155167 private static class Rewriter
@@ -218,17 +230,60 @@ else if (node.getStep().equals(FINAL)) {
218230 private AggregationNode createPartialAggregationNode (AggregationNode node , PlanNode rewrittenSource , RewriteContext <Context > context )
219231 {
220232 checkState (context .get ().isEmpty (), "There should be no partial aggregation left unmerged for a partial aggregation node" );
233+
221234 Map <AggregationNode .Aggregation , VariableReferenceExpression > aggregationsWithoutMaskToOutput = node .getAggregations ().entrySet ().stream ()
222235 .filter (x -> !x .getValue ().getMask ().isPresent ())
223- .collect (toImmutableMap (x -> x . getValue (), x -> x . getKey () , (a , b ) -> a ));
236+ .collect (toImmutableMap (Map . Entry :: getValue , Map . Entry :: getKey , (a , b ) -> a ));
224237 Map <AggregationNode .Aggregation , VariableReferenceExpression > aggregationsToMergeOutput = node .getAggregations ().entrySet ().stream ()
225238 .filter (x -> x .getValue ().getMask ().isPresent () && aggregationsWithoutMaskToOutput .containsKey (removeFilterAndMask (x .getValue ())))
226- .collect (toImmutableMap (x -> x .getValue (), x -> x .getKey ()));
239+ .collect (toImmutableMap (Map .Entry ::getValue , Map .Entry ::getKey ));
240+
241+ ImmutableMap .Builder <AggregationNode .Aggregation , VariableReferenceExpression > partialAggregationToOutputBuilder = ImmutableMap .builder ();
242+ partialAggregationToOutputBuilder .putAll (aggregationsToMergeOutput .keySet ().stream ().collect (toImmutableMap (Function .identity (), x -> aggregationsWithoutMaskToOutput .get (removeFilterAndMask (x )))));
243+
244+ List <List <AggregationNode .Aggregation >> candidateAggregationsWithMaskNotMatched = node .getAggregations ().entrySet ().stream ().map (Map .Entry ::getValue )
245+ .filter (x -> x .getMask ().isPresent () && !aggregationsToMergeOutput .containsKey (x ))
246+ .collect (Collectors .groupingBy (AggregationNodeUtils ::removeFilterAndMask )).values ()
247+ .stream ().filter (x -> x .size () > 1 ).collect (toImmutableList ());
248+
249+ Map <AggregationNode .Aggregation , VariableReferenceExpression > aggregationsWithMaskToMerge = node .getAggregations ().entrySet ().stream ()
250+ .filter (x -> aggregationsToMergeOutput .containsKey (x .getValue ()) || candidateAggregationsWithMaskNotMatched .stream ().anyMatch (aggregations -> aggregations .contains (x .getValue ())))
251+ .collect (toImmutableMap (Map .Entry ::getValue , Map .Entry ::getKey ));
252+ ImmutableMap .Builder <VariableReferenceExpression , RowExpression > newMaskAssignmentsBuilder = ImmutableMap .builder ();
253+ ImmutableMap .Builder <VariableReferenceExpression , AggregationNode .Aggregation > aggregationsAddedBuilder = ImmutableMap .builder ();
254+ List <AggregationNode .Aggregation > newAggregationAdded = candidateAggregationsWithMaskNotMatched .stream ()
255+ .map (aggregations ->
256+ {
257+ List <VariableReferenceExpression > maskVariables = aggregations .stream ().map (x -> x .getMask ().get ()).collect (toImmutableList ());
258+ RowExpression orMaskVariables = or (maskVariables );
259+ VariableReferenceExpression newMaskVariable = variableAllocator .newVariable (orMaskVariables );
260+ newMaskAssignmentsBuilder .put (newMaskVariable , orMaskVariables );
261+ AggregationNode .Aggregation newAggregation = new AggregationNode .Aggregation (
262+ aggregations .get (0 ).getCall (),
263+ Optional .empty (),
264+ aggregations .get (0 ).getOrderBy (),
265+ aggregations .get (0 ).isDistinct (),
266+ Optional .of (newMaskVariable ));
267+ VariableReferenceExpression newAggregationVariable = variableAllocator .newVariable (newAggregation .getCall ());
268+ aggregationsAddedBuilder .put (newAggregationVariable , newAggregation );
269+ aggregations .forEach (x -> partialAggregationToOutputBuilder .put (x , newAggregationVariable ));
270+ return newAggregation ;
271+ })
272+ .collect (toImmutableList ());
273+ Map <VariableReferenceExpression , RowExpression > newMaskAssignments = newMaskAssignmentsBuilder .build ();
274+ Map <VariableReferenceExpression , AggregationNode .Aggregation > aggregationsAdded = aggregationsAddedBuilder .build ();
275+ Map <AggregationNode .Aggregation , VariableReferenceExpression > partialAggregationToOutput = partialAggregationToOutputBuilder .build ();
276+
277+ Map <AggregationNode .Aggregation , VariableReferenceExpression > aggregationsToMergeOutputCombined =
278+ node .getAggregations ().entrySet ().stream ()
279+ .filter (x -> x .getValue ().getMask ().isPresent () && aggregationsToMergeOutput .containsKey (x .getValue ()) || candidateAggregationsWithMaskNotMatched .stream ().anyMatch (aggregations -> aggregations .contains (x .getValue ())))
280+ .collect (toImmutableMap (Map .Entry ::getValue , Map .Entry ::getKey ));
227281
228- context .get ().getPartialResultToMask ().putAll (aggregationsToMergeOutput .entrySet ().stream ()
229- .collect (toImmutableMap (x -> x .getValue (), x -> x .getKey ().getMask ().get ())));
230- context .get ().getPartialOutputMapping ().putAll (aggregationsToMergeOutput .entrySet ().stream ()
231- .collect (toImmutableMap (x -> x .getValue (), x -> aggregationsWithoutMaskToOutput .get (removeFilterAndMask (x .getKey ())))));
282+ context .get ().getNewAggregationOutput ().addAll (aggregationsAdded .keySet ());
283+ context .get ().getPartialResultToMask ().putAll (aggregationsWithMaskToMerge .entrySet ().stream ()
284+ .collect (toImmutableMap (Map .Entry ::getValue , x -> x .getKey ().getMask ().get ())));
285+ context .get ().getPartialOutputMapping ().putAll (aggregationsWithMaskToMerge .entrySet ().stream ()
286+ .collect (toImmutableMap (Map .Entry ::getValue , x -> partialAggregationToOutput .get (x .getKey ()))));
232287
233288 Set <VariableReferenceExpression > maskVariables = new HashSet <>(context .get ().getPartialResultToMask ().values ());
234289 if (maskVariables .isEmpty ()) {
@@ -242,14 +297,21 @@ private AggregationNode createPartialAggregationNode(AggregationNode node, PlanN
242297 AggregationNode .GroupingSetDescriptor partialGroupingSetDescriptor = new AggregationNode .GroupingSetDescriptor (
243298 groupingVariables .build (), groupingSetDescriptor .getGroupingSetCount (), groupingSetDescriptor .getGlobalGroupingSets ());
244299
245- Set <VariableReferenceExpression > partialResultToMerge = new HashSet <>(aggregationsToMergeOutput .values ());
246- Map <VariableReferenceExpression , AggregationNode .Aggregation > newAggregations = node .getAggregations ().entrySet ().stream ()
300+ Set <VariableReferenceExpression > partialResultToMerge = new HashSet <>(aggregationsToMergeOutputCombined .values ());
301+ Map <VariableReferenceExpression , AggregationNode .Aggregation > aggregationsRemained = node .getAggregations ().entrySet ().stream ()
247302 .filter (x -> !partialResultToMerge .contains (x .getKey ())).collect (toImmutableMap (Map .Entry ::getKey , Map .Entry ::getValue ));
303+ Map <VariableReferenceExpression , AggregationNode .Aggregation > newAggregations = ImmutableMap .<VariableReferenceExpression , AggregationNode .Aggregation >builder ()
304+ .putAll (aggregationsRemained ).putAll (aggregationsAdded ).build ();
305+
306+ PlanNode newChild = rewrittenSource ;
307+ if (!newMaskAssignments .isEmpty ()) {
308+ newChild = addProjections (newChild , planNodeIdAllocator , newMaskAssignments );
309+ }
248310
249311 return new AggregationNode (
250312 node .getSourceLocation (),
251313 node .getId (),
252- rewrittenSource ,
314+ newChild ,
253315 newAggregations ,
254316 partialGroupingSetDescriptor ,
255317 node .getPreGroupedVariables (),
@@ -265,7 +327,7 @@ private AggregationNode createFinalAggregationNode(AggregationNode node, PlanNod
265327 return (AggregationNode ) node .replaceChildren (ImmutableList .of (rewrittenSource ));
266328 }
267329 List <VariableReferenceExpression > intermediateVariables = node .getAggregations ().values ().stream ()
268- .map (x -> (VariableReferenceExpression ) x .getArguments ().get (0 )).collect (Collectors . toList ());
330+ .map (x -> (VariableReferenceExpression ) x .getArguments ().get (0 )).collect (toImmutableList ());
269331 checkState (intermediateVariables .containsAll (context .get ().partialResultToMask .keySet ()));
270332
271333 ImmutableList .Builder <RowExpression > projectionsFromPartialAgg = ImmutableList .builder ();
@@ -331,6 +393,7 @@ public PlanNode visitProject(ProjectNode node, RewriteContext<Context> context)
331393 .collect (toImmutableMap (Map .Entry ::getKey , Map .Entry ::getValue ));
332394 assignments .putAll (excludeMergedAssignments );
333395 assignments .putAll (identityAssignments (context .get ().getPartialResultToMask ().values ()));
396+ assignments .putAll (identityAssignments (context .get ().getNewAggregationOutput ()));
334397 return new ProjectNode (
335398 node .getSourceLocation (),
336399 node .getId (),
0 commit comments