@@ -44,7 +44,7 @@ import org.apache.spark.sql.types._
4444import org .apache .comet .{CometConf , ExtendedExplainInfo }
4545import org .apache .comet .CometConf .COMET_EXEC_SHUFFLE_ENABLED
4646import org .apache .comet .CometSparkSessionExtensions ._
47- import org .apache .comet .rules .CometExecRule .opSerdeMap
47+ import org .apache .comet .rules .CometExecRule .cometNativeExecHandlers
4848import org .apache .comet .serde .{CometOperatorSerde , Compatible , Incompatible , OperatorOuterClass , QueryPlanSerde , Unsupported }
4949import org .apache .comet .serde .OperatorOuterClass .Operator
5050import org .apache .comet .serde .QueryPlanSerde .{serializeDataType , supportedDataType }
@@ -55,7 +55,7 @@ object CometExecRule {
5555 /**
5656 * Mapping of Spark operator class to Comet operator handler.
5757 */
58- val opSerdeMap : Map [Class [_ <: SparkPlan ], CometOperatorSerde [_]] =
58+ val cometNativeExecHandlers : Map [Class [_ <: SparkPlan ], CometOperatorSerde [_]] =
5959 Map (
6060 classOf [ProjectExec ] -> CometProject ,
6161 classOf [FilterExec ] -> CometFilter ,
@@ -183,7 +183,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
183183 // Fully native scan for V1
184184 case scan : CometScanExec if scan.scanImpl == CometConf .SCAN_NATIVE_DATAFUSION =>
185185 val nativeOp = operator2Proto(scan).get
186- CometNativeScanExec (nativeOp, scan.wrapped, scan.session )
186+ CometNativeScan .createExec (nativeOp, scan)
187187
188188 // Comet JVM + native scan for V1 and V2
189189 case op if isCometScan(op) =>
@@ -195,36 +195,6 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
195195 val nativeOp = operator2Proto(cometOp)
196196 CometScanWrapper (nativeOp.get, cometOp)
197197
198- case op : ProjectExec =>
199- newPlanWithProto(
200- op,
201- CometProjectExec (_, op, op.output, op.projectList, op.child, SerializedPlan (None )))
202-
203- case op : FilterExec =>
204- newPlanWithProto(
205- op,
206- CometFilterExec (_, op, op.output, op.condition, op.child, SerializedPlan (None )))
207-
208- case op : SortExec =>
209- newPlanWithProto(
210- op,
211- CometSortExec (
212- _,
213- op,
214- op.output,
215- op.outputOrdering,
216- op.sortOrder,
217- op.child,
218- SerializedPlan (None )))
219-
220- case op : LocalLimitExec =>
221- newPlanWithProto(op, CometLocalLimitExec (_, op, op.limit, op.child, SerializedPlan (None )))
222-
223- case op : GlobalLimitExec =>
224- newPlanWithProto(
225- op,
226- CometGlobalLimitExec (_, op, op.limit, op.offset, op.child, SerializedPlan (None )))
227-
228198 case op : CollectLimitExec =>
229199 val fallbackReasons = new ListBuffer [String ]()
230200 if (! CometConf .COMET_EXEC_COLLECT_LIMIT_ENABLED .get(conf)) {
@@ -250,116 +220,6 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
250220 }
251221 }
252222
253- case op : ExpandExec =>
254- newPlanWithProto(
255- op,
256- CometExpandExec (_, op, op.output, op.projections, op.child, SerializedPlan (None )))
257-
258- case op : HashAggregateExec =>
259- newPlanWithProto(
260- op,
261- nativeOp => {
262- CometHashAggregateExec (
263- nativeOp,
264- op,
265- op.output,
266- op.groupingExpressions,
267- op.aggregateExpressions,
268- op.resultExpressions,
269- op.child.output,
270- op.child,
271- SerializedPlan (None ))
272- })
273-
274- case op : ObjectHashAggregateExec =>
275- newPlanWithProto(
276- op,
277- nativeOp => {
278- CometHashAggregateExec (
279- nativeOp,
280- op,
281- op.output,
282- op.groupingExpressions,
283- op.aggregateExpressions,
284- op.resultExpressions,
285- op.child.output,
286- op.child,
287- SerializedPlan (None ))
288- })
289-
290- case op : ShuffledHashJoinExec
291- if CometConf .COMET_EXEC_HASH_JOIN_ENABLED .get(conf) &&
292- op.children.forall(isCometNative) =>
293- newPlanWithProto(
294- op,
295- CometHashJoinExec (
296- _,
297- op,
298- op.output,
299- op.outputOrdering,
300- op.leftKeys,
301- op.rightKeys,
302- op.joinType,
303- op.condition,
304- op.buildSide,
305- op.left,
306- op.right,
307- SerializedPlan (None )))
308-
309- case op : ShuffledHashJoinExec if ! CometConf .COMET_EXEC_HASH_JOIN_ENABLED .get(conf) =>
310- withInfo(op, " ShuffleHashJoin is not enabled" )
311-
312- case op : ShuffledHashJoinExec if ! op.children.forall(isCometNative) =>
313- op
314-
315- case op : BroadcastHashJoinExec
316- if CometConf .COMET_EXEC_BROADCAST_HASH_JOIN_ENABLED .get(conf) &&
317- op.children.forall(isCometNative) =>
318- newPlanWithProto(
319- op,
320- CometBroadcastHashJoinExec (
321- _,
322- op,
323- op.output,
324- op.outputOrdering,
325- op.leftKeys,
326- op.rightKeys,
327- op.joinType,
328- op.condition,
329- op.buildSide,
330- op.left,
331- op.right,
332- SerializedPlan (None )))
333-
334- case op : SortMergeJoinExec
335- if CometConf .COMET_EXEC_SORT_MERGE_JOIN_ENABLED .get(conf) &&
336- op.children.forall(isCometNative) =>
337- newPlanWithProto(
338- op,
339- CometSortMergeJoinExec (
340- _,
341- op,
342- op.output,
343- op.outputOrdering,
344- op.leftKeys,
345- op.rightKeys,
346- op.joinType,
347- op.condition,
348- op.left,
349- op.right,
350- SerializedPlan (None )))
351-
352- case op : SortMergeJoinExec
353- if CometConf .COMET_EXEC_SORT_MERGE_JOIN_ENABLED .get(conf) &&
354- ! op.children.forall(isCometNative) =>
355- op
356-
357- case op : SortMergeJoinExec if ! CometConf .COMET_EXEC_SORT_MERGE_JOIN_ENABLED .get(conf) =>
358- withInfo(op, " SortMergeJoin is not enabled" )
359-
360- case op : SortMergeJoinExec if ! op.children.forall(isCometNative) =>
361- op
362-
363223 case c @ CoalesceExec (numPartitions, child)
364224 if CometConf .COMET_EXEC_COALESCE_ENABLED .get(conf)
365225 && isCometNative(child) =>
@@ -405,19 +265,6 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
405265 " TakeOrderedAndProject requires shuffle to be enabled" )
406266 withInfo(s, Seq (info1, info2).flatten.mkString(" ," ))
407267
408- case w : WindowExec =>
409- newPlanWithProto(
410- w,
411- CometWindowExec (
412- _,
413- w,
414- w.output,
415- w.windowExpression,
416- w.partitionSpec,
417- w.orderSpec,
418- w.child,
419- SerializedPlan (None )))
420-
421268 case u : UnionExec
422269 if CometConf .COMET_EXEC_UNION_ENABLED .get(conf) &&
423270 u.children.forall(isCometNative) =>
@@ -476,16 +323,6 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
476323 plan
477324 }
478325
479- // this case should be checked only after the previous case checking for a
480- // child BroadcastExchange has been applied, otherwise that transform
481- // never gets applied
482- case op : BroadcastHashJoinExec if ! op.children.forall(isCometNative) =>
483- op
484-
485- case op : BroadcastHashJoinExec
486- if ! CometConf .COMET_EXEC_BROADCAST_HASH_JOIN_ENABLED .get(conf) =>
487- withInfo(op, " BroadcastHashJoin is not enabled" )
488-
489326 // For AQE shuffle stage on a Comet shuffle exchange
490327 case s @ ShuffleQueryStageExec (_, _ : CometShuffleExchangeExec , _) =>
491328 newPlanWithProto(s, CometSinkPlaceHolder (_, s, s))
@@ -548,19 +385,28 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
548385 s
549386 }
550387
551- case op : LocalTableScanExec =>
552- if (CometConf .COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED .get(conf)) {
553- operator2Proto(op)
554- .map { nativeOp =>
555- val cometOp = CometLocalTableScanExec (op, op.rows, op.output)
556- CometScanWrapper (nativeOp, cometOp)
388+ case op =>
389+ // check if this is a fully native operator
390+ cometNativeExecHandlers
391+ .get(op.getClass)
392+ .map(_.asInstanceOf [CometOperatorSerde [SparkPlan ]]) match {
393+ case Some (handler) =>
394+ if (op.children.forall(isCometNative)) {
395+ if (isOperatorEnabled(handler, op)) {
396+ val builder = OperatorOuterClass .Operator .newBuilder().setPlanId(op.id)
397+ val childOp = op.children.map(_.asInstanceOf [CometNativeExec ].nativeOp)
398+ childOp.foreach(builder.addChildren)
399+ return handler
400+ .convert(op, builder, childOp : _* )
401+ .map(handler.createExec(_, op))
402+ .getOrElse(op)
403+ }
404+ } else {
405+ return op
557406 }
558- .getOrElse(op)
559- } else {
560- withInfo(op, " LocalTableScan is not enabled" )
407+ case _ =>
561408 }
562409
563- case op =>
564410 op match {
565411 case _ : CometPlan | _ : AQEShuffleReadExec | _ : BroadcastExchangeExec |
566412 _ : BroadcastQueryStageExec | _ : AdaptiveSparkPlanExec =>
@@ -1030,20 +876,6 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
1030876 val builder = OperatorOuterClass .Operator .newBuilder().setPlanId(op.id)
1031877 childOp.foreach(builder.addChildren)
1032878
1033- // look for registered handler first
1034- val serde = opSerdeMap.get(op.getClass)
1035- serde match {
1036- case Some (handler) if isOperatorEnabled(handler, op) =>
1037- val opSerde = handler.asInstanceOf [CometOperatorSerde [SparkPlan ]]
1038- val maybeConverted = opSerde.convert(op, builder, childOp : _* )
1039- if (maybeConverted.isDefined) {
1040- return maybeConverted
1041- }
1042- case _ =>
1043- }
1044-
1045- // now handle special cases that cannot be handled as a simple mapping from class name
1046- // and see if operator can be used as a sink
1047879 op match {
1048880
1049881 // Fully native scan for V1
@@ -1108,7 +940,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
1108940 // Emit warning if:
1109941 // 1. it is not Spark shuffle operator, which is handled separately
1110942 // 2. it is not a Comet operator
1111- if (serde.isEmpty && ! op.nodeName.contains(" Comet" ) &&
943+ if (! op.nodeName.contains(" Comet" ) &&
1112944 ! op.isInstanceOf [ShuffleExchangeExec ]) {
1113945 withInfo(op, s " unsupported Spark operator: ${op.nodeName}" )
1114946 }
0 commit comments