@@ -22,7 +22,7 @@ package org.apache.comet.rules
2222import scala .collection .mutable .ListBuffer
2323
2424import org .apache .spark .sql .SparkSession
25- import org .apache .spark .sql .catalyst .expressions .{Divide , DoubleLiteral , EqualNullSafe , EqualTo , Expression , FloatLiteral , GreaterThan , GreaterThanOrEqual , KnownFloatingPointNormalized , LessThan , LessThanOrEqual , NamedExpression , Remainder }
25+ import org .apache .spark .sql .catalyst .expressions .{Attribute , Divide , DoubleLiteral , EqualNullSafe , EqualTo , Expression , FloatLiteral , GreaterThan , GreaterThanOrEqual , KnownFloatingPointNormalized , LessThan , LessThanOrEqual , NamedExpression , Remainder , SortOrder }
2626import org .apache .spark .sql .catalyst .expressions .aggregate .{Final , Partial }
2727import org .apache .spark .sql .catalyst .optimizer .NormalizeNaNAndZero
2828import org .apache .spark .sql .catalyst .plans .physical .{HashPartitioning , RangePartitioning , RoundRobinPartitioning , SinglePartition }
@@ -793,34 +793,52 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
793793 return false
794794 }
795795
796+ if (! checkSupportedShuffleDataTypes(s)) {
797+ return false
798+ }
799+
796800 val inputs = s.child.output
797801 val partitioning = s.outputPartitioning
798802 val conf = SQLConf .get
799803 partitioning match {
800804 case HashPartitioning (expressions, _) =>
801- // native shuffle currently does not support complex types as partition keys
802- // due to lack of hashing support for those types
803- val supported =
804- expressions.map(QueryPlanSerde .exprToProto(_, inputs)).forall(_.isDefined) &&
805- expressions.forall(e => supportedHashPartitionKeyDataType(e.dataType)) &&
806- inputs.forall(attr => supportedShuffleDataType(attr.dataType)) &&
807- CometConf .COMET_EXEC_SHUFFLE_WITH_HASH_PARTITIONING_ENABLED .get(conf)
808- if (! supported) {
809- withInfo(s, s " unsupported Spark partitioning: $expressions" )
805+ var supported = true
806+ if (! CometConf .COMET_EXEC_SHUFFLE_WITH_HASH_PARTITIONING_ENABLED .get(conf)) {
807+ withInfo(
808+ s,
809+ s " ${CometConf .COMET_EXEC_SHUFFLE_WITH_HASH_PARTITIONING_ENABLED .key} is disabled " )
810+ supported = false
811+ }
812+ for (expr <- expressions) {
813+ if (QueryPlanSerde .exprToProto(expr, inputs).isEmpty) {
814+ withInfo(s, s " unsupported hash partitioning expression: $expr" )
815+ supported = false
816+ }
817+ }
818+ for (dt <- expressions.map(_.dataType).distinct) {
819+ if (! supportedHashPartitionKeyDataType(dt)) {
820+ // native shuffle currently does not support complex types as partition keys
821+ // due to lack of hashing support for those types
822+ withInfo(s, s " unsupported hash partitioning data type for native shuffle: $dt" )
823+ supported = false
824+ }
810825 }
811826 supported
812827 case SinglePartition =>
813- inputs.forall(attr => supportedShuffleDataType(attr.dataType))
814- case RangePartitioning (ordering, _) =>
815- val supported = ordering.map(QueryPlanSerde .exprToProto(_, inputs)).forall(_.isDefined) &&
816- inputs.forall(attr => supportedShuffleDataType(attr.dataType)) &&
817- CometConf .COMET_EXEC_SHUFFLE_WITH_RANGE_PARTITIONING_ENABLED .get(conf)
818- if (! supported) {
819- withInfo(s, s " unsupported Spark partitioning: $ordering" )
828+ // we already checked that the input types are supported
829+ true
830+ case RangePartitioning (orderings, _) =>
831+ if (! CometConf .COMET_EXEC_SHUFFLE_WITH_RANGE_PARTITIONING_ENABLED .get(conf)) {
832+ // do not encourage the users to enable the config because we know that
833+ // the experimental implementation is not correct yet
834+ withInfo(s, " Range partitioning is not supported by native shuffle" )
835+ return false
820836 }
821- supported
837+ rangePartitioningSupported(s, inputs, orderings)
822838 case _ =>
823- withInfo(s, s " unsupported Spark partitioning: ${partitioning.getClass.getName}" )
839+ withInfo(
840+ s,
841+ s " unsupported Spark partitioning for native shuffle: ${partitioning.getClass.getName}" )
824842 false
825843 }
826844 }
@@ -851,38 +869,69 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
851869 }
852870
853871 val inputs = s.child.output
872+ if (! checkSupportedShuffleDataTypes(s)) {
873+ return false
874+ }
875+
854876 val partitioning = s.outputPartitioning
855877 partitioning match {
856878 case HashPartitioning (expressions, _) =>
857- // columnar shuffle supports the same data types (including complex types) both for
858- // partition keys and for other columns
859- val supported =
860- expressions.map(QueryPlanSerde .exprToProto(_, inputs)).forall(_.isDefined) &&
861- expressions.forall(e => supportedShuffleDataType(e.dataType)) &&
862- inputs.forall(attr => supportedShuffleDataType(attr.dataType))
863- if (! supported) {
864- withInfo(s, s " unsupported Spark partitioning expressions: $expressions" )
879+ var supported = true
880+ for (expr <- expressions) {
881+ if (QueryPlanSerde .exprToProto(expr, inputs).isEmpty) {
882+ withInfo(s, s " unsupported hash partitioning expression: $expr" )
883+ supported = false
884+ }
865885 }
866886 supported
867887 case SinglePartition =>
868- inputs.forall(attr => supportedShuffleDataType(attr.dataType))
888+ // we already checked that the input types are supported
889+ true
869890 case RoundRobinPartitioning (_) =>
870- inputs.forall(attr => supportedShuffleDataType(attr.dataType))
891+ // we already checked that the input types are supported
892+ true
871893 case RangePartitioning (orderings, _) =>
872- val supported =
873- orderings.map(QueryPlanSerde .exprToProto(_, inputs)).forall(_.isDefined) &&
874- orderings.forall(e => supportedShuffleDataType(e.dataType)) &&
875- inputs.forall(attr => supportedShuffleDataType(attr.dataType))
876- if (! supported) {
877- withInfo(s, s " unsupported Spark partitioning expressions: $orderings" )
878- }
879- supported
894+ rangePartitioningSupported(s, inputs, orderings)
880895 case _ =>
881- withInfo(s, s " unsupported Spark partitioning: ${partitioning.getClass.getName}" )
896+ withInfo(
897+ s,
898+ s " unsupported Spark partitioning for columnar shuffle: ${partitioning.getClass.getName}" )
882899 false
883900 }
884901 }
885902
903+ private def rangePartitioningSupported (
904+ s : ShuffleExchangeExec ,
905+ inputs : Seq [Attribute ],
906+ orderings : Seq [SortOrder ]) = {
907+ var supported = true
908+ for (o <- orderings) {
909+ if (QueryPlanSerde .exprToProto(o, inputs).isEmpty) {
910+ withInfo(s, s " unsupported range partitioning sort order: $o" )
911+ supported = false
912+ }
913+ }
914+ for (dt <- orderings.map(_.dataType).distinct) {
915+ if (! supportedShuffleDataType(dt)) {
916+ withInfo(s, s " unsupported shuffle data type: $dt" )
917+ supported = false
918+ }
919+ }
920+ supported
921+ }
922+
923+ /** Check that all input types can be written to a shuffle file */
924+ private def checkSupportedShuffleDataTypes (s : ShuffleExchangeExec ): Boolean = {
925+ var supported = true
926+ for (input <- s.child.output) {
927+ if (! supportedShuffleDataType(input.dataType)) {
928+ withInfo(s, s " unsupported shuffle data type ${input.dataType} for input $input" )
929+ supported = false
930+ }
931+ }
932+ supported
933+ }
934+
886935 /**
887936 * Determine which data types are supported in a shuffle.
888937 */
@@ -895,6 +944,10 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
895944 fields.forall(f => supportedShuffleDataType(f.dataType)) &&
896945 // Java Arrow stream reader cannot work on duplicate field name
897946 fields.map(f => f.name).distinct.length == fields.length
947+
948+ // TODO add support for nested complex types
949+ // https://github.com/apache/datafusion-comet/issues/2199
950+
898951 case ArrayType (ArrayType (_, _), _) => false // TODO: nested array is not supported
899952 case ArrayType (MapType (_, _, _), _) => false // TODO: map array element is not supported
900953 case ArrayType (elementType, _) =>
0 commit comments