Skip to content

Commit 20be79d

Browse files
authored
feat: CometExecRule refactor: Unify CometNativeExec creation with Serde in CometOperatorSerde trait (#2768)
1 parent 35a99e0 commit 20be79d

16 files changed

+159
-196
lines changed

spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala

Lines changed: 23 additions & 191 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ import org.apache.spark.sql.types._
4444
import org.apache.comet.{CometConf, ExtendedExplainInfo}
4545
import org.apache.comet.CometConf.COMET_EXEC_SHUFFLE_ENABLED
4646
import org.apache.comet.CometSparkSessionExtensions._
47-
import org.apache.comet.rules.CometExecRule.opSerdeMap
47+
import org.apache.comet.rules.CometExecRule.cometNativeExecHandlers
4848
import org.apache.comet.serde.{CometOperatorSerde, Compatible, Incompatible, OperatorOuterClass, QueryPlanSerde, Unsupported}
4949
import org.apache.comet.serde.OperatorOuterClass.Operator
5050
import org.apache.comet.serde.QueryPlanSerde.{serializeDataType, supportedDataType}
@@ -55,7 +55,7 @@ object CometExecRule {
5555
/**
5656
* Mapping of Spark operator class to Comet operator handler.
5757
*/
58-
val opSerdeMap: Map[Class[_ <: SparkPlan], CometOperatorSerde[_]] =
58+
val cometNativeExecHandlers: Map[Class[_ <: SparkPlan], CometOperatorSerde[_]] =
5959
Map(
6060
classOf[ProjectExec] -> CometProject,
6161
classOf[FilterExec] -> CometFilter,
@@ -183,7 +183,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
183183
// Fully native scan for V1
184184
case scan: CometScanExec if scan.scanImpl == CometConf.SCAN_NATIVE_DATAFUSION =>
185185
val nativeOp = operator2Proto(scan).get
186-
CometNativeScanExec(nativeOp, scan.wrapped, scan.session)
186+
CometNativeScan.createExec(nativeOp, scan)
187187

188188
// Comet JVM + native scan for V1 and V2
189189
case op if isCometScan(op) =>
@@ -195,36 +195,6 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
195195
val nativeOp = operator2Proto(cometOp)
196196
CometScanWrapper(nativeOp.get, cometOp)
197197

198-
case op: ProjectExec =>
199-
newPlanWithProto(
200-
op,
201-
CometProjectExec(_, op, op.output, op.projectList, op.child, SerializedPlan(None)))
202-
203-
case op: FilterExec =>
204-
newPlanWithProto(
205-
op,
206-
CometFilterExec(_, op, op.output, op.condition, op.child, SerializedPlan(None)))
207-
208-
case op: SortExec =>
209-
newPlanWithProto(
210-
op,
211-
CometSortExec(
212-
_,
213-
op,
214-
op.output,
215-
op.outputOrdering,
216-
op.sortOrder,
217-
op.child,
218-
SerializedPlan(None)))
219-
220-
case op: LocalLimitExec =>
221-
newPlanWithProto(op, CometLocalLimitExec(_, op, op.limit, op.child, SerializedPlan(None)))
222-
223-
case op: GlobalLimitExec =>
224-
newPlanWithProto(
225-
op,
226-
CometGlobalLimitExec(_, op, op.limit, op.offset, op.child, SerializedPlan(None)))
227-
228198
case op: CollectLimitExec =>
229199
val fallbackReasons = new ListBuffer[String]()
230200
if (!CometConf.COMET_EXEC_COLLECT_LIMIT_ENABLED.get(conf)) {
@@ -250,116 +220,6 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
250220
}
251221
}
252222

253-
case op: ExpandExec =>
254-
newPlanWithProto(
255-
op,
256-
CometExpandExec(_, op, op.output, op.projections, op.child, SerializedPlan(None)))
257-
258-
case op: HashAggregateExec =>
259-
newPlanWithProto(
260-
op,
261-
nativeOp => {
262-
CometHashAggregateExec(
263-
nativeOp,
264-
op,
265-
op.output,
266-
op.groupingExpressions,
267-
op.aggregateExpressions,
268-
op.resultExpressions,
269-
op.child.output,
270-
op.child,
271-
SerializedPlan(None))
272-
})
273-
274-
case op: ObjectHashAggregateExec =>
275-
newPlanWithProto(
276-
op,
277-
nativeOp => {
278-
CometHashAggregateExec(
279-
nativeOp,
280-
op,
281-
op.output,
282-
op.groupingExpressions,
283-
op.aggregateExpressions,
284-
op.resultExpressions,
285-
op.child.output,
286-
op.child,
287-
SerializedPlan(None))
288-
})
289-
290-
case op: ShuffledHashJoinExec
291-
if CometConf.COMET_EXEC_HASH_JOIN_ENABLED.get(conf) &&
292-
op.children.forall(isCometNative) =>
293-
newPlanWithProto(
294-
op,
295-
CometHashJoinExec(
296-
_,
297-
op,
298-
op.output,
299-
op.outputOrdering,
300-
op.leftKeys,
301-
op.rightKeys,
302-
op.joinType,
303-
op.condition,
304-
op.buildSide,
305-
op.left,
306-
op.right,
307-
SerializedPlan(None)))
308-
309-
case op: ShuffledHashJoinExec if !CometConf.COMET_EXEC_HASH_JOIN_ENABLED.get(conf) =>
310-
withInfo(op, "ShuffleHashJoin is not enabled")
311-
312-
case op: ShuffledHashJoinExec if !op.children.forall(isCometNative) =>
313-
op
314-
315-
case op: BroadcastHashJoinExec
316-
if CometConf.COMET_EXEC_BROADCAST_HASH_JOIN_ENABLED.get(conf) &&
317-
op.children.forall(isCometNative) =>
318-
newPlanWithProto(
319-
op,
320-
CometBroadcastHashJoinExec(
321-
_,
322-
op,
323-
op.output,
324-
op.outputOrdering,
325-
op.leftKeys,
326-
op.rightKeys,
327-
op.joinType,
328-
op.condition,
329-
op.buildSide,
330-
op.left,
331-
op.right,
332-
SerializedPlan(None)))
333-
334-
case op: SortMergeJoinExec
335-
if CometConf.COMET_EXEC_SORT_MERGE_JOIN_ENABLED.get(conf) &&
336-
op.children.forall(isCometNative) =>
337-
newPlanWithProto(
338-
op,
339-
CometSortMergeJoinExec(
340-
_,
341-
op,
342-
op.output,
343-
op.outputOrdering,
344-
op.leftKeys,
345-
op.rightKeys,
346-
op.joinType,
347-
op.condition,
348-
op.left,
349-
op.right,
350-
SerializedPlan(None)))
351-
352-
case op: SortMergeJoinExec
353-
if CometConf.COMET_EXEC_SORT_MERGE_JOIN_ENABLED.get(conf) &&
354-
!op.children.forall(isCometNative) =>
355-
op
356-
357-
case op: SortMergeJoinExec if !CometConf.COMET_EXEC_SORT_MERGE_JOIN_ENABLED.get(conf) =>
358-
withInfo(op, "SortMergeJoin is not enabled")
359-
360-
case op: SortMergeJoinExec if !op.children.forall(isCometNative) =>
361-
op
362-
363223
case c @ CoalesceExec(numPartitions, child)
364224
if CometConf.COMET_EXEC_COALESCE_ENABLED.get(conf)
365225
&& isCometNative(child) =>
@@ -405,19 +265,6 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
405265
"TakeOrderedAndProject requires shuffle to be enabled")
406266
withInfo(s, Seq(info1, info2).flatten.mkString(","))
407267

408-
case w: WindowExec =>
409-
newPlanWithProto(
410-
w,
411-
CometWindowExec(
412-
_,
413-
w,
414-
w.output,
415-
w.windowExpression,
416-
w.partitionSpec,
417-
w.orderSpec,
418-
w.child,
419-
SerializedPlan(None)))
420-
421268
case u: UnionExec
422269
if CometConf.COMET_EXEC_UNION_ENABLED.get(conf) &&
423270
u.children.forall(isCometNative) =>
@@ -476,16 +323,6 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
476323
plan
477324
}
478325

