@@ -22,11 +22,10 @@ package org.apache.comet.rules
2222import org .apache .spark .sql .SparkSession
2323import org .apache .spark .sql .catalyst .expressions .{Divide , DoubleLiteral , EqualNullSafe , EqualTo , Expression , FloatLiteral , GreaterThan , GreaterThanOrEqual , KnownFloatingPointNormalized , LessThan , LessThanOrEqual , NamedExpression , Remainder }
2424import org .apache .spark .sql .catalyst .optimizer .NormalizeNaNAndZero
25- import org .apache .spark .sql .catalyst .plans .physical .{HashPartitioning , RangePartitioning , RoundRobinPartitioning , SinglePartition }
2625import org .apache .spark .sql .catalyst .rules .Rule
2726import org .apache .spark .sql .catalyst .util .sideBySide
2827import org .apache .spark .sql .comet ._
29- import org .apache .spark .sql .comet .execution .shuffle .{CometColumnarShuffle , CometNativeShuffle , CometShuffleExchangeExec , CometShuffleManager }
28+ import org .apache .spark .sql .comet .execution .shuffle .{CometColumnarShuffle , CometNativeShuffle , CometShuffleExchangeExec }
3029import org .apache .spark .sql .execution ._
3130import org .apache .spark .sql .execution .adaptive .{AdaptiveSparkPlanExec , AQEShuffleReadExec , BroadcastQueryStageExec , ShuffleQueryStageExec }
3231import org .apache .spark .sql .execution .aggregate .{HashAggregateExec , ObjectHashAggregateExec }
@@ -35,14 +34,12 @@ import org.apache.spark.sql.execution.datasources.v2.V2CommandExec
3534import org .apache .spark .sql .execution .exchange .{BroadcastExchangeExec , ReusedExchangeExec , ShuffleExchangeExec }
3635import org .apache .spark .sql .execution .joins .{BroadcastHashJoinExec , ShuffledHashJoinExec , SortMergeJoinExec }
3736import org .apache .spark .sql .execution .window .WindowExec
38- import org .apache .spark .sql .internal .SQLConf
3937import org .apache .spark .sql .types ._
4038
4139import org .apache .comet .{CometConf , ExtendedExplainInfo }
42- import org .apache .comet .CometConf .COMET_EXEC_SHUFFLE_ENABLED
4340import org .apache .comet .CometSparkSessionExtensions ._
4441import org .apache .comet .rules .CometExecRule .allExecs
45- import org .apache .comet .serde .{CometOperatorSerde , Compatible , Incompatible , OperatorOuterClass , QueryPlanSerde , Unsupported }
42+ import org .apache .comet .serde .{CometOperatorSerde , Compatible , Incompatible , OperatorOuterClass , Unsupported }
4643import org .apache .comet .serde .OperatorOuterClass .Operator
4744import org .apache .comet .serde .operator ._
4845import org .apache .comet .serde .operator .CometDataWritingCommand
@@ -92,21 +89,19 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
9289
9390 private def applyCometShuffle (plan : SparkPlan ): SparkPlan = {
9491 plan.transformUp {
95- case s : ShuffleExchangeExec if nativeShuffleSupported(s) =>
92+ case s : ShuffleExchangeExec if CometShuffleExchangeExec . nativeShuffleSupported(s) =>
9693 // Switch to use Decimal128 regardless of precision, since Arrow native execution
9794 // doesn't support Decimal32 and Decimal64 yet.
9895 conf.setConfString(CometConf .COMET_USE_DECIMAL_128 .key, " true" )
9996 CometShuffleExchangeExec (s, shuffleType = CometNativeShuffle )
10097
101- case s : ShuffleExchangeExec if columnarShuffleSupported(s) =>
98+ case s : ShuffleExchangeExec if CometShuffleExchangeExec . columnarShuffleSupported(s) =>
10299 // Columnar shuffle for regular Spark operators (not Comet) and Comet operators
103100 // (if configured)
104101 CometShuffleExchangeExec (s, shuffleType = CometColumnarShuffle )
105102 }
106103 }
107104
108- private def isCometPlan (op : SparkPlan ): Boolean = op.isInstanceOf [CometPlan ]
109-
110105 private def isCometNative (op : SparkPlan ): Boolean = op.isInstanceOf [CometNativeExec ]
111106
112107 // spotless:off
@@ -249,9 +244,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
249244 convertToCometIfAllChildrenAreNative(s, CometExchangeSink ).getOrElse(s)
250245
251246 case s : ShuffleExchangeExec =>
252- // try native shuffle first, then columnar shuffle, then fall back to Spark
253- // if neither are supported
254- tryNativeShuffle(s).orElse(tryColumnarShuffle(s)).getOrElse(s)
247+ convertToComet(s, CometShuffleExchangeExec ).getOrElse(s)
255248
256249 case op =>
257250 val handler = allExecs
@@ -288,39 +281,6 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
288281 }
289282 }
290283
291- private def tryNativeShuffle (s : ShuffleExchangeExec ): Option [SparkPlan ] = {
292- Some (s)
293- .filter(nativeShuffleSupported)
294- .filter(_.children.forall(_.isInstanceOf [CometNativeExec ]))
295- .flatMap(_ => operator2Proto(s))
296- .map { nativeOp =>
297- // Switch to use Decimal128 regardless of precision, since Arrow native execution
298- // doesn't support Decimal32 and Decimal64 yet.
299- conf.setConfString(CometConf .COMET_USE_DECIMAL_128 .key, " true" )
300- val cometOp = CometShuffleExchangeExec (s, shuffleType = CometNativeShuffle )
301- CometSinkPlaceHolder (nativeOp, s, cometOp)
302- }
303- }
304-
305- private def tryColumnarShuffle (s : ShuffleExchangeExec ): Option [SparkPlan ] = {
306- // Columnar shuffle for regular Spark operators (not Comet) and Comet operators
307- // (if configured).
308- // If the child of ShuffleExchangeExec is also a ShuffleExchangeExec, we should not
309- // convert it to CometColumnarShuffle,
310- Some (s)
311- .filter(columnarShuffleSupported)
312- .flatMap(_ => operator2Proto(s))
313- .flatMap { nativeOp =>
314- s.child match {
315- case n if n.isInstanceOf [CometNativeExec ] || ! n.supportsColumnar =>
316- val cometOp = CometShuffleExchangeExec (s, shuffleType = CometColumnarShuffle )
317- Some (CometSinkPlaceHolder (nativeOp, s, cometOp))
318- case _ =>
319- None
320- }
321- }
322- }
323-
324284 private def normalizePlan (plan : SparkPlan ): SparkPlan = {
325285 plan.transformUp {
326286 case p : ProjectExec =>
@@ -497,269 +457,6 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
497457 }
498458 }
499459
500- /**
501- * Returns true if a given spark plan is Comet shuffle operator.
502- */
503- private def isShuffleOperator (op : SparkPlan ): Boolean = {
504- op match {
505- case op : ShuffleQueryStageExec if op.plan.isInstanceOf [CometShuffleExchangeExec ] => true
506- case _ : CometShuffleExchangeExec => true
507- case op : CometSinkPlaceHolder => isShuffleOperator(op.child)
508- case _ => false
509- }
510- }
511-
512- def isCometShuffleEnabledWithInfo (op : SparkPlan ): Boolean = {
513- if (! COMET_EXEC_SHUFFLE_ENABLED .get(op.conf)) {
514- withInfo(
515- op,
516- s " Comet shuffle is not enabled: ${COMET_EXEC_SHUFFLE_ENABLED .key} is not enabled " )
517- false
518- } else if (! isCometShuffleManagerEnabled(op.conf)) {
519- withInfo(op, s " spark.shuffle.manager is not set to ${classOf [CometShuffleManager ].getName}" )
520- false
521- } else {
522- true
523- }
524- }
525-
526- /**
527- * Whether the given Spark partitioning is supported by Comet native shuffle.
528- */
529- private def nativeShuffleSupported (s : ShuffleExchangeExec ): Boolean = {
530-
531- /**
532- * Determine which data types are supported as partition columns in native shuffle.
533- *
534- * For HashPartitioning this defines the key that determines how data should be collocated for
535- * operations like `groupByKey`, `reduceByKey`, or `join`. Native code does not support
536- * hashing complex types, see hash_funcs/utils.rs
537- */
538- def supportedHashPartitioningDataType (dt : DataType ): Boolean = dt match {
539- case _ : BooleanType | _ : ByteType | _ : ShortType | _ : IntegerType | _ : LongType |
540- _ : FloatType | _ : DoubleType | _ : StringType | _ : BinaryType | _ : TimestampType |
541- _ : TimestampNTZType | _ : DecimalType | _ : DateType =>
542- true
543- case _ =>
544- false
545- }
546-
547- /**
548- * Determine which data types are supported as partition columns in native shuffle.
549- *
550- * For RangePartitioning this defines the key that determines how data should be collocated
551- * for operations like `orderBy`, `repartitionByRange`. Native code does not support sorting
552- * complex types.
553- */
554- def supportedRangePartitioningDataType (dt : DataType ): Boolean = dt match {
555- case _ : BooleanType | _ : ByteType | _ : ShortType | _ : IntegerType | _ : LongType |
556- _ : FloatType | _ : DoubleType | _ : StringType | _ : BinaryType | _ : TimestampType |
557- _ : TimestampNTZType | _ : DecimalType | _ : DateType =>
558- true
559- case _ =>
560- false
561- }
562-
563- /**
564- * Determine which data types are supported as data columns in native shuffle.
565- *
566- * Native shuffle relies on the Arrow IPC writer to serialize batches to disk, so it should
567- * support all types that Comet supports.
568- */
569- def supportedSerializableDataType (dt : DataType ): Boolean = dt match {
570- case _ : BooleanType | _ : ByteType | _ : ShortType | _ : IntegerType | _ : LongType |
571- _ : FloatType | _ : DoubleType | _ : StringType | _ : BinaryType | _ : TimestampType |
572- _ : TimestampNTZType | _ : DecimalType | _ : DateType =>
573- true
574- case StructType (fields) =>
575- fields.nonEmpty && fields.forall(f => supportedSerializableDataType(f.dataType))
576- case ArrayType (elementType, _) =>
577- supportedSerializableDataType(elementType)
578- case MapType (keyType, valueType, _) =>
579- supportedSerializableDataType(keyType) && supportedSerializableDataType(valueType)
580- case _ =>
581- false
582- }
583-
584- if (! isCometShuffleEnabledWithInfo(s)) {
585- return false
586- }
587-
588- if (! isCometNativeShuffleMode(s.conf)) {
589- withInfo(s, " Comet native shuffle not enabled" )
590- return false
591- }
592-
593- if (! isCometPlan(s.child)) {
594- // we do not need to report a fallback reason if the child plan is not a Comet plan
595- return false
596- }
597-
598- val inputs = s.child.output
599-
600- for (input <- inputs) {
601- if (! supportedSerializableDataType(input.dataType)) {
602- withInfo(s, s " unsupported shuffle data type ${input.dataType} for input $input" )
603- return false
604- }
605- }
606-
607- val partitioning = s.outputPartitioning
608- val conf = SQLConf .get
609- partitioning match {
610- case HashPartitioning (expressions, _) =>
611- var supported = true
612- if (! CometConf .COMET_EXEC_SHUFFLE_WITH_HASH_PARTITIONING_ENABLED .get(conf)) {
613- withInfo(
614- s,
615- s " ${CometConf .COMET_EXEC_SHUFFLE_WITH_HASH_PARTITIONING_ENABLED .key} is disabled " )
616- supported = false
617- }
618- for (expr <- expressions) {
619- if (QueryPlanSerde .exprToProto(expr, inputs).isEmpty) {
620- withInfo(s, s " unsupported hash partitioning expression: $expr" )
621- supported = false
622- // We don't short-circuit in case there is more than one unsupported expression
623- // to provide info for.
624- }
625- }
626- for (dt <- expressions.map(_.dataType).distinct) {
627- if (! supportedHashPartitioningDataType(dt)) {
628- withInfo(s, s " unsupported hash partitioning data type for native shuffle: $dt" )
629- supported = false
630- }
631- }
632- supported
633- case SinglePartition =>
634- // we already checked that the input types are supported
635- true
636- case RangePartitioning (orderings, _) =>
637- if (! CometConf .COMET_EXEC_SHUFFLE_WITH_RANGE_PARTITIONING_ENABLED .get(conf)) {
638- withInfo(
639- s,
640- s " ${CometConf .COMET_EXEC_SHUFFLE_WITH_RANGE_PARTITIONING_ENABLED .key} is disabled " )
641- return false
642- }
643- var supported = true
644- for (o <- orderings) {
645- if (QueryPlanSerde .exprToProto(o, inputs).isEmpty) {
646- withInfo(s, s " unsupported range partitioning sort order: $o" , o)
647- supported = false
648- // We don't short-circuit in case there is more than one unsupported expression
649- // to provide info for.
650- }
651- }
652- for (dt <- orderings.map(_.dataType).distinct) {
653- if (! supportedRangePartitioningDataType(dt)) {
654- withInfo(s, s " unsupported range partitioning data type for native shuffle: $dt" )
655- supported = false
656- }
657- }
658- supported
659- case _ =>
660- withInfo(
661- s,
662- s " unsupported Spark partitioning for native shuffle: ${partitioning.getClass.getName}" )
663- false
664- }
665- }
666-
667- /**
668- * Check if the datatypes of shuffle input are supported. This is used for Columnar shuffle
669- * which supports struct/array.
670- */
671- private def columnarShuffleSupported (s : ShuffleExchangeExec ): Boolean = {
672-
673- /**
674- * Determine which data types are supported as data columns in columnar shuffle.
675- *
676- * Comet columnar shuffle used native code to convert Spark unsafe rows to Arrow batches, see
677- * shuffle/row.rs
678- */
679- def supportedSerializableDataType (dt : DataType ): Boolean = dt match {
680- case _ : BooleanType | _ : ByteType | _ : ShortType | _ : IntegerType | _ : LongType |
681- _ : FloatType | _ : DoubleType | _ : StringType | _ : BinaryType | _ : TimestampType |
682- _ : TimestampNTZType | _ : DecimalType | _ : DateType =>
683- true
684- case StructType (fields) =>
685- fields.nonEmpty && fields.forall(f => supportedSerializableDataType(f.dataType)) &&
686- // Java Arrow stream reader cannot work on duplicate field name
687- fields.map(f => f.name).distinct.length == fields.length &&
688- fields.nonEmpty
689- case ArrayType (elementType, _) =>
690- supportedSerializableDataType(elementType)
691- case MapType (keyType, valueType, _) =>
692- supportedSerializableDataType(keyType) && supportedSerializableDataType(valueType)
693- case _ =>
694- false
695- }
696-
697- if (! isCometShuffleEnabledWithInfo(s)) {
698- return false
699- }
700-
701- if (! isCometJVMShuffleMode(s.conf)) {
702- withInfo(s, " Comet columnar shuffle not enabled" )
703- return false
704- }
705-
706- if (isShuffleOperator(s.child)) {
707- withInfo(s, s " Child ${s.child.getClass.getName} is a shuffle operator " )
708- return false
709- }
710-
711- if (! (! s.child.supportsColumnar || isCometPlan(s.child))) {
712- withInfo(s, s " Child ${s.child.getClass.getName} is a neither row-based or a Comet operator " )
713- return false
714- }
715-
716- val inputs = s.child.output
717-
718- for (input <- inputs) {
719- if (! supportedSerializableDataType(input.dataType)) {
720- withInfo(s, s " unsupported shuffle data type ${input.dataType} for input $input" )
721- return false
722- }
723- }
724-
725- val partitioning = s.outputPartitioning
726- partitioning match {
727- case HashPartitioning (expressions, _) =>
728- var supported = true
729- for (expr <- expressions) {
730- if (QueryPlanSerde .exprToProto(expr, inputs).isEmpty) {
731- withInfo(s, s " unsupported hash partitioning expression: $expr" )
732- supported = false
733- // We don't short-circuit in case there is more than one unsupported expression
734- // to provide info for.
735- }
736- }
737- supported
738- case SinglePartition =>
739- // we already checked that the input types are supported
740- true
741- case RoundRobinPartitioning (_) =>
742- // we already checked that the input types are supported
743- true
744- case RangePartitioning (orderings, _) =>
745- var supported = true
746- for (o <- orderings) {
747- if (QueryPlanSerde .exprToProto(o, inputs).isEmpty) {
748- withInfo(s, s " unsupported range partitioning sort order: $o" )
749- supported = false
750- // We don't short-circuit in case there is more than one unsupported expression
751- // to provide info for.
752- }
753- }
754- supported
755- case _ =>
756- withInfo(
757- s,
758- s " unsupported Spark partitioning for columnar shuffle: ${partitioning.getClass.getName}" )
759- false
760- }
761- }
762-
763460 /**
764461 * Fallback for handling sinks that have not been handled explicitly. This method should
765462 * eventually be removed once CometExecRule fully uses the operator serde framework.
0 commit comments