Skip to content

Commit 5c6a027

Browse files
authored
minor: Move shuffle logic from CometExecRule to CometShuffleExchangeExec serde implementation (#2853)
1 parent fd0ab64 commit 5c6a027

File tree

3 files changed

+328
-327
lines changed

3 files changed

+328
-327
lines changed

spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -157,22 +157,6 @@ object CometSparkSessionExtensions extends Logging {
157157
COMET_EXEC_ENABLED.get(conf)
158158
}
159159

160-
private[comet] def isCometNativeShuffleMode(conf: SQLConf): Boolean = {
161-
COMET_SHUFFLE_MODE.get(conf) match {
162-
case "native" => true
163-
case "auto" => true
164-
case _ => false
165-
}
166-
}
167-
168-
private[comet] def isCometJVMShuffleMode(conf: SQLConf): Boolean = {
169-
COMET_SHUFFLE_MODE.get(conf) match {
170-
case "jvm" => true
171-
case "auto" => true
172-
case _ => false
173-
}
174-
}
175-
176160
def isCometScan(op: SparkPlan): Boolean = {
177161
op.isInstanceOf[CometBatchScanExec] || op.isInstanceOf[CometScanExec]
178162
}

spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala

Lines changed: 5 additions & 308 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,10 @@ package org.apache.comet.rules
2222
import org.apache.spark.sql.SparkSession
2323
import org.apache.spark.sql.catalyst.expressions.{Divide, DoubleLiteral, EqualNullSafe, EqualTo, Expression, FloatLiteral, GreaterThan, GreaterThanOrEqual, KnownFloatingPointNormalized, LessThan, LessThanOrEqual, NamedExpression, Remainder}
2424
import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero
25-
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, RangePartitioning, RoundRobinPartitioning, SinglePartition}
2625
import org.apache.spark.sql.catalyst.rules.Rule
2726
import org.apache.spark.sql.catalyst.util.sideBySide
2827
import org.apache.spark.sql.comet._
29-
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle, CometShuffleExchangeExec, CometShuffleManager}
28+
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle, CometShuffleExchangeExec}
3029
import org.apache.spark.sql.execution._
3130
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec}
3231
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec}
@@ -35,14 +34,12 @@ import org.apache.spark.sql.execution.datasources.v2.V2CommandExec
3534
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
3635
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, SortMergeJoinExec}
3736
import org.apache.spark.sql.execution.window.WindowExec
38-
import org.apache.spark.sql.internal.SQLConf
3937
import org.apache.spark.sql.types._
4038

4139
import org.apache.comet.{CometConf, ExtendedExplainInfo}
42-
import org.apache.comet.CometConf.COMET_EXEC_SHUFFLE_ENABLED
4340
import org.apache.comet.CometSparkSessionExtensions._
4441
import org.apache.comet.rules.CometExecRule.allExecs
45-
import org.apache.comet.serde.{CometOperatorSerde, Compatible, Incompatible, OperatorOuterClass, QueryPlanSerde, Unsupported}
42+
import org.apache.comet.serde.{CometOperatorSerde, Compatible, Incompatible, OperatorOuterClass, Unsupported}
4643
import org.apache.comet.serde.OperatorOuterClass.Operator
4744
import org.apache.comet.serde.operator._
4845
import org.apache.comet.serde.operator.CometDataWritingCommand
@@ -92,21 +89,19 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
9289

9390
private def applyCometShuffle(plan: SparkPlan): SparkPlan = {
9491
plan.transformUp {
95-
case s: ShuffleExchangeExec if nativeShuffleSupported(s) =>
92+
case s: ShuffleExchangeExec if CometShuffleExchangeExec.nativeShuffleSupported(s) =>
9693
// Switch to use Decimal128 regardless of precision, since Arrow native execution
9794
// doesn't support Decimal32 and Decimal64 yet.
9895
conf.setConfString(CometConf.COMET_USE_DECIMAL_128.key, "true")
9996
CometShuffleExchangeExec(s, shuffleType = CometNativeShuffle)
10097

101-
case s: ShuffleExchangeExec if columnarShuffleSupported(s) =>
98+
case s: ShuffleExchangeExec if CometShuffleExchangeExec.columnarShuffleSupported(s) =>
10299
// Columnar shuffle for regular Spark operators (not Comet) and Comet operators
103100
// (if configured)
104101
CometShuffleExchangeExec(s, shuffleType = CometColumnarShuffle)
105102
}
106103
}
107104