479-
// this case should be checked only after the previous case checking for a
480-
// child BroadcastExchange has been applied, otherwise that transform
481-
// never gets applied
482-
case op: BroadcastHashJoinExec if !op.children.forall(isCometNative) =>
483-
op
484-
485-
case op: BroadcastHashJoinExec
486-
if !CometConf.COMET_EXEC_BROADCAST_HASH_JOIN_ENABLED.get(conf) =>
487-
withInfo(op, "BroadcastHashJoin is not enabled")
488-
489326
// For AQE shuffle stage on a Comet shuffle exchange
490327
case s @ ShuffleQueryStageExec(_, _: CometShuffleExchangeExec, _) =>
491328
newPlanWithProto(s, CometSinkPlaceHolder(_, s, s))
@@ -548,19 +385,28 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
548385
s
549386
}
550387

551-
case op: LocalTableScanExec =>
552-
if (CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.get(conf)) {
553-
operator2Proto(op)
554-
.map { nativeOp =>
555-
val cometOp = CometLocalTableScanExec(op, op.rows, op.output)
556-
CometScanWrapper(nativeOp, cometOp)
388+
case op =>
389+
// check if this is a fully native operator
390+
cometNativeExecHandlers
391+
.get(op.getClass)
392+
.map(_.asInstanceOf[CometOperatorSerde[SparkPlan]]) match {
393+
case Some(handler) =>
394+
if (op.children.forall(isCometNative)) {
395+
if (isOperatorEnabled(handler, op)) {
396+
val builder = OperatorOuterClass.Operator.newBuilder().setPlanId(op.id)
397+
val childOp = op.children.map(_.asInstanceOf[CometNativeExec].nativeOp)
398+
childOp.foreach(builder.addChildren)
399+
return handler
400+
.convert(op, builder, childOp: _*)
401+
.map(handler.createExec(_, op))
402+
.getOrElse(op)
403+
}
404+
} else {
405+
return op
557406
}
558-
.getOrElse(op)
559-
} else {
560-
withInfo(op, "LocalTableScan is not enabled")
407+
case _ =>
561408
}
562409

563-
case op =>
564410
op match {
565411
case _: CometPlan | _: AQEShuffleReadExec | _: BroadcastExchangeExec |
566412
_: BroadcastQueryStageExec | _: AdaptiveSparkPlanExec =>
@@ -1030,20 +876,6 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
1030876
val builder = OperatorOuterClass.Operator.newBuilder().setPlanId(op.id)
1031877
childOp.foreach(builder.addChildren)
1032878

1033-
// look for registered handler first
1034-
val serde = opSerdeMap.get(op.getClass)
1035-
serde match {
1036-
case Some(handler) if isOperatorEnabled(handler, op) =>
1037-
val opSerde = handler.asInstanceOf[CometOperatorSerde[SparkPlan]]
1038-
val maybeConverted = opSerde.convert(op, builder, childOp: _*)
1039-
if (maybeConverted.isDefined) {
1040-
return maybeConverted
1041-
}
1042-
case _ =>
1043-
}
1044-
1045-
// now handle special cases that cannot be handled as a simple mapping from class name
1046-
// and see if operator can be used as a sink
1047879
op match {
1048880

1049881
// Fully native scan for V1
@@ -1108,7 +940,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
1108940
// Emit warning if:
1109941
// 1. it is not Spark shuffle operator, which is handled separately
1110942
// 2. it is not a Comet operator
1111-
if (serde.isEmpty && !op.nodeName.contains("Comet") &&
943+
if (!op.nodeName.contains("Comet") &&
1112944
!op.isInstanceOf[ShuffleExchangeExec]) {
1113945
withInfo(op, s"unsupported Spark operator: ${op.nodeName}")
1114946
}

spark/src/main/scala/org/apache/comet/serde/CometOperatorSerde.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
package org.apache.comet.serde
2121

22+
import org.apache.spark.sql.comet.CometNativeExec
2223
import org.apache.spark.sql.execution.SparkPlan
2324

2425
import org.apache.comet.ConfigEntry
@@ -65,4 +66,5 @@ trait CometOperatorSerde[T <: SparkPlan] {
6566
builder: Operator.Builder,
6667
childOp: Operator*): Option[OperatorOuterClass.Operator]
6768

69+
def createExec(nativeOp: Operator, op: T): CometNativeExec
6870
}

spark/src/main/scala/org/apache/comet/serde/operator/CometAggregate.scala

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import scala.jdk.CollectionConverters._
2323

2424
import org.apache.spark.sql.catalyst.expressions.Expression
2525
import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial}
26-
import org.apache.spark.sql.comet.CometHashAggregateExec
26+
import org.apache.spark.sql.comet.{CometHashAggregateExec, CometNativeExec, SerializedPlan}
2727
import org.apache.spark.sql.execution.SparkPlan
2828
import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, ShuffleQueryStageExec}
2929
import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec}
@@ -207,6 +207,19 @@ object CometHashAggregate extends CometOperatorSerde[HashAggregateExec] with Com
207207
childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = {
208208
doConvert(aggregate, builder, childOp: _*)
209209
}
210+
211+
override def createExec(nativeOp: Operator, op: HashAggregateExec): CometNativeExec = {
212+
CometHashAggregateExec(
213+
nativeOp,
214+
op,
215+
op.output,
216+
op.groupingExpressions,
217+
op.aggregateExpressions,
218+
op.resultExpressions,
219+
op.child.output,
220+
op.child,
221+
SerializedPlan(None))
222+
}
210223
}
211224

