@@ -22,7 +22,7 @@ package org.apache.comet.rules
22
22
import scala .collection .mutable .ListBuffer
23
23
24
24
import org .apache .spark .sql .SparkSession
25
- import org .apache .spark .sql .catalyst .expressions .{Divide , DoubleLiteral , EqualNullSafe , EqualTo , Expression , FloatLiteral , GreaterThan , GreaterThanOrEqual , KnownFloatingPointNormalized , LessThan , LessThanOrEqual , NamedExpression , Remainder }
25
+ import org .apache .spark .sql .catalyst .expressions .{Attribute , Divide , DoubleLiteral , EqualNullSafe , EqualTo , Expression , FloatLiteral , GreaterThan , GreaterThanOrEqual , KnownFloatingPointNormalized , LessThan , LessThanOrEqual , NamedExpression , Remainder , SortOrder }
26
26
import org .apache .spark .sql .catalyst .expressions .aggregate .{Final , Partial }
27
27
import org .apache .spark .sql .catalyst .optimizer .NormalizeNaNAndZero
28
28
import org .apache .spark .sql .catalyst .plans .physical .{HashPartitioning , RangePartitioning , RoundRobinPartitioning , SinglePartition }
@@ -793,34 +793,52 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
793
793
return false
794
794
}
795
795
796
+ if (! checkSupportedShuffleDataTypes(s)) {
797
+ return false
798
+ }
799
+
796
800
val inputs = s.child.output
797
801
val partitioning = s.outputPartitioning
798
802
val conf = SQLConf .get
799
803
partitioning match {
800
804
case HashPartitioning (expressions, _) =>
801
- // native shuffle currently does not support complex types as partition keys
802
- // due to lack of hashing support for those types
803
- val supported =
804
- expressions.map(QueryPlanSerde .exprToProto(_, inputs)).forall(_.isDefined) &&
805
- expressions.forall(e => supportedHashPartitionKeyDataType(e.dataType)) &&
806
- inputs.forall(attr => supportedShuffleDataType(attr.dataType)) &&
807
- CometConf .COMET_EXEC_SHUFFLE_WITH_HASH_PARTITIONING_ENABLED .get(conf)
808
- if (! supported) {
809
- withInfo(s, s " unsupported Spark partitioning: $expressions" )
805
+ var supported = true
806
+ if (! CometConf .COMET_EXEC_SHUFFLE_WITH_HASH_PARTITIONING_ENABLED .get(conf)) {
807
+ withInfo(
808
+ s,
809
+ s " ${CometConf .COMET_EXEC_SHUFFLE_WITH_HASH_PARTITIONING_ENABLED .key} is disabled " )
810
+ supported = false
811
+ }
812
+ for (expr <- expressions) {
813
+ if (QueryPlanSerde .exprToProto(expr, inputs).isEmpty) {
814
+ withInfo(s, s " unsupported hash partitioning expression: $expr" )
815
+ supported = false
816
+ }
817
+ }
818
+ for (dt <- expressions.map(_.dataType).distinct) {
819
+ if (! supportedHashPartitionKeyDataType(dt)) {
820
+ // native shuffle currently does not support complex types as partition keys
821
+ // due to lack of hashing support for those types
822
+ withInfo(s, s " unsupported hash partitioning data type for native shuffle: $dt" )
823
+ supported = false
824
+ }
810
825
}
811
826
supported
812
827
case SinglePartition =>
813
- inputs.forall(attr => supportedShuffleDataType(attr.dataType))
814
- case RangePartitioning (ordering, _) =>
815
- val supported = ordering.map(QueryPlanSerde .exprToProto(_, inputs)).forall(_.isDefined) &&
816
- inputs.forall(attr => supportedShuffleDataType(attr.dataType)) &&
817
- CometConf .COMET_EXEC_SHUFFLE_WITH_RANGE_PARTITIONING_ENABLED .get(conf)
818
- if (! supported) {
819
- withInfo(s, s " unsupported Spark partitioning: $ordering" )
828
+ // we already checked that the input types are supported
829
+ true
830
+ case RangePartitioning (orderings, _) =>
831
+ if (! CometConf .COMET_EXEC_SHUFFLE_WITH_RANGE_PARTITIONING_ENABLED .get(conf)) {
832
+ // do not encourage the users to enable the config because we know that
833
+ // the experimental implementation is not correct yet
834
+ withInfo(s, " Range partitioning is not supported by native shuffle" )
835
+ return false
820
836
}
821
- supported
837
+ rangePartitioningSupported(s, inputs, orderings)
822
838
case _ =>
823
- withInfo(s, s " unsupported Spark partitioning: ${partitioning.getClass.getName}" )
839
+ withInfo(
840
+ s,
841
+ s " unsupported Spark partitioning for native shuffle: ${partitioning.getClass.getName}" )
824
842
false
825
843
}
826
844
}
@@ -851,38 +869,69 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
851
869
}
852
870
853
871
val inputs = s.child.output
872
+ if (! checkSupportedShuffleDataTypes(s)) {
873
+ return false
874
+ }
875
+
854
876
val partitioning = s.outputPartitioning
855
877
partitioning match {
856
878
case HashPartitioning (expressions, _) =>
857
- // columnar shuffle supports the same data types (including complex types) both for
858
- // partition keys and for other columns
859
- val supported =
860
- expressions.map(QueryPlanSerde .exprToProto(_, inputs)).forall(_.isDefined) &&
861
- expressions.forall(e => supportedShuffleDataType(e.dataType)) &&
862
- inputs.forall(attr => supportedShuffleDataType(attr.dataType))
863
- if (! supported) {
864
- withInfo(s, s " unsupported Spark partitioning expressions: $expressions" )
879
+ var supported = true
880
+ for (expr <- expressions) {
881
+ if (QueryPlanSerde .exprToProto(expr, inputs).isEmpty) {
882
+ withInfo(s, s " unsupported hash partitioning expression: $expr" )
883
+ supported = false
884
+ }
865
885
}
866
886
supported
867
887
case SinglePartition =>
868
- inputs.forall(attr => supportedShuffleDataType(attr.dataType))
888
+ // we already checked that the input types are supported
889
+ true
869
890
case RoundRobinPartitioning (_) =>
870
- inputs.forall(attr => supportedShuffleDataType(attr.dataType))
891
+ // we already checked that the input types are supported
892
+ true
871
893
case RangePartitioning (orderings, _) =>
872
- val supported =
873
- orderings.map(QueryPlanSerde .exprToProto(_, inputs)).forall(_.isDefined) &&
874
- orderings.forall(e => supportedShuffleDataType(e.dataType)) &&
875
- inputs.forall(attr => supportedShuffleDataType(attr.dataType))
876
- if (! supported) {
877
- withInfo(s, s " unsupported Spark partitioning expressions: $orderings" )
878
- }
879
- supported
894
+ rangePartitioningSupported(s, inputs, orderings)
880
895
case _ =>
881
- withInfo(s, s " unsupported Spark partitioning: ${partitioning.getClass.getName}" )
896
+ withInfo(
897
+ s,
898
+ s " unsupported Spark partitioning for columnar shuffle: ${partitioning.getClass.getName}" )
882
899
false
883
900
}
884
901
}
885
902
903
+ private def rangePartitioningSupported (
904
+ s : ShuffleExchangeExec ,
905
+ inputs : Seq [Attribute ],
906
+ orderings : Seq [SortOrder ]) = {
907
+ var supported = true
908
+ for (o <- orderings) {
909
+ if (QueryPlanSerde .exprToProto(o, inputs).isEmpty) {
910
+ withInfo(s, s " unsupported range partitioning sort order: $o" )
911
+ supported = false
912
+ }
913
+ }
914
+ for (dt <- orderings.map(_.dataType).distinct) {
915
+ if (! supportedShuffleDataType(dt)) {
916
+ withInfo(s, s " unsupported shuffle data type: $dt" )
917
+ supported = false
918
+ }
919
+ }
920
+ supported
921
+ }
922
+
923
+ /** Check that all input types can be written to a shuffle file */
924
+ private def checkSupportedShuffleDataTypes (s : ShuffleExchangeExec ): Boolean = {
925
+ var supported = true
926
+ for (input <- s.child.output) {
927
+ if (! supportedShuffleDataType(input.dataType)) {
928
+ withInfo(s, s " unsupported shuffle data type ${input.dataType} for input $input" )
929
+ supported = false
930
+ }
931
+ }
932
+ supported
933
+ }
934
+
886
935
/**
887
936
* Determine which data types are supported in a shuffle.
888
937
*/
@@ -895,6 +944,10 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
895
944
fields.forall(f => supportedShuffleDataType(f.dataType)) &&
896
945
// Java Arrow stream reader cannot work on duplicate field name
897
946
fields.map(f => f.name).distinct.length == fields.length
947
+
948
+ // TODO add support for nested complex types
949
+ // https://github.com/apache/datafusion-comet/issues/2199
950
+
898
951
case ArrayType (ArrayType (_, _), _) => false // TODO: nested array is not supported
899
952
case ArrayType (MapType (_, _, _), _) => false // TODO: map array element is not supported
900
953
case ArrayType (elementType, _) =>
0 commit comments