Skip to content

Commit 41c69d5

Browse files
authored
chore: CometExecRule code cleanup (#2159)
1 parent f50caa1 commit 41c69d5

File tree

2 files changed

+68
-76
lines changed

2 files changed

+68
-76
lines changed

spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ import org.apache.spark.sql.{SparkSession, SparkSessionExtensions}
3030
import org.apache.spark.sql.catalyst.rules.Rule
3131
import org.apache.spark.sql.catalyst.trees.TreeNode
3232
import org.apache.spark.sql.comet._
33-
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager
3433
import org.apache.spark.sql.comet.util.Utils
3534
import org.apache.spark.sql.execution._
3635
import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
@@ -145,17 +144,7 @@ object CometSparkSessionExtensions extends Logging {
145144
private[comet] def isCometShuffleEnabled(conf: SQLConf): Boolean =
146145
COMET_EXEC_SHUFFLE_ENABLED.get(conf) && isCometShuffleManagerEnabled(conf)
147146

148-
private[comet] def getCometShuffleNotEnabledReason(conf: SQLConf): Option[String] = {
149-
if (!COMET_EXEC_SHUFFLE_ENABLED.get(conf)) {
150-
Some(s"${COMET_EXEC_SHUFFLE_ENABLED.key} is not enabled")
151-
} else if (!isCometShuffleManagerEnabled(conf)) {
152-
Some(s"spark.shuffle.manager is not set to ${CometShuffleManager.getClass.getName}")
153-
} else {
154-
None
155-
}
156-
}
157-
158-
private def isCometShuffleManagerEnabled(conf: SQLConf) = {
147+
def isCometShuffleManagerEnabled(conf: SQLConf): Boolean = {
159148
conf.contains("spark.shuffle.manager") && conf.getConfString("spark.shuffle.manager") ==
160149
"org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager"
161150
}

spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala

Lines changed: 67 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero
2828
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, RangePartitioning, RoundRobinPartitioning, SinglePartition}
2929
import org.apache.spark.sql.catalyst.rules.Rule
3030
import org.apache.spark.sql.comet._
31-
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle, CometShuffleExchangeExec}
31+
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle, CometShuffleExchangeExec, CometShuffleManager}
3232
import org.apache.spark.sql.execution._
3333
import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec}
3434
import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec}
@@ -39,11 +39,10 @@ import org.apache.spark.sql.internal.SQLConf
3939
import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, ShortType, StringType, StructType, TimestampNTZType, TimestampType}
4040

4141
import org.apache.comet.{CometConf, ExtendedExplainInfo}
42-
import org.apache.comet.CometConf.COMET_ANSI_MODE_ENABLED
42+
import org.apache.comet.CometConf.{COMET_ANSI_MODE_ENABLED, COMET_EXEC_SHUFFLE_ENABLED}
4343
import org.apache.comet.CometSparkSessionExtensions._
4444
import org.apache.comet.serde.OperatorOuterClass.Operator
4545
import org.apache.comet.serde.QueryPlanSerde
46-
import org.apache.comet.serde.QueryPlanSerde.emitWarning
4746