212225
object CometObjectHashAggregate
@@ -230,4 +243,17 @@ object CometObjectHashAggregate
230243

231244
doConvert(aggregate, builder, childOp: _*)
232245
}
246+
247+
override def createExec(nativeOp: Operator, op: ObjectHashAggregateExec): CometNativeExec = {
248+
CometHashAggregateExec(
249+
nativeOp,
250+
op,
251+
op.output,
252+
op.groupingExpressions,
253+
op.aggregateExpressions,
254+
op.resultExpressions,
255+
op.child.output,
256+
op.child,
257+
SerializedPlan(None))
258+
}
233259
}

spark/src/main/scala/org/apache/comet/serde/operator/CometExpand.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ package org.apache.comet.serde.operator
2222
import scala.jdk.CollectionConverters._
2323

2424
import org.apache.spark.sql.catalyst.expressions.Expression
25+
import org.apache.spark.sql.comet.{CometExpandExec, CometNativeExec, SerializedPlan}
2526
import org.apache.spark.sql.execution.ExpandExec
2627

2728
import org.apache.comet.{CometConf, ConfigEntry}
@@ -55,7 +56,9 @@ object CometExpand extends CometOperatorSerde[ExpandExec] {
5556
withInfo(op, allProjExprs: _*)
5657
None
5758
}
58-
5959
}
6060

61+
override def createExec(nativeOp: Operator, op: ExpandExec): CometNativeExec = {
62+
CometExpandExec(nativeOp, op, op.output, op.projections, op.child, SerializedPlan(None))
63+
}
6164
}

spark/src/main/scala/org/apache/comet/serde/operator/CometFilter.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
package org.apache.comet.serde.operator
2121

22+
import org.apache.spark.sql.comet.{CometFilterExec, CometNativeExec, SerializedPlan}
2223
import org.apache.spark.sql.execution.FilterExec
2324

2425
import org.apache.comet.{CometConf, ConfigEntry}
@@ -49,4 +50,7 @@ object CometFilter extends CometOperatorSerde[FilterExec] {
4950
}
5051
}
5152

53+
override def createExec(nativeOp: Operator, op: FilterExec): CometNativeExec = {
54+
CometFilterExec(nativeOp, op, op.output, op.condition, op.child, SerializedPlan(None))
55+
}
5256
}

0 commit comments

Comments
 (0)