Skip to content

Commit 0d63bc1

Browse files
authored
minor: Small refactor for consistent serde for hash aggregate (apache#2764)
1 parent fb37d9a commit 0d63bc1

File tree

3 files changed

+81
-58
lines changed

3 files changed

+81
-58
lines changed

spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala

Lines changed: 32 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ import scala.collection.mutable.ListBuffer
2323

2424
import org.apache.spark.sql.SparkSession
2525
import org.apache.spark.sql.catalyst.expressions.{Divide, DoubleLiteral, EqualNullSafe, EqualTo, Expression, FloatLiteral, GreaterThan, GreaterThanOrEqual, KnownFloatingPointNormalized, LessThan, LessThanOrEqual, NamedExpression, Remainder}
26-
import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial}
2726
import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero
2827
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, RangePartitioning, RoundRobinPartitioning, SinglePartition}
2928
import org.apache.spark.sql.catalyst.rules.Rule
@@ -32,7 +31,7 @@ import org.apache.spark.sql.comet._
3231
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle, CometShuffleExchangeExec, CometShuffleManager}
3332
import org.apache.spark.sql.execution._
3433
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec}
35-
import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec}
34+
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec}
3635
import org.apache.spark.sql.execution.command.ExecutedCommandExec
3736
import org.apache.spark.sql.execution.datasources.v2.V2CommandExec
3837
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
@@ -232,44 +231,37 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
232231
op,
233232
CometExpandExec(_, op, op.output, op.projections, op.child, SerializedPlan(None)))
234233

235-
// When Comet shuffle is disabled, we don't want to transform the HashAggregate
236-
// to CometHashAggregate. Otherwise, we probably get partial Comet aggregation
237-
// and final Spark aggregation.
238-
case op: BaseAggregateExec
239-
if op.isInstanceOf[HashAggregateExec] ||
240-
op.isInstanceOf[ObjectHashAggregateExec] &&
241-
isCometShuffleEnabled(conf) =>
242-
val modes = op.aggregateExpressions.map(_.mode).distinct
243-
// In distinct aggregates there can be a combination of modes
244-
val multiMode = modes.size > 1
245-
// For a final mode HashAggregate, we only need to transform the HashAggregate
246-
// if there is Comet partial aggregation.
247-
val sparkFinalMode = modes.contains(Final) && findCometPartialAgg(op.child).isEmpty
248-
249-
if (multiMode || sparkFinalMode) {
250-
op
251-
} else {
252-
newPlanWithProto(
253-
op,
254-
nativeOp => {
255-
// The aggExprs could be empty. For example, if the aggregate functions only have
256-
// distinct aggregate functions or only have group by, the aggExprs is empty and
257-
// modes is empty too. If aggExprs is not empty, we need to verify all the
258-
// aggregates have the same mode.
259-
assert(modes.length == 1 || modes.isEmpty)
260-
CometHashAggregateExec(
261-
nativeOp,
262-
op,
263-
op.output,
264-
op.groupingExpressions,
265-
op.aggregateExpressions,
266-
op.resultExpressions,
267-
op.child.output,
268-
modes.headOption,
269-
op.child,
270-
SerializedPlan(None))
271-
})
272-
}
234+
case op: HashAggregateExec =>
235+
newPlanWithProto(
236+
op,
237+
nativeOp => {
238+
CometHashAggregateExec(
239+
nativeOp,
240+
op,
241+
op.output,
242+
op.groupingExpressions,
243+
op.aggregateExpressions,
244+
op.resultExpressions,
245+
op.child.output,
246+
op.child,
247+
SerializedPlan(None))
248+
})
249+
250+
case op: ObjectHashAggregateExec =>
251+
newPlanWithProto(
252+
op,
253+
nativeOp => {
254+
CometHashAggregateExec(
255+
nativeOp,
256+
op,
257+
op.output,
258+
op.groupingExpressions,
259+
op.aggregateExpressions,
260+
op.resultExpressions,
261+
op.child.output,
262+
op.child,
263+
SerializedPlan(None))
264+
})
273265

274266
case op: ShuffledHashJoinExec
275267
if CometConf.COMET_EXEC_HASH_JOIN_ENABLED.get(conf) &&
@@ -738,22 +730,6 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
738730
}
739731
}
740732

741-
/**
742-
* Find the first Comet partial aggregate in the plan. If it reaches a Spark HashAggregate with
743-
* partial mode, it will return None.
744-
*/
745-
private def findCometPartialAgg(plan: SparkPlan): Option[CometHashAggregateExec] = {
746-
plan.collectFirst {
747-
case agg: CometHashAggregateExec if agg.aggregateExpressions.forall(_.mode == Partial) =>
748-
Some(agg)
749-
case agg: HashAggregateExec if agg.aggregateExpressions.forall(_.mode == Partial) => None
750-
case agg: ObjectHashAggregateExec if agg.aggregateExpressions.forall(_.mode == Partial) =>
751-
None
752-
case a: AQEShuffleReadExec => findCometPartialAgg(a.child)
753-
case s: ShuffleQueryStageExec => findCometPartialAgg(s.plan)
754-
}.flatten
755-
}
756-
757733
/**
758734
* Returns true if a given spark plan is Comet shuffle operator.
759735
*/