108-
private def isCometPlan(op: SparkPlan): Boolean = op.isInstanceOf[CometPlan]
109-
110105
private def isCometNative(op: SparkPlan): Boolean = op.isInstanceOf[CometNativeExec]
111106

112107
// spotless:off
@@ -249,9 +244,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
249244
convertToCometIfAllChildrenAreNative(s, CometExchangeSink).getOrElse(s)
250245

251246
case s: ShuffleExchangeExec =>
252-
// try native shuffle first, then columnar shuffle, then fall back to Spark
253-
// if neither are supported
254-
tryNativeShuffle(s).orElse(tryColumnarShuffle(s)).getOrElse(s)
247+
convertToComet(s, CometShuffleExchangeExec).getOrElse(s)
255248

256249
case op =>
257250
val handler = allExecs
@@ -288,39 +281,6 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
288281
}
289282
}
290283

291-
private def tryNativeShuffle(s: ShuffleExchangeExec): Option[SparkPlan] = {
292-
Some(s)
293-
.filter(nativeShuffleSupported)
294-
.filter(_.children.forall(_.isInstanceOf[CometNativeExec]))
295-
.flatMap(_ => operator2Proto(s))
296-
.map { nativeOp =>
297-
// Switch to use Decimal128 regardless of precision, since Arrow native execution
298-
// doesn't support Decimal32 and Decimal64 yet.
299-
conf.setConfString(CometConf.COMET_USE_DECIMAL_128.key, "true")
300-
val cometOp = CometShuffleExchangeExec(s, shuffleType = CometNativeShuffle)
301-
CometSinkPlaceHolder(nativeOp, s, cometOp)
302-
}
303-
}
304-
305-
private def tryColumnarShuffle(s: ShuffleExchangeExec): Option[SparkPlan] = {
306-
// Columnar shuffle for regular Spark operators (not Comet) and Comet operators
307-
// (if configured).
308-
// If the child of ShuffleExchangeExec is also a ShuffleExchangeExec, we should not
309-
// convert it to CometColumnarShuffle,
310-
Some(s)
311-
.filter(columnarShuffleSupported)
312-
.flatMap(_ => operator2Proto(s))
313-
.flatMap { nativeOp =>
314-
s.child match {
315-
case n if n.isInstanceOf[CometNativeExec] || !n.supportsColumnar =>
316-
val cometOp = CometShuffleExchangeExec(s, shuffleType = CometColumnarShuffle)
317-
Some(CometSinkPlaceHolder(nativeOp, s, cometOp))
318-
case _ =>
319-
None
320-
}
321-
}
322-
}
323-
324284
private def normalizePlan(plan: SparkPlan): SparkPlan = {
325285
plan.transformUp {
326286
case p: ProjectExec =>
@@ -497,269 +457,6 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
497457
}
498458
}
499459