4847
/**
4948
* Spark physical optimizer rule for replacing Spark operators with Comet operators.
@@ -54,23 +53,15 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
5453

5554
private def applyCometShuffle(plan: SparkPlan): SparkPlan = {
5655
plan.transformUp {
57-
case s: ShuffleExchangeExec
58-
if isCometPlan(s.child) && isCometNativeShuffleMode(conf) &&
59-
nativeShuffleSupported(s)._1 =>
60-
logInfo("Comet extension enabled for Native Shuffle")
61-
56+
case s: ShuffleExchangeExec if nativeShuffleSupported(s) =>
6257
// Switch to use Decimal128 regardless of precision, since Arrow native execution
6358
// doesn't support Decimal32 and Decimal64 yet.
6459
conf.setConfString(CometConf.COMET_USE_DECIMAL_128.key, "true")
6560
CometShuffleExchangeExec(s, shuffleType = CometNativeShuffle)
6661

67-
// Columnar shuffle for regular Spark operators (not Comet) and Comet operators
68-
// (if configured)
69-
case s: ShuffleExchangeExec
70-
if (!s.child.supportsColumnar || isCometPlan(s.child)) && isCometJVMShuffleMode(conf) &&
71-
columnarShuffleSupported(s)._1 &&
72-
!isShuffleOperator(s.child) =>
73-
logInfo("Comet extension enabled for JVM Columnar Shuffle")
62+
case s: ShuffleExchangeExec if columnarShuffleSupported(s) =>
63+
// Columnar shuffle for regular Spark operators (not Comet) and Comet operators
64+
// (if configured)
7465
CometShuffleExchangeExec(s, shuffleType = CometColumnarShuffle)
7566
}
7667
}
@@ -489,12 +480,8 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
489480

490481
// Native shuffle for Comet operators
491482
case s: ShuffleExchangeExec =>
492-
val nativePrecondition = isCometShuffleEnabled(conf) &&
493-
isCometNativeShuffleMode(conf) &&
494-
nativeShuffleSupported(s)._1
495-
496483
val nativeShuffle: Option[SparkPlan] =
497-
if (nativePrecondition) {
484+
if (nativeShuffleSupported(s)) {
498485
val newOp = operator2Proto(s)
499486
newOp match {
500487
case Some(nativeOp) =>
@@ -517,10 +504,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
517504
// (if configured).
518505
// If the child of ShuffleExchangeExec is also a ShuffleExchangeExec, we should not
519506
// convert it to CometColumnarShuffle,
520-
if (isCometShuffleEnabled(conf) && isCometJVMShuffleMode(conf) &&
521-
columnarShuffleSupported(s)._1 &&
522-
!isShuffleOperator(s.child)) {
523-
507+
if (columnarShuffleSupported(s)) {
524508
val newOp = QueryPlanSerde.operator2Proto(s)
525509
newOp match {
526510
case Some(nativeOp) =>
@@ -543,20 +527,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
543527
if (nativeOrColumnarShuffle.isDefined) {
544528
nativeOrColumnarShuffle.get
545529
} else {
546-
val isShuffleEnabled = isCometShuffleEnabled(conf)
547-
val reason = getCometShuffleNotEnabledReason(conf).getOrElse("no reason available")
548-
val msg1 = createMessage(!isShuffleEnabled, s"Comet shuffle is not enabled: $reason")
549-
val columnarShuffleEnabled = isCometJVMShuffleMode(conf)
550-
val msg2 = createMessage(
551-
isShuffleEnabled && !columnarShuffleEnabled && !nativeShuffleSupported(s)._1,
552-
"Native shuffle: " +
553-
s"${nativeShuffleSupported(s)._2}")
554-
val typeInfo = columnarShuffleSupported(s)._2
555-
val msg3 = createMessage(
556-
isShuffleEnabled && columnarShuffleEnabled && !columnarShuffleSupported(s)._1,
557-
"JVM shuffle: " +
558-
s"$typeInfo")
559-
withInfo(s, Seq(msg1, msg2, msg3).flatten.mkString(","))
530+
s
560531
}
561532

562533
case op =>
@@ -774,10 +745,24 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
774745
}
775746
}
776747

748+
def isCometShuffleEnabledWithInfo(op: SparkPlan): Boolean = {
749+
if (!COMET_EXEC_SHUFFLE_ENABLED.get(op.conf)) {
750+
withInfo(
751+
op,
752+
s"Comet shuffle is not enabled: ${COMET_EXEC_SHUFFLE_ENABLED.key} is not enabled")
753+
false
754+
} else if (!isCometShuffleManagerEnabled(op.conf)) {
755+
withInfo(op, s"spark.shuffle.manager is not set to ${CometShuffleManager.getClass.getName}")
756+
false
757+
} else {
758+
true
759+
}
760+
}
761+
777762
/**
778763
* Whether the given Spark partitioning is supported by Comet native shuffle.
779764
*/
780-
private def nativeShuffleSupported(s: ShuffleExchangeExec): (Boolean, String) = {
765+
private def nativeShuffleSupported(s: ShuffleExchangeExec): Boolean = {
781766

782767
/**
783768
* Determine which data types are supported as hash-partition keys in native shuffle.
@@ -794,11 +779,24 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
794779
false
795780
}
796781

782+
if (!isCometShuffleEnabledWithInfo(s)) {
783+
return false
784+
}
785+
786+
if (!isCometNativeShuffleMode(s.conf)) {
787+
withInfo(s, "Comet native shuffle not enabled")
788+
return false
789+
}
790+
791+
if (!isCometPlan(s.child)) {
792+
withInfo(s, "Child {s.child.getClass.getName} is not native")
793+
return false
794+
}
795+
797796
val inputs = s.child.output
798797
val partitioning = s.outputPartitioning
799798
val conf = SQLConf.get
800-
var msg = ""
801-
val supported = partitioning match {
799+
partitioning match {
802800
case HashPartitioning(expressions, _) =>
803801
// native shuffle currently does not support complex types as partition keys
804802
// due to lack of hashing support for those types
@@ -808,7 +806,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
808806
inputs.forall(attr => supportedShuffleDataType(attr.dataType)) &&
809807
CometConf.COMET_EXEC_SHUFFLE_WITH_HASH_PARTITIONING_ENABLED.get(conf)
810808
if (!supported) {
811-
msg = s"unsupported Spark partitioning: $expressions"
809+
withInfo(s, s"unsupported Spark partitioning: $expressions")
812810
}
813811
supported
814812
case SinglePartition =>
@@ -818,31 +816,43 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
818816
inputs.forall(attr => supportedShuffleDataType(attr.dataType)) &&
819817
CometConf.COMET_EXEC_SHUFFLE_WITH_RANGE_PARTITIONING_ENABLED.get(conf)
820818
if (!supported) {
821-
msg = s"unsupported Spark partitioning: $ordering"
819+
withInfo(s, s"unsupported Spark partitioning: $ordering")
822820
}
823821
supported
824822
case _ =>
825-
msg = s"unsupported Spark partitioning: ${partitioning.getClass.getName}"
823+
withInfo(s, s"unsupported Spark partitioning: ${partitioning.getClass.getName}")
826824
false
827825
}
828-
829-
if (!supported) {
830-
emitWarning(msg)
831-
(false, msg)
832-
} else {
833-
(true, null)
834-
}
835826
}
836827

837828
/**
838829
* Check if the datatypes of shuffle input are supported. This is used for Columnar shuffle
839830
* which supports struct/array.
840831
*/
841-
private def columnarShuffleSupported(s: ShuffleExchangeExec): (Boolean, String) = {
832+
private def columnarShuffleSupported(s: ShuffleExchangeExec): Boolean = {
833+
834+
if (!isCometShuffleEnabledWithInfo(s)) {
835+
return false
836+
}
837+
838+
if (!isCometJVMShuffleMode(s.conf)) {
839+
withInfo(s, "Comet columnar shuffle not enabled")
840+
return false
841+
}
842+
843+
if (isShuffleOperator(s.child)) {
844+
withInfo(s, "Child {s.child.getClass.getName} is a shuffle operator")
845+
return false
846+
}
847+
848+
if (!(!s.child.supportsColumnar || isCometPlan(s.child))) {
849+
withInfo(s, "Child {s.child.getClass.getName} is a neither row-based or a Comet operator")
850+
return false
851+
}
852+
842853
val inputs = s.child.output
843854
val partitioning = s.outputPartitioning
844-
var msg = ""
845-
val supported = partitioning match {
855+
partitioning match {
846856
case HashPartitioning(expressions, _) =>
847857
// columnar shuffle supports the same data types (including complex types) both for
848858
// partition keys and for other columns
@@ -851,7 +861,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
851861
expressions.forall(e => supportedShuffleDataType(e.dataType)) &&
852862
inputs.forall(attr => supportedShuffleDataType(attr.dataType))
853863
if (!supported) {
854-
msg = s"unsupported Spark partitioning expressions: $expressions"
864+
withInfo(s, s"unsupported Spark partitioning expressions: $expressions")
855865
}
856866
supported
857867
case SinglePartition =>
@@ -864,20 +874,13 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
864874
orderings.forall(e => supportedShuffleDataType(e.dataType)) &&
865875
inputs.forall(attr => supportedShuffleDataType(attr.dataType))
866876
if (!supported) {
867-
msg = s"unsupported Spark partitioning expressions: $orderings"
877+
withInfo(s, s"unsupported Spark partitioning expressions: $orderings")
868878
}
869879
supported
870880
case _ =>
871-
msg = s"unsupported Spark partitioning: ${partitioning.getClass.getName}"
881+
withInfo(s, s"unsupported Spark partitioning: ${partitioning.getClass.getName}")
872882
false
873883
}
874-
875-
if (!supported) {
876-
emitWarning(msg)
877-
(false, msg)
878-
} else {
879-
(true, null)
880-
}
881884
}
882885

883886
/**

0 commit comments

Comments
 (0)