@@ -23,7 +23,6 @@ import scala.collection.mutable.ListBuffer
2323
2424import org .apache .spark .sql .SparkSession
2525import org .apache .spark .sql .catalyst .expressions .{Divide , DoubleLiteral , EqualNullSafe , EqualTo , Expression , FloatLiteral , GreaterThan , GreaterThanOrEqual , KnownFloatingPointNormalized , LessThan , LessThanOrEqual , NamedExpression , Remainder }
26- import org .apache .spark .sql .catalyst .expressions .aggregate .{Final , Partial }
2726import org .apache .spark .sql .catalyst .optimizer .NormalizeNaNAndZero
2827import org .apache .spark .sql .catalyst .plans .physical .{HashPartitioning , RangePartitioning , RoundRobinPartitioning , SinglePartition }
2928import org .apache .spark .sql .catalyst .rules .Rule
@@ -32,7 +31,7 @@ import org.apache.spark.sql.comet._
3231import org .apache .spark .sql .comet .execution .shuffle .{CometColumnarShuffle , CometNativeShuffle , CometShuffleExchangeExec , CometShuffleManager }
3332import org .apache .spark .sql .execution ._
3433import org .apache .spark .sql .execution .adaptive .{AdaptiveSparkPlanExec , AQEShuffleReadExec , BroadcastQueryStageExec , ShuffleQueryStageExec }
35- import org .apache .spark .sql .execution .aggregate .{BaseAggregateExec , HashAggregateExec , ObjectHashAggregateExec }
34+ import org .apache .spark .sql .execution .aggregate .{HashAggregateExec , ObjectHashAggregateExec }
3635import org .apache .spark .sql .execution .command .ExecutedCommandExec
3736import org .apache .spark .sql .execution .datasources .v2 .V2CommandExec
3837import org .apache .spark .sql .execution .exchange .{BroadcastExchangeExec , ReusedExchangeExec , ShuffleExchangeExec }
@@ -232,44 +231,37 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
232231 op,
233232 CometExpandExec (_, op, op.output, op.projections, op.child, SerializedPlan (None )))
234233
235- // When Comet shuffle is disabled, we don't want to transform the HashAggregate
236- // to CometHashAggregate. Otherwise, we probably get partial Comet aggregation
237- // and final Spark aggregation.
238- case op : BaseAggregateExec
239- if op.isInstanceOf [HashAggregateExec ] ||
240- op.isInstanceOf [ObjectHashAggregateExec ] &&
241- isCometShuffleEnabled(conf) =>
242- val modes = op.aggregateExpressions.map(_.mode).distinct
243- // In distinct aggregates there can be a combination of modes
244- val multiMode = modes.size > 1
245- // For a final mode HashAggregate, we only need to transform the HashAggregate
246- // if there is Comet partial aggregation.
247- val sparkFinalMode = modes.contains(Final ) && findCometPartialAgg(op.child).isEmpty
248-
249- if (multiMode || sparkFinalMode) {
250- op
251- } else {
252- newPlanWithProto(
253- op,
254- nativeOp => {
255- // The aggExprs could be empty. For example, if the aggregate functions only have
256- // distinct aggregate functions or only have group by, the aggExprs is empty and
257- // modes is empty too. If aggExprs is not empty, we need to verify all the
258- // aggregates have the same mode.
259- assert(modes.length == 1 || modes.isEmpty)
260- CometHashAggregateExec (
261- nativeOp,
262- op,
263- op.output,
264- op.groupingExpressions,
265- op.aggregateExpressions,
266- op.resultExpressions,
267- op.child.output,
268- modes.headOption,
269- op.child,
270- SerializedPlan (None ))
271- })
272- }
234+ case op : HashAggregateExec =>
235+ newPlanWithProto(
236+ op,
237+ nativeOp => {
238+ CometHashAggregateExec (
239+ nativeOp,
240+ op,
241+ op.output,
242+ op.groupingExpressions,
243+ op.aggregateExpressions,
244+ op.resultExpressions,
245+ op.child.output,
246+ op.child,
247+ SerializedPlan (None ))
248+ })
249+
250+ case op : ObjectHashAggregateExec =>
251+ newPlanWithProto(
252+ op,
253+ nativeOp => {
254+ CometHashAggregateExec (
255+ nativeOp,
256+ op,
257+ op.output,
258+ op.groupingExpressions,
259+ op.aggregateExpressions,
260+ op.resultExpressions,
261+ op.child.output,
262+ op.child,
263+ SerializedPlan (None ))
264+ })
273265
274266 case op : ShuffledHashJoinExec
275267 if CometConf .COMET_EXEC_HASH_JOIN_ENABLED .get(conf) &&
@@ -738,22 +730,6 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
738730 }
739731 }
740732
741- /**
742- * Find the first Comet partial aggregate in the plan. If it reaches a Spark HashAggregate with
743- * partial mode, it will return None.
744- */
745- private def findCometPartialAgg (plan : SparkPlan ): Option [CometHashAggregateExec ] = {
746- plan.collectFirst {
747- case agg : CometHashAggregateExec if agg.aggregateExpressions.forall(_.mode == Partial ) =>
748- Some (agg)
749- case agg : HashAggregateExec if agg.aggregateExpressions.forall(_.mode == Partial ) => None
750- case agg : ObjectHashAggregateExec if agg.aggregateExpressions.forall(_.mode == Partial ) =>
751- None
752- case a : AQEShuffleReadExec => findCometPartialAgg(a.child)
753- case s : ShuffleQueryStageExec => findCometPartialAgg(s.plan)
754- }.flatten
755- }
756-
757733 /**
758734 * Returns true if a given spark plan is Comet shuffle operator.
759735 */
0 commit comments