@@ -25,6 +25,7 @@ import org.apache.spark.sql.SparkSession
2525import org .apache .spark .sql .catalyst .expressions .{Divide , DoubleLiteral , EqualNullSafe , EqualTo , Expression , FloatLiteral , GreaterThan , GreaterThanOrEqual , KnownFloatingPointNormalized , LessThan , LessThanOrEqual , NamedExpression , Remainder }
2626import org .apache .spark .sql .catalyst .expressions .aggregate .{Final , Partial }
2727import org .apache .spark .sql .catalyst .optimizer .NormalizeNaNAndZero
28+ import org .apache .spark .sql .catalyst .plans .physical .{HashPartitioning , RangePartitioning , RoundRobinPartitioning , SinglePartition }
2829import org .apache .spark .sql .catalyst .rules .Rule
2930import org .apache .spark .sql .comet ._
3031import org .apache .spark .sql .comet .execution .shuffle .{CometColumnarShuffle , CometNativeShuffle , CometShuffleExchangeExec }
@@ -34,13 +35,15 @@ import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregat
3435import org .apache .spark .sql .execution .exchange .{BroadcastExchangeExec , ReusedExchangeExec , ShuffleExchangeExec }
3536import org .apache .spark .sql .execution .joins .{BroadcastHashJoinExec , ShuffledHashJoinExec , SortMergeJoinExec }
3637import org .apache .spark .sql .execution .window .WindowExec
37- import org .apache .spark .sql .types .{DoubleType , FloatType }
38+ import org .apache .spark .sql .internal .SQLConf
39+ import org .apache .spark .sql .types .{ArrayType , BinaryType , BooleanType , ByteType , DataType , DateType , DecimalType , DoubleType , FloatType , IntegerType , LongType , MapType , ShortType , StringType , StructType , TimestampNTZType , TimestampType }
3840
3941import org .apache .comet .{CometConf , ExtendedExplainInfo }
4042import org .apache .comet .CometConf .COMET_ANSI_MODE_ENABLED
4143import org .apache .comet .CometSparkSessionExtensions ._
4244import org .apache .comet .serde .OperatorOuterClass .Operator
4345import org .apache .comet .serde .QueryPlanSerde
46+ import org .apache .comet .serde .QueryPlanSerde .emitWarning
4447
4548/**
4649 * Spark physical optimizer rule for replacing Spark operators with Comet operators.
@@ -53,7 +56,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
5356 plan.transformUp {
5457 case s : ShuffleExchangeExec
5558 if isCometPlan(s.child) && isCometNativeShuffleMode(conf) &&
56- QueryPlanSerde . nativeShuffleSupported(s)._1 =>
59+ nativeShuffleSupported(s)._1 =>
5760 logInfo(" Comet extension enabled for Native Shuffle" )
5861
5962 // Switch to use Decimal128 regardless of precision, since Arrow native execution
@@ -65,7 +68,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
6568 // (if configured)
6669 case s : ShuffleExchangeExec
6770 if (! s.child.supportsColumnar || isCometPlan(s.child)) && isCometJVMShuffleMode(conf) &&
68- QueryPlanSerde . columnarShuffleSupported(s)._1 &&
71+ columnarShuffleSupported(s)._1 &&
6972 ! isShuffleOperator(s.child) =>
7073 logInfo(" Comet extension enabled for JVM Columnar Shuffle" )
7174 CometShuffleExchangeExec (s, shuffleType = CometColumnarShuffle )
@@ -490,7 +493,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
490493 case s : ShuffleExchangeExec =>
491494 val nativePrecondition = isCometShuffleEnabled(conf) &&
492495 isCometNativeShuffleMode(conf) &&
493- QueryPlanSerde . nativeShuffleSupported(s)._1
496+ nativeShuffleSupported(s)._1
494497
495498 val nativeShuffle : Option [SparkPlan ] =
496499 if (nativePrecondition) {
@@ -517,7 +520,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
517520 // If the child of ShuffleExchangeExec is also a ShuffleExchangeExec, we should not
518521 // convert it to CometColumnarShuffle,
519522 if (isCometShuffleEnabled(conf) && isCometJVMShuffleMode(conf) &&
520- QueryPlanSerde . columnarShuffleSupported(s)._1 &&
523+ columnarShuffleSupported(s)._1 &&
521524 ! isShuffleOperator(s.child)) {
522525
523526 val newOp = QueryPlanSerde .operator2Proto(s)
@@ -547,18 +550,12 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
547550 val msg1 = createMessage(! isShuffleEnabled, s " Comet shuffle is not enabled: $reason" )
548551 val columnarShuffleEnabled = isCometJVMShuffleMode(conf)
549552 val msg2 = createMessage(
550- isShuffleEnabled && ! columnarShuffleEnabled && ! QueryPlanSerde
551- .nativeShuffleSupported(s)
552- ._1,
553+ isShuffleEnabled && ! columnarShuffleEnabled && ! nativeShuffleSupported(s)._1,
553554 " Native shuffle: " +
554- s " ${QueryPlanSerde .nativeShuffleSupported(s)._2}" )
555- val typeInfo = QueryPlanSerde
556- .columnarShuffleSupported(s)
557- ._2
555+ s " ${nativeShuffleSupported(s)._2}" )
556+ val typeInfo = columnarShuffleSupported(s)._2
558557 val msg3 = createMessage(
559- isShuffleEnabled && columnarShuffleEnabled && ! QueryPlanSerde
560- .columnarShuffleSupported(s)
561- ._1,
558+ isShuffleEnabled && columnarShuffleEnabled && ! columnarShuffleSupported(s)._1,
562559 " JVM shuffle: " +
563560 s " $typeInfo" )
564561 withInfo(s, Seq (msg1, msg2, msg3).flatten.mkString(" ," ))
@@ -578,7 +575,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
578575 }
579576 }
580577
581- def normalizePlan (plan : SparkPlan ): SparkPlan = {
578+ private def normalizePlan (plan : SparkPlan ): SparkPlan = {
582579 plan.transformUp {
583580 case p : ProjectExec =>
584581 val newProjectList = p.projectList.map(normalize(_).asInstanceOf [NamedExpression ])
@@ -595,7 +592,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
595592 // because they are handled well in Spark (e.g., `SQLOrderingUtil.compareFloats`). But the
596593 // comparison functions in arrow-rs do not normalize NaN and zero. So we need to normalize NaN
597594 // and zero for comparison operators in Comet.
598- def normalize (expr : Expression ): Expression = {
595+ private def normalize (expr : Expression ): Expression = {
599596 expr.transformUp {
600597 case EqualTo (left, right) =>
601598 EqualTo (normalizeNaNAndZero(left), normalizeNaNAndZero(right))
@@ -616,7 +613,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
616613 }
617614 }
618615
619- def normalizeNaNAndZero (expr : Expression ): Expression = {
616+ private def normalizeNaNAndZero (expr : Expression ): Expression = {
620617 expr match {
621618 case _ : KnownFloatingPointNormalized => expr
622619 case FloatLiteral (f) if ! f.equals(- 0.0f ) => expr
@@ -755,7 +752,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
755752 * Find the first Comet partial aggregate in the plan. If it reaches a Spark HashAggregate with
756753 * partial mode, it will return None.
757754 */
758- def findCometPartialAgg (plan : SparkPlan ): Option [CometHashAggregateExec ] = {
755+ private def findCometPartialAgg (plan : SparkPlan ): Option [CometHashAggregateExec ] = {
759756 plan.collectFirst {
760757 case agg : CometHashAggregateExec if agg.aggregateExpressions.forall(_.mode == Partial ) =>
761758 Some (agg)
@@ -770,12 +767,147 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
770767 /**
771768 * Returns true if a given spark plan is Comet shuffle operator.
772769 */
773- def isShuffleOperator (op : SparkPlan ): Boolean = {
770+ private def isShuffleOperator (op : SparkPlan ): Boolean = {
774771 op match {
775772 case op : ShuffleQueryStageExec if op.plan.isInstanceOf [CometShuffleExchangeExec ] => true
776773 case _ : CometShuffleExchangeExec => true
777774 case op : CometSinkPlaceHolder => isShuffleOperator(op.child)
778775 case _ => false
779776 }
780777 }
778+
779+ /**
780+ * Whether the given Spark partitioning is supported by Comet native shuffle.
781+ */
782+ private def nativeShuffleSupported (s : ShuffleExchangeExec ): (Boolean , String ) = {
783+
784+ /**
785+ * Determine which data types are supported as hash-partition keys in native shuffle.
786+ *
787+ * Hash Partition Key determines how data should be collocated for operations like
788+ * `groupByKey`, `reduceByKey` or `join`.
789+ */
790+ def supportedHashPartitionKeyDataType (dt : DataType ): Boolean = dt match {
791+ case _ : BooleanType | _ : ByteType | _ : ShortType | _ : IntegerType | _ : LongType |
792+ _ : FloatType | _ : DoubleType | _ : StringType | _ : BinaryType | _ : TimestampType |
793+ _ : TimestampNTZType | _ : DecimalType | _ : DateType =>
794+ true
795+ case _ =>
796+ false
797+ }
798+
799+ val inputs = s.child.output
800+ val partitioning = s.outputPartitioning
801+ val conf = SQLConf .get
802+ var msg = " "
803+ val supported = partitioning match {
804+ case HashPartitioning (expressions, _) =>
805+ // native shuffle currently does not support complex types as partition keys
806+ // due to lack of hashing support for those types
807+ val supported =
808+ expressions.map(QueryPlanSerde .exprToProto(_, inputs)).forall(_.isDefined) &&
809+ expressions.forall(e => supportedHashPartitionKeyDataType(e.dataType)) &&
810+ inputs.forall(attr => supportedShuffleDataType(attr.dataType)) &&
811+ CometConf .COMET_EXEC_SHUFFLE_WITH_HASH_PARTITIONING_ENABLED .get(conf)
812+ if (! supported) {
813+ msg = s " unsupported Spark partitioning: $expressions"
814+ }
815+ supported
816+ case SinglePartition =>
817+ inputs.forall(attr => supportedShuffleDataType(attr.dataType))
818+ case RangePartitioning (ordering, _) =>
819+ val supported = ordering.map(QueryPlanSerde .exprToProto(_, inputs)).forall(_.isDefined) &&
820+ inputs.forall(attr => supportedShuffleDataType(attr.dataType)) &&
821+ CometConf .COMET_EXEC_SHUFFLE_WITH_RANGE_PARTITIONING_ENABLED .get(conf)
822+ if (! supported) {
823+ msg = s " unsupported Spark partitioning: $ordering"
824+ }
825+ supported
826+ case _ =>
827+ msg = s " unsupported Spark partitioning: ${partitioning.getClass.getName}"
828+ false
829+ }
830+
831+ if (! supported) {
832+ emitWarning(msg)
833+ (false , msg)
834+ } else {
835+ (true , null )
836+ }
837+ }
838+
839+ /**
840+ * Check if the datatypes of shuffle input are supported. This is used for Columnar shuffle
841+ * which supports struct/array.
842+ */
843+ private def columnarShuffleSupported (s : ShuffleExchangeExec ): (Boolean , String ) = {
844+ val inputs = s.child.output
845+ val partitioning = s.outputPartitioning
846+ var msg = " "
847+ val supported = partitioning match {
848+ case HashPartitioning (expressions, _) =>
849+ // columnar shuffle supports the same data types (including complex types) both for
850+ // partition keys and for other columns
851+ val supported =
852+ expressions.map(QueryPlanSerde .exprToProto(_, inputs)).forall(_.isDefined) &&
853+ expressions.forall(e => supportedShuffleDataType(e.dataType)) &&
854+ inputs.forall(attr => supportedShuffleDataType(attr.dataType))
855+ if (! supported) {
856+ msg = s " unsupported Spark partitioning expressions: $expressions"
857+ }
858+ supported
859+ case SinglePartition =>
860+ inputs.forall(attr => supportedShuffleDataType(attr.dataType))
861+ case RoundRobinPartitioning (_) =>
862+ inputs.forall(attr => supportedShuffleDataType(attr.dataType))
863+ case RangePartitioning (orderings, _) =>
864+ val supported =
865+ orderings.map(QueryPlanSerde .exprToProto(_, inputs)).forall(_.isDefined) &&
866+ orderings.forall(e => supportedShuffleDataType(e.dataType)) &&
867+ inputs.forall(attr => supportedShuffleDataType(attr.dataType))
868+ if (! supported) {
869+ msg = s " unsupported Spark partitioning expressions: $orderings"
870+ }
871+ supported
872+ case _ =>
873+ msg = s " unsupported Spark partitioning: ${partitioning.getClass.getName}"
874+ false
875+ }
876+
877+ if (! supported) {
878+ emitWarning(msg)
879+ (false , msg)
880+ } else {
881+ (true , null )
882+ }
883+ }
884+
885+ /**
886+ * Determine which data types are supported in a shuffle.
887+ */
888+ private def supportedShuffleDataType (dt : DataType ): Boolean = dt match {
889+ case _ : BooleanType | _ : ByteType | _ : ShortType | _ : IntegerType | _ : LongType |
890+ _ : FloatType | _ : DoubleType | _ : StringType | _ : BinaryType | _ : TimestampType |
891+ _ : TimestampNTZType | _ : DecimalType | _ : DateType =>
892+ true
893+ case StructType (fields) =>
894+ fields.forall(f => supportedShuffleDataType(f.dataType)) &&
895+ // Java Arrow stream reader cannot work on duplicate field name
896+ fields.map(f => f.name).distinct.length == fields.length
897+ case ArrayType (ArrayType (_, _), _) => false // TODO: nested array is not supported
898+ case ArrayType (MapType (_, _, _), _) => false // TODO: map array element is not supported
899+ case ArrayType (elementType, _) =>
900+ supportedShuffleDataType(elementType)
901+ case MapType (MapType (_, _, _), _, _) => false // TODO: nested map is not supported
902+ case MapType (_, MapType (_, _, _), _) => false
903+ case MapType (StructType (_), _, _) => false // TODO: struct map key/value is not supported
904+ case MapType (_, StructType (_), _) => false
905+ case MapType (ArrayType (_, _), _, _) => false // TODO: array map key/value is not supported
906+ case MapType (_, ArrayType (_, _), _) => false
907+ case MapType (keyType, valueType, _) =>
908+ supportedShuffleDataType(keyType) && supportedShuffleDataType(valueType)
909+ case _ =>
910+ false
911+ }
912+
781913}
0 commit comments