@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero
2828import org .apache .spark .sql .catalyst .plans .physical .{HashPartitioning , RangePartitioning , RoundRobinPartitioning , SinglePartition }
2929import org .apache .spark .sql .catalyst .rules .Rule
3030import org .apache .spark .sql .comet ._
31- import org .apache .spark .sql .comet .execution .shuffle .{CometColumnarShuffle , CometNativeShuffle , CometShuffleExchangeExec }
31+ import org .apache .spark .sql .comet .execution .shuffle .{CometColumnarShuffle , CometNativeShuffle , CometShuffleExchangeExec , CometShuffleManager }
3232import org .apache .spark .sql .execution ._
3333import org .apache .spark .sql .execution .adaptive .{AQEShuffleReadExec , BroadcastQueryStageExec , ShuffleQueryStageExec }
3434import org .apache .spark .sql .execution .aggregate .{BaseAggregateExec , HashAggregateExec , ObjectHashAggregateExec }
@@ -39,11 +39,10 @@ import org.apache.spark.sql.internal.SQLConf
3939import org .apache .spark .sql .types .{ArrayType , BinaryType , BooleanType , ByteType , DataType , DateType , DecimalType , DoubleType , FloatType , IntegerType , LongType , MapType , ShortType , StringType , StructType , TimestampNTZType , TimestampType }
4040
4141import org .apache .comet .{CometConf , ExtendedExplainInfo }
42- import org .apache .comet .CometConf .COMET_ANSI_MODE_ENABLED
42+ import org .apache .comet .CometConf .{ COMET_ANSI_MODE_ENABLED , COMET_EXEC_SHUFFLE_ENABLED }
4343import org .apache .comet .CometSparkSessionExtensions ._
4444import org .apache .comet .serde .OperatorOuterClass .Operator
4545import org .apache .comet .serde .QueryPlanSerde
46- import org .apache .comet .serde .QueryPlanSerde .emitWarning
4746
4847/**
4948 * Spark physical optimizer rule for replacing Spark operators with Comet operators.
@@ -54,23 +53,15 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
5453
5554 private def applyCometShuffle (plan : SparkPlan ): SparkPlan = {
5655 plan.transformUp {
57- case s : ShuffleExchangeExec
58- if isCometPlan(s.child) && isCometNativeShuffleMode(conf) &&
59- nativeShuffleSupported(s)._1 =>
60- logInfo(" Comet extension enabled for Native Shuffle" )
61-
56+ case s : ShuffleExchangeExec if nativeShuffleSupported(s) =>
6257 // Switch to use Decimal128 regardless of precision, since Arrow native execution
6358 // doesn't support Decimal32 and Decimal64 yet.
6459 conf.setConfString(CometConf .COMET_USE_DECIMAL_128 .key, " true" )
6560 CometShuffleExchangeExec (s, shuffleType = CometNativeShuffle )
6661
67- // Columnar shuffle for regular Spark operators (not Comet) and Comet operators
68- // (if configured)
69- case s : ShuffleExchangeExec
70- if (! s.child.supportsColumnar || isCometPlan(s.child)) && isCometJVMShuffleMode(conf) &&
71- columnarShuffleSupported(s)._1 &&
72- ! isShuffleOperator(s.child) =>
73- logInfo(" Comet extension enabled for JVM Columnar Shuffle" )
62+ case s : ShuffleExchangeExec if columnarShuffleSupported(s) =>
63+ // Columnar shuffle for regular Spark operators (not Comet) and Comet operators
64+ // (if configured)
7465 CometShuffleExchangeExec (s, shuffleType = CometColumnarShuffle )
7566 }
7667 }
@@ -489,12 +480,8 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
489480
490481 // Native shuffle for Comet operators
491482 case s : ShuffleExchangeExec =>
492- val nativePrecondition = isCometShuffleEnabled(conf) &&
493- isCometNativeShuffleMode(conf) &&
494- nativeShuffleSupported(s)._1
495-
496483 val nativeShuffle : Option [SparkPlan ] =
497- if (nativePrecondition ) {
484+ if (nativeShuffleSupported(s) ) {
498485 val newOp = operator2Proto(s)
499486 newOp match {
500487 case Some (nativeOp) =>
@@ -517,10 +504,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
517504 // (if configured).
518505 // If the child of ShuffleExchangeExec is also a ShuffleExchangeExec, we should not
519506 // convert it to CometColumnarShuffle,
520- if (isCometShuffleEnabled(conf) && isCometJVMShuffleMode(conf) &&
521- columnarShuffleSupported(s)._1 &&
522- ! isShuffleOperator(s.child)) {
523-
507+ if (columnarShuffleSupported(s)) {
524508 val newOp = QueryPlanSerde .operator2Proto(s)
525509 newOp match {
526510 case Some (nativeOp) =>
@@ -543,20 +527,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
543527 if (nativeOrColumnarShuffle.isDefined) {
544528 nativeOrColumnarShuffle.get
545529 } else {
546- val isShuffleEnabled = isCometShuffleEnabled(conf)
547- val reason = getCometShuffleNotEnabledReason(conf).getOrElse(" no reason available" )
548- val msg1 = createMessage(! isShuffleEnabled, s " Comet shuffle is not enabled: $reason" )
549- val columnarShuffleEnabled = isCometJVMShuffleMode(conf)
550- val msg2 = createMessage(
551- isShuffleEnabled && ! columnarShuffleEnabled && ! nativeShuffleSupported(s)._1,
552- " Native shuffle: " +
553- s " ${nativeShuffleSupported(s)._2}" )
554- val typeInfo = columnarShuffleSupported(s)._2
555- val msg3 = createMessage(
556- isShuffleEnabled && columnarShuffleEnabled && ! columnarShuffleSupported(s)._1,
557- " JVM shuffle: " +
558- s " $typeInfo" )
559- withInfo(s, Seq (msg1, msg2, msg3).flatten.mkString(" ," ))
530+ s
560531 }
561532
562533 case op =>
@@ -774,10 +745,24 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
774745 }
775746 }
776747
748+ def isCometShuffleEnabledWithInfo (op : SparkPlan ): Boolean = {
749+ if (! COMET_EXEC_SHUFFLE_ENABLED .get(op.conf)) {
750+ withInfo(
751+ op,
752+ s " Comet shuffle is not enabled: ${COMET_EXEC_SHUFFLE_ENABLED .key} is not enabled " )
753+ false
754+ } else if (! isCometShuffleManagerEnabled(op.conf)) {
755+ withInfo(op, s " spark.shuffle.manager is not set to ${CometShuffleManager .getClass.getName}" )
756+ false
757+ } else {
758+ true
759+ }
760+ }
761+
777762 /**
778763 * Whether the given Spark partitioning is supported by Comet native shuffle.
779764 */
780- private def nativeShuffleSupported (s : ShuffleExchangeExec ): ( Boolean , String ) = {
765+ private def nativeShuffleSupported (s : ShuffleExchangeExec ): Boolean = {
781766
782767 /**
783768 * Determine which data types are supported as hash-partition keys in native shuffle.
@@ -794,11 +779,24 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
794779 false
795780 }
796781
782+ if (! isCometShuffleEnabledWithInfo(s)) {
783+ return false
784+ }
785+
786+ if (! isCometNativeShuffleMode(s.conf)) {
787+ withInfo(s, " Comet native shuffle not enabled" )
788+ return false
789+ }
790+
791+ if (! isCometPlan(s.child)) {
792+ withInfo(s, " Child {s.child.getClass.getName} is not native" )
793+ return false
794+ }
795+
797796 val inputs = s.child.output
798797 val partitioning = s.outputPartitioning
799798 val conf = SQLConf .get
800- var msg = " "
801- val supported = partitioning match {
799+ partitioning match {
802800 case HashPartitioning (expressions, _) =>
803801 // native shuffle currently does not support complex types as partition keys
804802 // due to lack of hashing support for those types
@@ -808,7 +806,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
808806 inputs.forall(attr => supportedShuffleDataType(attr.dataType)) &&
809807 CometConf .COMET_EXEC_SHUFFLE_WITH_HASH_PARTITIONING_ENABLED .get(conf)
810808 if (! supported) {
811- msg = s " unsupported Spark partitioning: $expressions"
809+ withInfo(s, s " unsupported Spark partitioning: $expressions" )
812810 }
813811 supported
814812 case SinglePartition =>
@@ -818,31 +816,43 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
818816 inputs.forall(attr => supportedShuffleDataType(attr.dataType)) &&
819817 CometConf .COMET_EXEC_SHUFFLE_WITH_RANGE_PARTITIONING_ENABLED .get(conf)
820818 if (! supported) {
821- msg = s " unsupported Spark partitioning: $ordering"
819+ withInfo(s, s " unsupported Spark partitioning: $ordering" )
822820 }
823821 supported
824822 case _ =>
825- msg = s " unsupported Spark partitioning: ${partitioning.getClass.getName}"
823+ withInfo(s, s " unsupported Spark partitioning: ${partitioning.getClass.getName}" )
826824 false
827825 }
828-
829- if (! supported) {
830- emitWarning(msg)
831- (false , msg)
832- } else {
833- (true , null )
834- }
835826 }
836827
837828 /**
838829 * Check if the datatypes of shuffle input are supported. This is used for Columnar shuffle
839830 * which supports struct/array.
840831 */
841- private def columnarShuffleSupported (s : ShuffleExchangeExec ): (Boolean , String ) = {
832+ private def columnarShuffleSupported (s : ShuffleExchangeExec ): Boolean = {
833+
834+ if (! isCometShuffleEnabledWithInfo(s)) {
835+ return false
836+ }
837+
838+ if (! isCometJVMShuffleMode(s.conf)) {
839+ withInfo(s, " Comet columnar shuffle not enabled" )
840+ return false
841+ }
842+
843+ if (isShuffleOperator(s.child)) {
844+ withInfo(s, " Child {s.child.getClass.getName} is a shuffle operator" )
845+ return false
846+ }
847+
848+ if (! (! s.child.supportsColumnar || isCometPlan(s.child))) {
849+ withInfo(s, " Child {s.child.getClass.getName} is a neither row-based or a Comet operator" )
850+ return false
851+ }
852+
842853 val inputs = s.child.output
843854 val partitioning = s.outputPartitioning
844- var msg = " "
845- val supported = partitioning match {
855+ partitioning match {
846856 case HashPartitioning (expressions, _) =>
847857 // columnar shuffle supports the same data types (including complex types) both for
848858 // partition keys and for other columns
@@ -851,7 +861,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
851861 expressions.forall(e => supportedShuffleDataType(e.dataType)) &&
852862 inputs.forall(attr => supportedShuffleDataType(attr.dataType))
853863 if (! supported) {
854- msg = s " unsupported Spark partitioning expressions: $expressions"
864+ withInfo(s, s " unsupported Spark partitioning expressions: $expressions" )
855865 }
856866 supported
857867 case SinglePartition =>
@@ -864,20 +874,13 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
864874 orderings.forall(e => supportedShuffleDataType(e.dataType)) &&
865875 inputs.forall(attr => supportedShuffleDataType(attr.dataType))
866876 if (! supported) {
867- msg = s " unsupported Spark partitioning expressions: $orderings"
877+ withInfo(s, s " unsupported Spark partitioning expressions: $orderings" )
868878 }
869879 supported
870880 case _ =>
871- msg = s " unsupported Spark partitioning: ${partitioning.getClass.getName}"
881+ withInfo(s, s " unsupported Spark partitioning: ${partitioning.getClass.getName}" )
872882 false
873883 }
874-
875- if (! supported) {
876- emitWarning(msg)
877- (false , msg)
878- } else {
879- (true , null )
880- }
881884 }
882885
883886 /**
0 commit comments