diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala index 7dba24bff7..81ac72247f 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala @@ -37,7 +37,7 @@ import org.apache.parquet.hadoop.example.{ExampleParquetWriter, GroupWriteSuppor import org.apache.parquet.schema.{MessageType, MessageTypeParser} import org.apache.spark._ import org.apache.spark.internal.config.{MEMORY_OFFHEAP_ENABLED, MEMORY_OFFHEAP_SIZE, SHUFFLE_MANAGER} -import org.apache.spark.sql.comet._ +import org.apache.spark.sql.comet.CometPlanChecker import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle, CometShuffleExchangeExec} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper @@ -58,7 +58,8 @@ abstract class CometTestBase with BeforeAndAfterEach with AdaptiveSparkPlanHelper with ShimCometSparkSessionExtensions - with ShimCometTestBase { + with ShimCometTestBase + with CometPlanChecker { import testImplicits._ protected val shuffleManager: String = @@ -396,26 +397,6 @@ abstract class CometTestBase checkPlanNotMissingInput(plan) } - protected def findFirstNonCometOperator( - plan: SparkPlan, - excludedClasses: Class[_]*): Option[SparkPlan] = { - val wrapped = wrapCometSparkToColumnar(plan) - wrapped.foreach { - case _: CometNativeScanExec | _: CometScanExec | _: CometBatchScanExec | - _: CometIcebergNativeScanExec => - case _: CometSinkPlaceHolder | _: CometScanWrapper => - case _: CometColumnarToRowExec => - case _: CometSparkToColumnarExec => - case _: CometExec | _: CometShuffleExchangeExec => - case _: CometBroadcastExchangeExec => - case _: WholeStageCodegenExec | _: ColumnarToRowExec | _: InputAdapter => - case op if !excludedClasses.exists(c => c.isAssignableFrom(op.getClass)) => - return Some(op) - case _ => - } - None - } - // checks the plan node has no missing inputs // such nodes represented in plan with exclamation mark ! // example: !CometWindowExec @@ -449,14 +430,6 @@ abstract class CometTestBase } } - /** Wraps the CometRowToColumn as ScanWrapper, so the child operators will not be checked */ - private def wrapCometSparkToColumnar(plan: SparkPlan): SparkPlan = { - plan.transformDown { - // don't care the native operators - case p: CometSparkToColumnarExec => CometScanWrapper(null, p) - } - } - private var _spark: SparkSessionType = _ override protected implicit def spark: SparkSessionType = _spark protected implicit def sqlContext: SQLContext = _spark.sqlContext diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometBenchmarkBase.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometBenchmarkBase.scala index 8d56cefa05..5d1d0c5718 100644 --- a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometBenchmarkBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometBenchmarkBase.scala @@ -31,6 +31,8 @@ import org.apache.parquet.crypto.keytools.mocks.InMemoryKMS import org.apache.spark.SparkConf import org.apache.spark.benchmark.Benchmark import org.apache.spark.sql.{DataFrame, DataFrameWriter, Row, SparkSession} +import org.apache.spark.sql.comet.CometPlanChecker +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.benchmark.SqlBasedBenchmark import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.DecimalType @@ -38,7 +40,10 @@ import org.apache.spark.sql.types.DecimalType import org.apache.comet.CometConf import org.apache.comet.CometSparkSessionExtensions -trait CometBenchmarkBase extends SqlBasedBenchmark { +trait CometBenchmarkBase + extends SqlBasedBenchmark + with AdaptiveSparkPlanHelper + with CometPlanChecker { override def getSparkSession: SparkSession = { val conf = new SparkConf() .setAppName("CometReadBenchmark") @@ -88,28 +93,6 @@ trait CometBenchmarkBase extends SqlBasedBenchmark { } } - /** Runs function `f` with Comet on and off. */ - final def runWithComet(name: String, cardinality: Long)(f: => Unit): Unit = { - val benchmark = new Benchmark(name, cardinality, output = output) - - benchmark.addCase(s"$name - Spark ") { _ => - withSQLConf(CometConf.COMET_ENABLED.key -> "false") { - f - } - } - - benchmark.addCase(s"$name - Comet") { _ => - withSQLConf( - CometConf.COMET_ENABLED.key -> "true", - CometConf.COMET_EXEC_ENABLED.key -> "true", - SQLConf.ANSI_ENABLED.key -> "false") { - f - } - } - - benchmark.run() - } - /** * Runs an expression benchmark with standard cases: Spark, Comet (Scan), Comet (Scan + Exec). * This provides a consistent benchmark structure for expression evaluation. @@ -149,6 +132,29 @@ trait CometBenchmarkBase extends SqlBasedBenchmark { CometConf.COMET_EXEC_ENABLED.key -> "true", "spark.sql.optimizer.constantFolding.enabled" -> "false") ++ extraCometConfigs + // Check that the plan is fully Comet native before running the benchmark + withSQLConf(cometExecConfigs.toSeq: _*) { + val df = spark.sql(query) + df.noop() + val plan = stripAQEPlan(df.queryExecution.executedPlan) + findFirstNonCometOperator(plan) match { + case Some(op) => + // scalastyle:off println + println() + println("=" * 80) + println("WARNING: Benchmark plan is NOT fully Comet native!") + println(s"First non-Comet operator: ${op.nodeName}") + println("=" * 80) + println("Query plan:") + println(plan.treeString) + println("=" * 80) + println() + // scalastyle:on println + case None => + // All operators are Comet native, no warning needed + } + } + benchmark.addCase("Comet (Scan + Exec)") { _ => withSQLConf(cometExecConfigs.toSeq: _*) { spark.sql(query).noop() diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanChecker.scala b/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanChecker.scala new file mode 100644 index 0000000000..7caac71351 --- /dev/null +++ b/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanChecker.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet + +import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec +import org.apache.spark.sql.execution.{ColumnarToRowExec, InputAdapter, SparkPlan, WholeStageCodegenExec} + +/** + * Trait providing utilities to check if a Spark plan is fully running on Comet native operators. + * Used by both CometTestBase and CometBenchmarkBase. + */ +trait CometPlanChecker { + + /** + * Finds the first non-Comet operator in the plan, if any. + * + * @param plan + * The SparkPlan to check + * @param excludedClasses + * Classes to exclude from the check (these are allowed to be non-Comet) + * @return + * Some(operator) if a non-Comet operator is found, None otherwise + */ + protected def findFirstNonCometOperator( + plan: SparkPlan, + excludedClasses: Class[_]*): Option[SparkPlan] = { + val wrapped = wrapCometSparkToColumnar(plan) + wrapped.foreach { + case _: CometNativeScanExec | _: CometScanExec | _: CometBatchScanExec | + _: CometIcebergNativeScanExec => + case _: CometSinkPlaceHolder | _: CometScanWrapper => + case _: CometColumnarToRowExec => + case _: CometSparkToColumnarExec => + case _: CometExec | _: CometShuffleExchangeExec => + case _: CometBroadcastExchangeExec => + case _: WholeStageCodegenExec | _: ColumnarToRowExec | _: InputAdapter => + case op if !excludedClasses.exists(c => c.isAssignableFrom(op.getClass)) => + return Some(op) + case _ => + } + None + } + + /** Wraps the CometSparkToColumnar as ScanWrapper, so the child operators will not be checked */ + private def wrapCometSparkToColumnar(plan: SparkPlan): SparkPlan = { + plan.transformDown { + // don't care the native operators + case p: CometSparkToColumnarExec => CometScanWrapper(null, p) + } + } +}