spark/src/main/scala/org/apache/comet/serde/operator/CometAggregate.scala

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,14 @@ import scala.jdk.CollectionConverters._
2323

2424
import org.apache.spark.sql.catalyst.expressions.Expression
2525
import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial}
26+
import org.apache.spark.sql.comet.CometHashAggregateExec
27+
import org.apache.spark.sql.execution.SparkPlan
28+
import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, ShuffleQueryStageExec}
2629
import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec}
2730
import org.apache.spark.sql.types.MapType
2831

2932
import org.apache.comet.{CometConf, ConfigEntry}
30-
import org.apache.comet.CometSparkSessionExtensions.withInfo
33+
import org.apache.comet.CometSparkSessionExtensions.{isCometShuffleEnabled, withInfo}
3134
import org.apache.comet.serde.{CometOperatorSerde, OperatorOuterClass}
3235
import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, Operator}
3336
import org.apache.comet.serde.QueryPlanSerde.{aggExprToProto, exprToProto}
@@ -38,6 +41,18 @@ trait CometBaseAggregate {
3841
aggregate: BaseAggregateExec,
3942
builder: Operator.Builder,
4043
childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = {
44+
45+
val modes = aggregate.aggregateExpressions.map(_.mode).distinct
46+
// In distinct aggregates there can be a combination of modes
47+
val multiMode = modes.size > 1
48+
// For a final mode HashAggregate, we only need to transform the HashAggregate
49+
// if there is Comet partial aggregation.
50+
val sparkFinalMode = modes.contains(Final) && findCometPartialAgg(aggregate.child).isEmpty
51+
52+
if (multiMode || sparkFinalMode) {
53+
return None
54+
}
55+
4156
val groupingExpressions = aggregate.groupingExpressions
4257
val aggregateExpressions = aggregate.aggregateExpressions
4358
val aggregateAttributes = aggregate.aggregateAttributes
@@ -163,6 +178,22 @@ trait CometBaseAggregate {
163178

164179
}
165180

181+
/**
182+
* Find the first Comet partial aggregate in the plan. If it reaches a Spark HashAggregate with
183+
* partial mode, it will return None.
184+
*/
185+
private def findCometPartialAgg(plan: SparkPlan): Option[CometHashAggregateExec] = {
186+
plan.collectFirst {
187+
case agg: CometHashAggregateExec if agg.aggregateExpressions.forall(_.mode == Partial) =>
188+
Some(agg)
189+
case agg: HashAggregateExec if agg.aggregateExpressions.forall(_.mode == Partial) => None
190+
case agg: ObjectHashAggregateExec if agg.aggregateExpressions.forall(_.mode == Partial) =>
191+
None
192+
case a: AQEShuffleReadExec => findCometPartialAgg(a.child)
193+
case s: ShuffleQueryStageExec => findCometPartialAgg(s.plan)
194+
}.flatten
195+
}
196+
166197
}
167198

168199
object CometHashAggregate extends CometOperatorSerde[HashAggregateExec] with CometBaseAggregate {
@@ -189,6 +220,14 @@ object CometObjectHashAggregate
189220
aggregate: ObjectHashAggregateExec,
190221
builder: Operator.Builder,
191222
childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = {
223+
224+
if (!isCometShuffleEnabled(aggregate.conf)) {
225+
// When Comet shuffle is disabled, we don't want to transform the HashAggregate
226+
// to CometHashAggregate. Otherwise, we probably get partial Comet aggregation
227+
// and final Spark aggregation.
228+
return None
229+
}
230+
192231
doConvert(aggregate, builder, childOp: _*)
193232
}
194233
}

spark/src/main/scala/org/apache/spark/sql/comet/operators.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -739,11 +739,19 @@ case class CometHashAggregateExec(
739739
aggregateExpressions: Seq[AggregateExpression],
740740
resultExpressions: Seq[NamedExpression],
741741
input: Seq[Attribute],
742-
mode: Option[AggregateMode],
743742
child: SparkPlan,
744743
override val serializedPlanOpt: SerializedPlan)
745744
extends CometUnaryExec
746745
with PartitioningPreservingUnaryExecNode {
746+
747+
// The aggExprs could be empty. For example, if the aggregate functions only have
748+
// distinct aggregate functions or only have group by, the aggExprs is empty and
749+
// modes is empty too. If aggExprs is not empty, we need to verify all the
750+
// aggregates have the same mode.
751+
val modes: Seq[AggregateMode] = aggregateExpressions.map(_.mode).distinct
752+
assert(modes.length == 1 || modes.isEmpty)
753+
val mode = modes.headOption
754+
747755
override def producedAttributes: AttributeSet = outputSet ++ AttributeSet(resultExpressions)
748756

749757
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =

0 commit comments

Comments
 (0)