@@ -31,7 +31,7 @@ import org.apache.spark.broadcast.Broadcast
3131import org .apache .spark .rdd .RDD
3232import org .apache .spark .sql .catalyst .InternalRow
3333import org .apache .spark .sql .catalyst .expressions .{Ascending , Attribute , AttributeSet , Expression , ExpressionSet , Generator , NamedExpression , SortOrder }
34- import org .apache .spark .sql .catalyst .expressions .aggregate .{AggregateExpression , AggregateMode , Final , Partial , PartialMerge }
34+ import org .apache .spark .sql .catalyst .expressions .aggregate .{AggregateExpression , AggregateMode , Complete , Final , Partial , PartialMerge }
3535import org .apache .spark .sql .catalyst .optimizer .{BuildLeft , BuildRight , BuildSide }
3636import org .apache .spark .sql .catalyst .plans ._
3737import org .apache .spark .sql .catalyst .plans .physical ._
@@ -1066,12 +1066,11 @@ trait CometBaseAggregate {
10661066
10671067 val modes = aggregate.aggregateExpressions.map(_.mode).distinct
10681068 // In distinct aggregates there can be a combination of modes
1069- val multiMode = modes.size > 1
10701069 // For a final mode HashAggregate, we only need to transform the HashAggregate
10711070 // if there is Comet partial aggregation.
10721071 val sparkFinalMode = modes.contains(Final ) && findCometPartialAgg(aggregate.child).isEmpty
10731072
1074- if (multiMode || sparkFinalMode) {
1073+ if (sparkFinalMode) {
10751074 return None
10761075 }
10771076
@@ -1144,33 +1143,34 @@ trait CometBaseAggregate {
11441143 hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava)
11451144 Some (builder.setHashAgg(hashAggBuilder).build())
11461145 } else {
1147- val modes = aggregateExpressions.map(_.mode).distinct
1148-
1149- if (modes.size != 1 ) {
1150- // This shouldn't happen as all aggregation expressions should share the same mode.
1151- // Fallback to Spark nevertheless here.
1152- withInfo(aggregate, " All aggregate expressions do not have the same mode" )
1153- return None
1154- }
1155-
1156- val mode = modes.head match {
1157- case Partial => CometAggregateMode .Partial
1158- case Final => CometAggregateMode .Final
1159- case _ =>
1160- withInfo(aggregate, s " Unsupported aggregation mode ${modes.head}" )
1161- return None
1162- }
1146+ // `output` is only used when `binding` is true (i.e., non-Final)
1147+ val output = child.output
11631148
11641149 // In final mode, the aggregate expressions are bound to the output of the
11651150 // child and partial aggregate expressions buffer attributes produced by partial
11661151 // aggregation. This is done in Spark `HashAggregateExec` internally. In Comet,
11671152 // we don't have to do this because we don't use the merging expression.
1168- val binding = mode != CometAggregateMode .Final
1169- // `output` is only used when `binding` is true (i.e., non-Final)
1170- val output = child.output
1171-
1172- val aggExprs =
1173- aggregateExpressions.map(aggExprToProto(_, output, binding, aggregate.conf))
1153+ //
1154+ // It is possible to have multiple modes for queries with DISTINCT agg expression
1155+ // So Spark can Partial and PartialMerge at the same time
1156+ val (aggExprs, aggModes) =
1157+ aggregateExpressions
1158+ .map(a =>
1159+ (
1160+ aggExprToProto(
1161+ a,
1162+ output,
1163+ a.mode != PartialMerge && a.mode != Final ,
1164+ aggregate.conf),
1165+ a.mode match {
1166+ case Partial => CometAggregateMode .Partial
1167+ case PartialMerge => CometAggregateMode .PartialMerge
1168+ case Final => CometAggregateMode .Final
1169+ case mode =>
1170+ withInfo(aggregate, s " Unsupported Aggregation Mode $mode" )
1171+ return None
1172+ }))
1173+ .unzip
11741174
11751175 if (aggExprs.exists(_.isEmpty)) {
11761176 withInfo(
@@ -1185,7 +1185,9 @@ trait CometBaseAggregate {
11851185 val hashAggBuilder = OperatorOuterClass .HashAggregate .newBuilder()
11861186 hashAggBuilder.addAllGroupingExprs(groupingExprs.map(_.get).asJava)
11871187 hashAggBuilder.addAllAggExprs(aggExprs.map(_.get).asJava)
1188- if (mode == CometAggregateMode .Final ) {
1188+ // Spark sending Final separately only,
1189+ // so if any entry is Final means everything else is also Final
1190+ if (modes.contains(Final )) {
11891191 val attributes = groupingExpressions.map(_.toAttribute) ++ aggregateAttributes
11901192 val resultExprs = resultExpressions.map(exprToProto(_, attributes))
11911193 if (resultExprs.exists(_.isEmpty)) {
@@ -1197,7 +1199,7 @@ trait CometBaseAggregate {
11971199 }
11981200 hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava)
11991201 }
1200- hashAggBuilder.setModeValue(mode.getNumber )
1202+ hashAggBuilder.addAllMode(aggModes.asJava )
12011203 Some (builder.setHashAgg(hashAggBuilder).build())
12021204 } else {
12031205 val allChildren : Seq [Expression ] =
@@ -1323,8 +1325,6 @@ case class CometHashAggregateExec(
13231325 // modes is empty too. If aggExprs is not empty, we need to verify all the
13241326 // aggregates have the same mode.
13251327 val modes : Seq [AggregateMode ] = aggregateExpressions.map(_.mode).distinct
1326- assert(modes.length == 1 || modes.isEmpty)
1327- val mode = modes.headOption
13281328
13291329 override def producedAttributes : AttributeSet = outputSet ++ AttributeSet (resultExpressions)
13301330
@@ -1341,7 +1341,7 @@ case class CometHashAggregateExec(
13411341 }
13421342
13431343 override def stringArgs : Iterator [Any ] =
1344- Iterator (input, mode , groupingExpressions, aggregateExpressions, child)
1344+ Iterator (input, modes , groupingExpressions, aggregateExpressions, child)
13451345
13461346 override def equals (obj : Any ): Boolean = {
13471347 obj match {
@@ -1350,7 +1350,7 @@ case class CometHashAggregateExec(
13501350 this .groupingExpressions == other.groupingExpressions &&
13511351 this .aggregateExpressions == other.aggregateExpressions &&
13521352 this .input == other.input &&
1353- this .mode == other.mode &&
1353+ this .modes == other.modes &&
13541354 this .child == other.child &&
13551355 this .serializedPlanOpt == other.serializedPlanOpt
13561356 case _ =>
@@ -1359,7 +1359,7 @@ case class CometHashAggregateExec(
13591359 }
13601360
13611361 override def hashCode (): Int =
1362- Objects .hashCode(output, groupingExpressions, aggregateExpressions, input, mode , child)
1362+ Objects .hashCode(output, groupingExpressions, aggregateExpressions, input, modes , child)
13631363
13641364 override protected def outputExpressions : Seq [NamedExpression ] = resultExpressions
13651365}
0 commit comments