500-
/**
501-
* Returns true if a given spark plan is Comet shuffle operator.
502-
*/
503-
private def isShuffleOperator(op: SparkPlan): Boolean = {
504-
op match {
505-
case op: ShuffleQueryStageExec if op.plan.isInstanceOf[CometShuffleExchangeExec] => true
506-
case _: CometShuffleExchangeExec => true
507-
case op: CometSinkPlaceHolder => isShuffleOperator(op.child)
508-
case _ => false
509-
}
510-
}
511-
512-
def isCometShuffleEnabledWithInfo(op: SparkPlan): Boolean = {
513-
if (!COMET_EXEC_SHUFFLE_ENABLED.get(op.conf)) {
514-
withInfo(
515-
op,
516-
s"Comet shuffle is not enabled: ${COMET_EXEC_SHUFFLE_ENABLED.key} is not enabled")
517-
false
518-
} else if (!isCometShuffleManagerEnabled(op.conf)) {
519-
withInfo(op, s"spark.shuffle.manager is not set to ${classOf[CometShuffleManager].getName}")
520-
false
521-
} else {
522-
true
523-
}
524-
}
525-
526-
/**
527-
* Whether the given Spark partitioning is supported by Comet native shuffle.
528-
*/
529-
private def nativeShuffleSupported(s: ShuffleExchangeExec): Boolean = {
530-
531-
/**
532-
* Determine which data types are supported as partition columns in native shuffle.
533-
*
534-
* For HashPartitioning this defines the key that determines how data should be collocated for
535-
* operations like `groupByKey`, `reduceByKey`, or `join`. Native code does not support
536-
* hashing complex types, see hash_funcs/utils.rs
537-
*/
538-
def supportedHashPartitioningDataType(dt: DataType): Boolean = dt match {
539-
case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType | _: LongType |
540-
_: FloatType | _: DoubleType | _: StringType | _: BinaryType | _: TimestampType |
541-
_: TimestampNTZType | _: DecimalType | _: DateType =>
542-
true
543-
case _ =>
544-
false
545-
}
546-
547-
/**
548-
* Determine which data types are supported as partition columns in native shuffle.
549-
*
550-
* For RangePartitioning this defines the key that determines how data should be collocated
551-
* for operations like `orderBy`, `repartitionByRange`. Native code does not support sorting
552-
* complex types.
553-
*/
554-
def supportedRangePartitioningDataType(dt: DataType): Boolean = dt match {
555-
case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType | _: LongType |
556-
_: FloatType | _: DoubleType | _: StringType | _: BinaryType | _: TimestampType |
557-
_: TimestampNTZType | _: DecimalType | _: DateType =>
558-
true
559-
case _ =>
560-
false
561-
}
562-
563-
/**
564-
* Determine which data types are supported as data columns in native shuffle.
565-
*
566-
* Native shuffle relies on the Arrow IPC writer to serialize batches to disk, so it should
567-
* support all types that Comet supports.
568-
*/
569-
def supportedSerializableDataType(dt: DataType): Boolean = dt match {
570-
case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType | _: LongType |
571-
_: FloatType | _: DoubleType | _: StringType | _: BinaryType | _: TimestampType |
572-
_: TimestampNTZType | _: DecimalType | _: DateType =>
573-
true
574-
case StructType(fields) =>
575-
fields.nonEmpty && fields.forall(f => supportedSerializableDataType(f.dataType))
576-
case ArrayType(elementType, _) =>
577-
supportedSerializableDataType(elementType)
578-
case MapType(keyType, valueType, _) =>
579-
supportedSerializableDataType(keyType) && supportedSerializableDataType(valueType)
580-
case _ =>
581-
false
582-
}
583-
584-
if (!isCometShuffleEnabledWithInfo(s)) {
585-
return false
586-
}
587-
588-
if (!isCometNativeShuffleMode(s.conf)) {
589-
withInfo(s, "Comet native shuffle not enabled")
590-
return false
591-
}
592-
593-
if (!isCometPlan(s.child)) {
594-
// we do not need to report a fallback reason if the child plan is not a Comet plan
595-
return false
596-
}
597-
598-
val inputs = s.child.output
599-
600-
for (input <- inputs) {
601-
if (!supportedSerializableDataType(input.dataType)) {
602-
withInfo(s, s"unsupported shuffle data type ${input.dataType} for input $input")
603-
return false
604-
}
605-
}
606-
607-
val partitioning = s.outputPartitioning
608-
val conf = SQLConf.get
609-
partitioning match {
610-
case HashPartitioning(expressions, _) =>
611-
var supported = true
612-
if (!CometConf.COMET_EXEC_SHUFFLE_WITH_HASH_PARTITIONING_ENABLED.get(conf)) {
613-
withInfo(
614-
s,
615-
s"${CometConf.COMET_EXEC_SHUFFLE_WITH_HASH_PARTITIONING_ENABLED.key} is disabled")
616-
supported = false
617-
}
618-
for (expr <- expressions) {
619-
if (QueryPlanSerde.exprToProto(expr, inputs).isEmpty) {
620-
withInfo(s, s"unsupported hash partitioning expression: $expr")
621-
supported = false
622-
// We don't short-circuit in case there is more than one unsupported expression
623-
// to provide info for.
624-
}
625-
}
626-
for (dt <- expressions.map(_.dataType).distinct) {
627-
if (!supportedHashPartitioningDataType(dt)) {
628-
withInfo(s, s"unsupported hash partitioning data type for native shuffle: $dt")
629-
supported = false
630-
}
631-
}
632-
supported
633-
case SinglePartition =>
634-
// we already checked that the input types are supported
635-
true
636-
case RangePartitioning(orderings, _) =>
637-
if (!CometConf.COMET_EXEC_SHUFFLE_WITH_RANGE_PARTITIONING_ENABLED.get(conf)) {
638-
withInfo(
639-
s,
640-
s"${CometConf.COMET_EXEC_SHUFFLE_WITH_RANGE_PARTITIONING_ENABLED.key} is disabled")
641-
return false
642-
}
643-
var supported = true
644-
for (o <- orderings) {
645-
if (QueryPlanSerde.exprToProto(o, inputs).isEmpty) {
646-
withInfo(s, s"unsupported range partitioning sort order: $o", o)
647-
supported = false
648-
// We don't short-circuit in case there is more than one unsupported expression
649-
// to provide info for.
650-
}
651-
}
652-
for (dt <- orderings.map(_.dataType).distinct) {
653-
if (!supportedRangePartitioningDataType(dt)) {
654-
withInfo(s, s"unsupported range partitioning data type for native shuffle: $dt")
655-
supported = false
656-
}
657-
}
658-
supported
659-
case _ =>
660-
withInfo(
661-
s,
662-
s"unsupported Spark partitioning for native shuffle: ${partitioning.getClass.getName}")
663-
false
664-
}
665-
}
666-
667-
/**
668-
* Check if the datatypes of shuffle input are supported. This is used for Columnar shuffle
669-
* which supports struct/array.
670-
*/
671-
private def columnarShuffleSupported(s: ShuffleExchangeExec): Boolean = {
672-
673-
/**
674-
* Determine which data types are supported as data columns in columnar shuffle.
675-
*
676-
* Comet columnar shuffle used native code to convert Spark unsafe rows to Arrow batches, see
677-
* shuffle/row.rs
678-
*/
679-
def supportedSerializableDataType(dt: DataType): Boolean = dt match {
680-
case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType | _: LongType |
681-
_: FloatType | _: DoubleType | _: StringType | _: BinaryType | _: TimestampType |
682-
_: TimestampNTZType | _: DecimalType | _: DateType =>
683-
true
684-
case StructType(fields) =>
685-
fields.nonEmpty && fields.forall(f => supportedSerializableDataType(f.dataType)) &&
686-
// Java Arrow stream reader cannot work on duplicate field name
687-
fields.map(f => f.name).distinct.length == fields.length &&
688-
fields.nonEmpty
689-
case ArrayType(elementType, _) =>
690-
supportedSerializableDataType(elementType)
691-
case MapType(keyType, valueType, _) =>
692-
supportedSerializableDataType(keyType) && supportedSerializableDataType(valueType)
693-
case _ =>
694-
false
695-
}
696-
697-
if (!isCometShuffleEnabledWithInfo(s)) {
698-
return false
699-
}
700-
701-
if (!isCometJVMShuffleMode(s.conf)) {
702-
withInfo(s, "Comet columnar shuffle not enabled")
703-
return false
704-
}
705-
706-
if (isShuffleOperator(s.child)) {
707-
withInfo(s, s"Child ${s.child.getClass.getName} is a shuffle operator")
708-
return false
709-
}
710-
711-
if (!(!s.child.supportsColumnar || isCometPlan(s.child))) {
712-
withInfo(s, s"Child ${s.child.getClass.getName} is a neither row-based or a Comet operator")
713-
return false
714-
}
715-
716-
val inputs = s.child.output
717-
718-
for (input <- inputs) {
719-
if (!supportedSerializableDataType(input.dataType)) {
720-
withInfo(s, s"unsupported shuffle data type ${input.dataType} for input $input")
721-
return false
722-
}
723-
}
724-
725-
val partitioning = s.outputPartitioning
726-
partitioning match {
727-
case HashPartitioning(expressions, _) =>
728-
var supported = true
729-
for (expr <- expressions) {
730-
if (QueryPlanSerde.exprToProto(expr, inputs).isEmpty) {
731-
withInfo(s, s"unsupported hash partitioning expression: $expr")
732-
supported = false
733-
// We don't short-circuit in case there is more than one unsupported expression
734-
// to provide info for.
735-
}
736-
}
737-
supported
738-
case SinglePartition =>
739-
// we already checked that the input types are supported
740-
true
741-
case RoundRobinPartitioning(_) =>
742-
// we already checked that the input types are supported
743-
true
744-
case RangePartitioning(orderings, _) =>
745-
var supported = true
746-
for (o <- orderings) {
747-
if (QueryPlanSerde.exprToProto(o, inputs).isEmpty) {
748-
withInfo(s, s"unsupported range partitioning sort order: $o")
749-
supported = false
750-
// We don't short-circuit in case there is more than one unsupported expression
751-
// to provide info for.
752-
}
753-
}
754-
supported
755-
case _ =>
756-
withInfo(
757-
s,
758-
s"unsupported Spark partitioning for columnar shuffle: ${partitioning.getClass.getName}")
759-
false
760-
}
761-
}
762-
763460
/**
764461
* Fallback for handling sinks that have not been handled explicitly. This method should
765462
* eventually be removed once CometExecRule fully uses the operator serde framework.

0 commit comments

Comments
 (0)