Skip to content

Commit 0050ed8

Browse files
authored
feat: Improve shuffle fallback reporting (#2194)
1 parent 0a0b65b commit 0050ed8

File tree

1 file changed

+91
-38
lines changed

1 file changed

+91
-38
lines changed

spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala

Lines changed: 91 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ package org.apache.comet.rules
2222
import scala.collection.mutable.ListBuffer
2323

2424
import org.apache.spark.sql.SparkSession
25-
import org.apache.spark.sql.catalyst.expressions.{Divide, DoubleLiteral, EqualNullSafe, EqualTo, Expression, FloatLiteral, GreaterThan, GreaterThanOrEqual, KnownFloatingPointNormalized, LessThan, LessThanOrEqual, NamedExpression, Remainder}
25+
import org.apache.spark.sql.catalyst.expressions.{Attribute, Divide, DoubleLiteral, EqualNullSafe, EqualTo, Expression, FloatLiteral, GreaterThan, GreaterThanOrEqual, KnownFloatingPointNormalized, LessThan, LessThanOrEqual, NamedExpression, Remainder, SortOrder}
2626
import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial}
2727
import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero
2828
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, RangePartitioning, RoundRobinPartitioning, SinglePartition}
@@ -793,34 +793,52 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
793793
return false
794794
}
795795

796+
if (!checkSupportedShuffleDataTypes(s)) {
797+
return false
798+
}
799+
796800
val inputs = s.child.output
797801
val partitioning = s.outputPartitioning
798802
val conf = SQLConf.get
799803
partitioning match {
800804
case HashPartitioning(expressions, _) =>
801-
// native shuffle currently does not support complex types as partition keys
802-
// due to lack of hashing support for those types
803-
val supported =
804-
expressions.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined) &&
805-
expressions.forall(e => supportedHashPartitionKeyDataType(e.dataType)) &&
806-
inputs.forall(attr => supportedShuffleDataType(attr.dataType)) &&
807-
CometConf.COMET_EXEC_SHUFFLE_WITH_HASH_PARTITIONING_ENABLED.get(conf)
808-
if (!supported) {
809-
withInfo(s, s"unsupported Spark partitioning: $expressions")
805+
var supported = true
806+
if (!CometConf.COMET_EXEC_SHUFFLE_WITH_HASH_PARTITIONING_ENABLED.get(conf)) {
807+
withInfo(
808+
s,
809+
s"${CometConf.COMET_EXEC_SHUFFLE_WITH_HASH_PARTITIONING_ENABLED.key} is disabled")
810+
supported = false
811+
}
812+
for (expr <- expressions) {
813+
if (QueryPlanSerde.exprToProto(expr, inputs).isEmpty) {
814+
withInfo(s, s"unsupported hash partitioning expression: $expr")
815+
supported = false
816+
}
817+
}
818+
for (dt <- expressions.map(_.dataType).distinct) {
819+
if (!supportedHashPartitionKeyDataType(dt)) {
820+
// native shuffle currently does not support complex types as partition keys
821+
// due to lack of hashing support for those types
822+
withInfo(s, s"unsupported hash partitioning data type for native shuffle: $dt")
823+
supported = false
824+
}
810825
}
811826
supported
812827
case SinglePartition =>
813-
inputs.forall(attr => supportedShuffleDataType(attr.dataType))
814-
case RangePartitioning(ordering, _) =>
815-
val supported = ordering.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined) &&
816-
inputs.forall(attr => supportedShuffleDataType(attr.dataType)) &&
817-
CometConf.COMET_EXEC_SHUFFLE_WITH_RANGE_PARTITIONING_ENABLED.get(conf)
818-
if (!supported) {
819-
withInfo(s, s"unsupported Spark partitioning: $ordering")
828+
// we already checked that the input types are supported
829+
true
830+
case RangePartitioning(orderings, _) =>
831+
if (!CometConf.COMET_EXEC_SHUFFLE_WITH_RANGE_PARTITIONING_ENABLED.get(conf)) {
832+
// do not encourage the users to enable the config because we know that
833+
// the experimental implementation is not correct yet
834+
withInfo(s, "Range partitioning is not supported by native shuffle")
835+
return false
820836
}
821-
supported
837+
rangePartitioningSupported(s, inputs, orderings)
822838
case _ =>
823-
withInfo(s, s"unsupported Spark partitioning: ${partitioning.getClass.getName}")
839+
withInfo(
840+
s,
841+
s"unsupported Spark partitioning for native shuffle: ${partitioning.getClass.getName}")
824842
false
825843
}
826844
}
@@ -851,38 +869,69 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
851869
}
852870

