Skip to content

Commit 7717a25

Browse files
authored
chore: More refactoring of type checking logic (#1744)
1 parent 48af872 commit 7717a25

File tree

13 files changed

+105
-56
lines changed

13 files changed

+105
-56
lines changed

spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,9 @@ import org.apache.comet.rules.{CometExecRule, CometScanRule, EliminateRedundantT
4747
import org.apache.comet.shims.ShimCometSparkSessionExtensions
4848

4949
/**
50-
* The entry point of Comet extension to Spark. This class is responsible for injecting Comet
51-
* rules and extensions into Spark.
50+
* CometDriverPlugin will register an instance of this class with Spark.
5251
*
53-
* CometScanRule: A rule to transform a Spark scan plan into a Comet scan plan. CometExecRule: A
54-
* rule to transform a Spark execution plan into a Comet execution plan.
52+
* This class is responsible for injecting Comet rules and extensions into Spark.
5553
*/
5654
class CometSparkSessionExtensions
5755
extends (SparkSessionExtensions => Unit)
@@ -242,15 +240,6 @@ object CometSparkSessionExtensions extends Logging {
242240
org.apache.spark.SPARK_VERSION >= "4.0"
243241
}
244242

245-
def usingDataSourceExec(conf: SQLConf): Boolean =
246-
Seq(CometConf.SCAN_NATIVE_ICEBERG_COMPAT, CometConf.SCAN_NATIVE_DATAFUSION).contains(
247-
CometConf.COMET_NATIVE_SCAN_IMPL.get(conf))
248-
249-
def usingDataSourceExecWithIncompatTypes(conf: SQLConf): Boolean = {
250-
usingDataSourceExec(conf) &&
251-
!CometConf.COMET_SCAN_ALLOW_INCOMPATIBLE.get(conf)
252-
}
253-
254243
/**
255244
* Whether we should override Spark memory configuration for Comet. This only returns true when
256245
* Comet native execution is enabled and/or Comet shuffle is enabled and Comet doesn't use

spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ import org.apache.comet.CometSparkSessionExtensions.{createMessage, getCometBroa
4040
import org.apache.comet.serde.OperatorOuterClass.Operator
4141
import org.apache.comet.serde.QueryPlanSerde
4242

43+
/**
44+
* Spark physical optimizer rule for replacing Spark operators with Comet operators.
45+
*/
4346
case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
4447
private def applyCometShuffle(plan: SparkPlan): SparkPlan = {
4548
plan.transformUp {

spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ import org.apache.comet.CometConf._
3737
import org.apache.comet.CometSparkSessionExtensions.{isCometLoaded, isCometScanEnabled, withInfo, withInfos}
3838
import org.apache.comet.parquet.{CometParquetScan, SupportsComet}
3939

40+
/**
41+
* Spark physical optimizer rule for replacing Spark scans with Comet scans.
42+
*/
4043
case class CometScanRule(session: SparkSession) extends Rule[SparkPlan] {
4144
override def apply(plan: SparkPlan): SparkPlan = {
4245
if (!isCometLoaded(conf) || !isCometScanEnabled(conf)) {

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ import org.apache.spark.sql.types._
4848
import org.apache.spark.unsafe.types.UTF8String
4949

5050
import org.apache.comet.CometConf
51-
import org.apache.comet.CometSparkSessionExtensions.{isCometScan, usingDataSourceExec, withInfo}
51+
import org.apache.comet.CometSparkSessionExtensions.{isCometScan, withInfo}
5252
import org.apache.comet.expressions._
5353
import org.apache.comet.serde.ExprOuterClass.{AggExpr, DataType => ProtoDataType, Expr, ScalarFunc}
5454
import org.apache.comet.serde.ExprOuterClass.DataType._
@@ -2518,6 +2518,15 @@ object QueryPlanSerde extends Logging with CometExprShim {
25182518
return None
25192519
}
25202520

2521+
if (groupingExpressions.exists(expr =>
2522+
expr.dataType match {
2523+
case _: MapType => true
2524+
case _ => false
2525+
})) {
2526+
withInfo(op, "Grouping on map types is not supported")
2527+
return None
2528+
}
2529+
25212530
val groupingExprs = groupingExpressions.map(exprToProto(_, child.output))
25222531
if (groupingExprs.exists(_.isEmpty)) {
25232532
withInfo(op, "Not all grouping expressions are supported")
@@ -2758,16 +2767,14 @@ object QueryPlanSerde extends Logging with CometExprShim {
27582767
withInfo(join, "SortMergeJoin is not enabled")
27592768
None
27602769

2761-
case op
2762-
if isCometSink(op) && op.output.forall(a =>
2763-
supportedDataType(
2764-
a.dataType,
2765-
// Complex type supported if
2766-
// - Native datafusion reader enabled (experimental) OR
2767-
// - conversion from Parquet/JSON enabled
2768-
allowComplex =
2769-
usingDataSourceExec(conf) || CometConf.COMET_CONVERT_FROM_PARQUET_ENABLED
2770-
.get(conf) || CometConf.COMET_CONVERT_FROM_JSON_ENABLED.get(conf))) =>
2770+
case op if isCometSink(op) =>
2771+
val supportedTypes =
2772+
op.output.forall(a => supportedDataType(a.dataType, allowComplex = true))
2773+
2774+
if (!supportedTypes) {
2775+
return None
2776+
}
2777+
27712778
// These operators are source of Comet native execution chain
27722779
val scanBuilder = OperatorOuterClass.Scan.newBuilder()
27732780
val source = op.simpleStringWithNodeId()

spark/src/main/scala/org/apache/spark/sql/comet/CometNativeScanExec.scala

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,12 @@ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartit
3131
import org.apache.spark.sql.execution._
3232
import org.apache.spark.sql.execution.datasources._
3333
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
34-
import org.apache.spark.sql.internal.SQLConf
3534
import org.apache.spark.sql.types._
3635
import org.apache.spark.util.collection._
3736

3837
import com.google.common.base.Objects
3938

4039
import org.apache.comet.{CometConf, DataTypeSupport}
41-
import org.apache.comet.CometSparkSessionExtensions.usingDataSourceExecWithIncompatTypes
4240
import org.apache.comet.parquet.CometParquetFileFormat
4341
import org.apache.comet.serde.OperatorOuterClass.Operator
4442

@@ -237,8 +235,9 @@ object CometNativeScanExec extends DataTypeSupport {
237235
name: String,
238236
fallbackReasons: ListBuffer[String]): Boolean = {
239237
dt match {
240-
case ByteType | ShortType if usingDataSourceExecWithIncompatTypes(SQLConf.get) =>
241-
fallbackReasons += s"${CometConf.COMET_SCAN_ALLOW_INCOMPATIBLE.key} is false"
238+
case ByteType | ShortType if !CometConf.COMET_SCAN_ALLOW_INCOMPATIBLE.get() =>
239+
fallbackReasons += s"${CometConf.SCAN_NATIVE_DATAFUSION} scan cannot read $dt when " +
240+
s"${CometConf.COMET_SCAN_ALLOW_INCOMPATIBLE.key} is false. ${CometConf.COMPAT_GUIDE}."
242241
false
243242
case _ =>
244243
super.isTypeSupported(dt, name, fallbackReasons)

spark/src/main/scala/org/apache/spark/sql/comet/CometScanExec.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,12 @@ import org.apache.spark.sql.execution.datasources._
3939
import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, ParquetOptions}
4040
import org.apache.spark.sql.execution.datasources.v2.DataSourceRDD
4141
import org.apache.spark.sql.execution.metric._
42-
import org.apache.spark.sql.internal.SQLConf
4342
import org.apache.spark.sql.types._
4443
import org.apache.spark.sql.vectorized.ColumnarBatch
4544
import org.apache.spark.util.SerializableConfiguration
4645
import org.apache.spark.util.collection._
4746

4847
import org.apache.comet.{CometConf, DataTypeSupport, MetricsSupport}
49-
import org.apache.comet.CometSparkSessionExtensions.usingDataSourceExecWithIncompatTypes
5048
import org.apache.comet.parquet.{CometParquetFileFormat, CometParquetPartitionReaderFactory}
5149

5250
/**
@@ -530,8 +528,11 @@ object CometScanExec extends DataTypeSupport {
530528
name: String,
531529
fallbackReasons: ListBuffer[String]): Boolean = {
532530
dt match {
533-
case ByteType | ShortType if usingDataSourceExecWithIncompatTypes(SQLConf.get) =>
534-
fallbackReasons += s"${CometConf.COMET_SCAN_ALLOW_INCOMPATIBLE.key} is false"
531+
case ByteType | ShortType
532+
if CometConf.COMET_NATIVE_SCAN_IMPL.get() == CometConf.SCAN_NATIVE_ICEBERG_COMPAT &&
533+
!CometConf.COMET_SCAN_ALLOW_INCOMPATIBLE.get() =>
534+
fallbackReasons += s"${CometConf.SCAN_NATIVE_ICEBERG_COMPAT} scan cannot read $dt when " +
535+
s"${CometConf.COMET_SCAN_ALLOW_INCOMPATIBLE.key} is false. ${CometConf.COMPAT_GUIDE}."
535536
false
536537
case _: StructType | _: ArrayType | _: MapType
537538
if CometConf.COMET_NATIVE_SCAN_IMPL.get() != CometConf.SCAN_NATIVE_ICEBERG_COMPAT =>

spark/src/test/scala/org/apache/comet/CometCastSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
6060
private val timestampPattern = "0123456789/:T" + whitespaceChars
6161

6262
lazy val usingParquetExecWithIncompatTypes: Boolean =
63-
CometSparkSessionExtensions.usingDataSourceExecWithIncompatTypes(conf)
63+
usingDataSourceExecWithIncompatTypes(conf)
6464

6565
test("all valid cast combinations covered") {
6666
val names = testNames

spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
140140
Byte.MaxValue)
141141
withParquetTable(path.toString, "tbl") {
142142
val qry = "select _9 from tbl order by _11"
143-
if (CometSparkSessionExtensions.usingDataSourceExec(conf)) {
143+
if (usingDataSourceExec(conf)) {
144144
if (!allowIncompatible) {
145145
checkSparkAnswer(qry)
146146
} else {

spark/src/test/scala/org/apache/comet/CometFuzzTestSuite.scala

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ import org.apache.spark.sql.CometTestBase
3232
import org.apache.spark.sql.comet.{CometNativeScanExec, CometScanExec}
3333
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
3434
import org.apache.spark.sql.execution.SparkPlan
35-
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
35+
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper}
3636
import org.apache.spark.sql.internal.SQLConf
3737
import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType
3838
import org.apache.spark.sql.types._
@@ -162,14 +162,41 @@ class CometFuzzTestSuite extends CometTestBase with AdaptiveSparkPlanHelper {
162162
}
163163
}
164164

165-
test("shuffle") {
165+
test("distribute by single column (complex types)") {
166+
val df = spark.read.parquet(filename)
167+
df.createOrReplaceTempView("t1")
168+
val columns = df.schema.fields.filter(f => isComplexType(f.dataType)).map(_.name)
169+
for (col <- columns) {
170+
// DISTRIBUTE BY is equivalent to df.repartition($col) and uses
171+
val sql = s"SELECT $col FROM t1 DISTRIBUTE BY $col"
172+
val df = spark.sql(sql)
173+
df.collect()
174+
// check for Comet shuffle
175+
val plan = df.queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan
176+
val cometShuffleExchanges = collectCometShuffleExchanges(plan)
177+
val expectedNumCometShuffles = CometConf.COMET_NATIVE_SCAN_IMPL.get() match {
178+
case CometConf.SCAN_NATIVE_COMET =>
179+
// native_comet does not support reading complex types
180+
0
181+
case CometConf.SCAN_NATIVE_ICEBERG_COMPAT | CometConf.SCAN_NATIVE_DATAFUSION =>
182+
CometConf.COMET_SHUFFLE_MODE.get() match {
183+
case "jvm" =>
184+
1
185+
case "native" =>
186+
// native shuffle does not support complex types as partitioning keys
187+
0
188+
}
189+
}
190+
assert(cometShuffleExchanges.length == expectedNumCometShuffles)
191+
}
192+
}
193+
194+
test("shuffle supports all types") {
166195
val df = spark.read.parquet(filename)
167196
val df2 = df.repartition(8, df.col("c0")).sort("c1")
168197
df2.collect()
169198
if (CometConf.isExperimentalNativeScan) {
170-
val cometShuffles = collect(df2.queryExecution.executedPlan) {
171-
case exec: CometShuffleExchangeExec => exec
172-
}
199+
val cometShuffles = collectCometShuffleExchanges(df2.queryExecution.executedPlan)
173200
assert(1 == cometShuffles.length)
174201
}
175202
}
@@ -316,4 +343,10 @@ class CometFuzzTestSuite extends CometTestBase with AdaptiveSparkPlanHelper {
316343
}
317344
}
318345

346+
private def collectCometShuffleExchanges(plan: SparkPlan): Seq[SparkPlan] = {
347+
collect(plan) { case exchange: CometShuffleExchangeExec =>
348+
exchange
349+
}
350+
}
351+
319352
}

spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ import org.apache.spark.sql.functions.col
3535
import org.apache.spark.sql.internal.SQLConf
3636
import org.apache.spark.sql.types._
3737

38-
import org.apache.comet.{CometConf, CometSparkSessionExtensions}
38+
import org.apache.comet.CometConf
3939

4040
abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSparkPlanHelper {
4141
protected val adaptiveExecutionEnabled: Boolean
@@ -758,7 +758,7 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar
758758
// TODO: revisit this when we have resolution of https://github.com/apache/arrow-rs/issues/7040
759759
// and https://github.com/apache/arrow-rs/issues/7097
760760
val fieldsToTest =
761-
if (CometSparkSessionExtensions.usingDataSourceExec(conf)) {
761+
if (usingDataSourceExec(conf)) {
762762
Seq(
763763
$"_1",
764764
$"_4",

0 commit comments

Comments
 (0)