853871
val inputs = s.child.output
872+
if (!checkSupportedShuffleDataTypes(s)) {
873+
return false
874+
}
875+
854876
val partitioning = s.outputPartitioning
855877
partitioning match {
856878
case HashPartitioning(expressions, _) =>
857-
// columnar shuffle supports the same data types (including complex types) both for
858-
// partition keys and for other columns
859-
val supported =
860-
expressions.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined) &&
861-
expressions.forall(e => supportedShuffleDataType(e.dataType)) &&
862-
inputs.forall(attr => supportedShuffleDataType(attr.dataType))
863-
if (!supported) {
864-
withInfo(s, s"unsupported Spark partitioning expressions: $expressions")
879+
var supported = true
880+
for (expr <- expressions) {
881+
if (QueryPlanSerde.exprToProto(expr, inputs).isEmpty) {
882+
withInfo(s, s"unsupported hash partitioning expression: $expr")
883+
supported = false
884+
}
865885
}
866886
supported
867887
case SinglePartition =>
868-
inputs.forall(attr => supportedShuffleDataType(attr.dataType))
888+
// we already checked that the input types are supported
889+
true
869890
case RoundRobinPartitioning(_) =>
870-
inputs.forall(attr => supportedShuffleDataType(attr.dataType))
891+
// we already checked that the input types are supported
892+
true
871893
case RangePartitioning(orderings, _) =>
872-
val supported =
873-
orderings.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined) &&
874-
orderings.forall(e => supportedShuffleDataType(e.dataType)) &&
875-
inputs.forall(attr => supportedShuffleDataType(attr.dataType))
876-
if (!supported) {
877-
withInfo(s, s"unsupported Spark partitioning expressions: $orderings")
878-
}
879-
supported
894+
rangePartitioningSupported(s, inputs, orderings)
880895
case _ =>
881-
withInfo(s, s"unsupported Spark partitioning: ${partitioning.getClass.getName}")
896+
withInfo(
897+
s,
898+
s"unsupported Spark partitioning for columnar shuffle: ${partitioning.getClass.getName}")
882899
false
883900
}
884901
}
885902

903+
private def rangePartitioningSupported(
904+
s: ShuffleExchangeExec,
905+
inputs: Seq[Attribute],
906+
orderings: Seq[SortOrder]) = {
907+
var supported = true
908+
for (o <- orderings) {
909+
if (QueryPlanSerde.exprToProto(o, inputs).isEmpty) {
910+
withInfo(s, s"unsupported range partitioning sort order: $o")
911+
supported = false
912+
}
913+
}
914+
for (dt <- orderings.map(_.dataType).distinct) {
915+
if (!supportedShuffleDataType(dt)) {
916+
withInfo(s, s"unsupported shuffle data type: $dt")
917+
supported = false
918+
}
919+
}
920+
supported
921+
}
922+
923+
/** Check that all input types can be written to a shuffle file */
924+
private def checkSupportedShuffleDataTypes(s: ShuffleExchangeExec): Boolean = {
925+
var supported = true
926+
for (input <- s.child.output) {
927+
if (!supportedShuffleDataType(input.dataType)) {
928+
withInfo(s, s"unsupported shuffle data type ${input.dataType} for input $input")
929+
supported = false
930+
}
931+
}
932+
supported
933+
}
934+
886935
/**
887936
* Determine which data types are supported in a shuffle.
888937
*/
@@ -895,6 +944,10 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
895944
fields.forall(f => supportedShuffleDataType(f.dataType)) &&
896945
// Java Arrow stream reader cannot work on duplicate field name
897946
fields.map(f => f.name).distinct.length == fields.length
947+
948+
// TODO add support for nested complex types
949+
// https://github.com/apache/datafusion-comet/issues/2199
950+
898951
case ArrayType(ArrayType(_, _), _) => false // TODO: nested array is not supported
899952
case ArrayType(MapType(_, _, _), _) => false // TODO: map array element is not supported
900953
case ArrayType(elementType, _) =>

0 commit comments

Comments
 (0)