diff --git a/Makefile b/Makefile index a96fdab6ee..c00a3a27f8 100644 --- a/Makefile +++ b/Makefile @@ -105,5 +105,35 @@ release-nogit: ./mvnw install -Prelease -DskipTests $(PROFILES) -Dmaven.gitcommitid.skip=true benchmark-%: release cd spark && COMET_CONF_DIR=$(shell pwd)/conf MAVEN_OPTS='-Xmx20g ${call spark_jvm_17_extra_args}' ../mvnw exec:java -Dexec.mainClass="$*" -Dexec.classpathScope="test" -Dexec.cleanupDaemonThreads="false" -Dexec.args="$(filter-out $@,$(MAKECMDGOALS))" $(PROFILES) + +# Discover all benchmark classes dynamically +BENCHMARK_CLASSES := $(shell find spark/src/test/scala/org/apache/spark/sql/benchmark -name "Comet*Benchmark.scala" -type f | \ + xargs grep -l "object.*Benchmark.*extends.*CometBenchmarkBase" | \ + sed 's|spark/src/test/scala/||g' | \ + sed 's|/|.|g' | \ + sed 's|.scala||g' | \ + sort) + +# Run all discovered benchmarks +benchmark-all: + @echo "Discovered benchmarks:" + @echo "$(BENCHMARK_CLASSES)" | tr ' ' '\n' + @echo "" + @echo "Running all benchmarks (this will take a long time)..." + @for benchmark in $(BENCHMARK_CLASSES); do \ + echo ""; \ + echo "======================================"; \ + echo "Running: $$benchmark"; \ + echo "======================================"; \ + SPARK_GENERATE_BENCHMARK_FILES=1 $(MAKE) benchmark-$$benchmark || echo "WARNING: $$benchmark failed"; \ + done + @echo "" + @echo "All benchmarks completed!" + +# List all available benchmarks +list-benchmarks: + @echo "Available benchmarks:" + @echo "$(BENCHMARK_CLASSES)" | tr ' ' '\n' + .DEFAULT: @: # ignore arguments provided to benchmarks e.g. "make benchmark-foo -- --bar", we do not want to treat "--bar" as target diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometAggregateBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometAggregateBenchmark.scala index d63b3e7106..84a36ff6fa 100644 --- a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometAggregateBenchmark.scala +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometAggregateBenchmark.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.benchmark -import scala.util.Try - import org.apache.spark.benchmark.Benchmark import org.apache.spark.sql.SparkSession import org.apache.spark.sql.internal.SQLConf @@ -87,7 +85,7 @@ object CometAggregateBenchmark extends CometBenchmarkBase { s"SQL Parquet - Spark (${aggregateFunction.toString}) ansi mode enabled : ${isAnsiMode}") { _ => withSQLConf(SQLConf.ANSI_ENABLED.key -> isAnsiMode.toString) { - Try { spark.sql(query).noop() } + spark.sql(query).noop() } } @@ -98,7 +96,7 @@ object CometAggregateBenchmark extends CometBenchmarkBase { CometConf.COMET_ENABLED.key -> "true", CometConf.COMET_EXEC_ENABLED.key -> "true", SQLConf.ANSI_ENABLED.key -> isAnsiMode.toString) { - Try { spark.sql(query).noop() } + spark.sql(query).noop() } } @@ -137,7 +135,7 @@ object CometAggregateBenchmark extends CometBenchmarkBase { benchmark.addCase( s"SQL Parquet - Spark (${aggregateFunction.toString}) ansi mode enabled : ${isAnsiMode}") { _ => - Try { spark.sql(query).noop() } + spark.sql(query).noop() } } @@ -148,7 +146,7 @@ object CometAggregateBenchmark extends CometBenchmarkBase { CometConf.COMET_ENABLED.key -> "true", CometConf.COMET_EXEC_ENABLED.key -> "true", SQLConf.ANSI_ENABLED.key -> isAnsiMode.toString) { - Try { spark.sql(query).noop() } + spark.sql(query).noop() } } @@ -185,7 +183,7 @@ object CometAggregateBenchmark extends CometBenchmarkBase { s"SQL Parquet - Spark (${aggregateFunction.toString}) isANSIMode: ${isAnsiMode.toString}") { _ => withSQLConf(SQLConf.ANSI_ENABLED.key -> isAnsiMode.toString) { - Try { spark.sql(query).noop() } + spark.sql(query).noop() } } @@ -197,7 +195,7 @@ object CometAggregateBenchmark extends CometBenchmarkBase { CometConf.COMET_EXEC_ENABLED.key -> "true", CometConf.COMET_ONHEAP_MEMORY_OVERHEAD.key -> "1G", SQLConf.ANSI_ENABLED.key -> isAnsiMode.toString) { - Try { spark.sql(query).noop() } + spark.sql(query).noop() } } @@ -236,7 +234,7 @@ object CometAggregateBenchmark extends CometBenchmarkBase { s"SQL Parquet - Spark (${aggregateFunction.toString}) isANSIMode: ${isAnsiMode.toString}") { _ => withSQLConf(SQLConf.ANSI_ENABLED.key -> isAnsiMode.toString) { - Try { spark.sql(query).noop() } + spark.sql(query).noop() } } @@ -247,7 +245,7 @@ object CometAggregateBenchmark extends CometBenchmarkBase { CometConf.COMET_ENABLED.key -> "true", CometConf.COMET_EXEC_ENABLED.key -> "true", SQLConf.ANSI_ENABLED.key -> isAnsiMode.toString) { - Try { spark.sql(query).noop() } + spark.sql(query).noop() } } @@ -260,36 +258,36 @@ object CometAggregateBenchmark extends CometBenchmarkBase { val total = 1024 * 1024 * 10 val combinations = List(100, 1024, 1024 * 1024) // number of distinct groups benchmarkAggFuncs.foreach { aggFunc => - Seq(true, false).foreach(k => { - runBenchmarkWithTable( + Seq(true, false).foreach(ansiMode => { + runBenchmarkWithSafeTable( s"Grouped Aggregate (single group key + single aggregate $aggFunc)", total) { v => for (card <- combinations) { - singleGroupAndAggregate(v, card, aggFunc, k) + singleGroupAndAggregate(v, card, aggFunc, ansiMode) } } - runBenchmarkWithTable( + runBenchmarkWithSafeTable( s"Grouped Aggregate (multiple group keys + single aggregate $aggFunc)", total) { v => for (card <- combinations) { - multiGroupKeys(v, card, aggFunc, k) + multiGroupKeys(v, card, aggFunc, ansiMode) } } - runBenchmarkWithTable( + runBenchmarkWithSafeTable( s"Grouped Aggregate (single group key + multiple aggregates $aggFunc)", total) { v => for (card <- combinations) { - multiAggregates(v, card, aggFunc, k) + multiAggregates(v, card, aggFunc, ansiMode) } } - runBenchmarkWithTable( + runBenchmarkWithSafeTable( s"Grouped Aggregate (single group key + single aggregate $aggFunc on decimal)", total) { v => for (card <- combinations) { - singleGroupAndAggregateDecimal(v, DecimalType(18, 10), card, aggFunc, k) + singleGroupAndAggregateDecimal(v, DecimalType(18, 10), card, aggFunc, ansiMode) } } }) diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometArithmeticBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometArithmeticBenchmark.scala index a513aa1a77..0c1c1e2cca 100644 --- a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometArithmeticBenchmark.scala +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometArithmeticBenchmark.scala @@ -44,7 +44,7 @@ object CometArithmeticBenchmark extends CometBenchmarkBase { val name = s"Binary op ${dataType.sql}, dictionary = $useDictionary" val query = s"SELECT c1 ${op.sig} c2 FROM $table" - runExpressionBenchmark(name, values, query) + runExpressionBenchmark(name, values, query, isAnsiMode = false) } } } @@ -64,7 +64,7 @@ object CometArithmeticBenchmark extends CometBenchmarkBase { val name = s"Binary op ${dataType.sql}, dictionary = $useDictionary" val query = s"SELECT c1 ${op.sig} c2 FROM $table" - runExpressionBenchmark(name, values, query) + runExpressionBenchmark(name, values, query, isAnsiMode = false) } } } @@ -84,7 +84,7 @@ object CometArithmeticBenchmark extends CometBenchmarkBase { Seq(true, false).foreach { useDictionary => Seq(Minus, Mul).foreach { op => - runBenchmarkWithTable(op.name, TOTAL, useDictionary) { v => + runBenchmarkWithSafeTable(op.name, TOTAL, useDictionary) { v => integerArithmeticBenchmark(v, op, useDictionary) } } diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometBenchmarkBase.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometBenchmarkBase.scala index 8d56cefa05..bf9e7faadf 100644 --- a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometBenchmarkBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometBenchmarkBase.scala @@ -33,7 +33,7 @@ import org.apache.spark.benchmark.Benchmark import org.apache.spark.sql.{DataFrame, DataFrameWriter, Row, SparkSession} import org.apache.spark.sql.execution.benchmark.SqlBasedBenchmark import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.DecimalType +import org.apache.spark.sql.types.{DataType, DecimalType} import org.apache.comet.CometConf import org.apache.comet.CometSparkSessionExtensions @@ -88,26 +88,52 @@ trait CometBenchmarkBase extends SqlBasedBenchmark { } } - /** Runs function `f` with Comet on and off. */ - final def runWithComet(name: String, cardinality: Long)(f: => Unit): Unit = { - val benchmark = new Benchmark(name, cardinality, output = output) - - benchmark.addCase(s"$name - Spark ") { _ => - withSQLConf(CometConf.COMET_ENABLED.key -> "false") { - f - } + /** + * Creates a table with ANSI-safe values that won't overflow in arithmetic operations. Use this + * instead of runBenchmarkWithTable for arithmetic/aggregate benchmarks. + */ + protected def runBenchmarkWithSafeTable( + benchmarkName: String, + values: Int, + useDictionary: Boolean = false)(f: Int => Any): Unit = { + withTempTable(tbl) { + import spark.implicits._ + spark + .range(values) + .map(i => if (useDictionary) i % 5 else i % 10000) + .createOrReplaceTempView(tbl) + runBenchmark(benchmarkName)(f(values)) } + } - benchmark.addCase(s"$name - Comet") { _ => - withSQLConf( - CometConf.COMET_ENABLED.key -> "true", - CometConf.COMET_EXEC_ENABLED.key -> "true", - SQLConf.ANSI_ENABLED.key -> "false") { - f - } + /** + * Generates ANSI-safe data for casting from Long to the specified target type. Returns a SQL + * expression that transforms the base "value" column to be within safe ranges. + * + * @param targetType + * The target data type for casting + * @return + * SQL expression to generate safe data + */ + protected def generateAnsiSafeData(targetType: DataType): String = { + import org.apache.spark.sql.types._ + +// we generate long inputs initially and this case statement translates them into right data type so that the code doesn't fail in ANSI mode + targetType match { + case ByteType => "CAST((value % 128) AS BIGINT)" + case ShortType => "CAST((value % 32768) AS BIGINT)" + case IntegerType => "CAST((value % 2147483648) AS BIGINT)" + case LongType => "value" + case FloatType => "CAST((value % 1000000) AS BIGINT)" + case DoubleType => "value" + case _: DecimalType => "CAST((value % 100000000) AS BIGINT)" + case StringType => "CAST(value AS STRING)" + case BooleanType => "CAST((value % 2) AS BIGINT)" + case DateType => "CAST((value % 18262) AS BIGINT)" + case TimestampType => "value" + case BinaryType => "value" + case _ => "value" } - - benchmark.run() } /** @@ -127,37 +153,48 @@ trait CometBenchmarkBase extends SqlBasedBenchmark { name: String, cardinality: Long, query: String, + isAnsiMode: Boolean, extraCometConfigs: Map[String, String] = Map.empty): Unit = { + val benchmark = new Benchmark(name, cardinality, output = output) benchmark.addCase("Spark") { _ => - withSQLConf(CometConf.COMET_ENABLED.key -> "false") { - spark.sql(query).noop() + withSQLConf( + CometConf.COMET_ENABLED.key -> "false", + SQLConf.ANSI_ENABLED.key -> isAnsiMode.toString) { + runSparkCommand(spark, query, isAnsiMode) } } benchmark.addCase("Comet (Scan)") { _ => withSQLConf( CometConf.COMET_ENABLED.key -> "true", - CometConf.COMET_EXEC_ENABLED.key -> "false") { - spark.sql(query).noop() + CometConf.COMET_EXEC_ENABLED.key -> "false", + SQLConf.ANSI_ENABLED.key -> isAnsiMode.toString) { + runSparkCommand(spark, query, isAnsiMode) } } val cometExecConfigs = Map( CometConf.COMET_ENABLED.key -> "true", CometConf.COMET_EXEC_ENABLED.key -> "true", - "spark.sql.optimizer.constantFolding.enabled" -> "false") ++ extraCometConfigs + "spark.sql.optimizer.constantFolding.enabled" -> "false", + SQLConf.ANSI_ENABLED.key -> isAnsiMode.toString) ++ extraCometConfigs benchmark.addCase("Comet (Scan + Exec)") { _ => withSQLConf(cometExecConfigs.toSeq: _*) { - spark.sql(query).noop() + runSparkCommand(spark, query, isAnsiMode) } } benchmark.run() } + private def runSparkCommand(spark: SparkSession, query: String, isANSIMode: Boolean): Unit = { + // With ANSI-safe data generation, queries should not throw exceptions + spark.sql(query).noop() + } + protected def prepareTable(dir: File, df: DataFrame, partition: Option[String] = None): Unit = { val testDf = if (partition.isDefined) { df.write.partitionBy(partition.get) @@ -250,7 +287,9 @@ trait CometBenchmarkBase extends SqlBasedBenchmark { useDictionary: Boolean): DataFrame = { import spark.implicits._ - val div = if (useDictionary) 5 else values + // Use safe range to avoid overflow in decimal operations + val maxValue = 10000 + val div = if (useDictionary) 5 else maxValue spark .range(values) .map(_ % div) diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometCastBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometCastBenchmark.scala index 975abd632f..91ea7e2317 100644 --- a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometCastBenchmark.scala +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometCastBenchmark.scala @@ -79,16 +79,22 @@ object CometCastBenchmark extends CometBenchmarkBase { withTempPath { dir => withTempTable("parquetV1Table") { - prepareTable(dir, spark.sql(s"SELECT value FROM $tbl")) + // Generate ANSI-safe data when in ANSI mode to avoid overflow exceptions + // In legacy mode, use raw values to test overflow handling + val dataExpr = if (isAnsiMode) { + generateAnsiSafeData(toDataType) + } else { + "value" + } + + prepareTable(dir, spark.sql(s"SELECT $dataExpr as value FROM $tbl")) val functionSQL = castExprSQL(toDataType, "value") val query = s"SELECT $functionSQL FROM parquetV1Table" val name = - s"Cast function to : ${toDataType} , ansi mode enabled : ${isAnsiMode}" - - val extraConfigs = Map(SQLConf.ANSI_ENABLED.key -> isAnsiMode.toString) + s"Cast function from : ${fromDataType} to : ${toDataType} , ansi mode enabled : ${isAnsiMode}" - runExpressionBenchmark(name, values, query, extraConfigs) + runExpressionBenchmark(name, values, query, isAnsiMode) } } } diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometConditionalExpressionBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometConditionalExpressionBenchmark.scala index c5eb9ea390..b710552d44 100644 --- a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometConditionalExpressionBenchmark.scala +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometConditionalExpressionBenchmark.scala @@ -35,7 +35,7 @@ object CometConditionalExpressionBenchmark extends CometBenchmarkBase { val query = "select CASE WHEN c1 < 0 THEN '<0' WHEN c1 = 0 THEN '=0' ELSE '>0' END from parquetV1Table" - runExpressionBenchmark("Case When Expr", values, query) + runExpressionBenchmark("Case When Expr", values, query, isAnsiMode = false) } } } @@ -47,7 +47,7 @@ object CometConditionalExpressionBenchmark extends CometBenchmarkBase { val query = "select IF (c1 < 0, '<0', '>=0') from parquetV1Table" - runExpressionBenchmark("If Expr", values, query) + runExpressionBenchmark("If Expr", values, query, isAnsiMode = false) } } } diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometDatetimeExpressionBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometDatetimeExpressionBenchmark.scala index 47eff41bbd..abcae06e7d 100644 --- a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometDatetimeExpressionBenchmark.scala +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometDatetimeExpressionBenchmark.scala @@ -41,7 +41,7 @@ object CometDatetimeExpressionBenchmark extends CometBenchmarkBase { val isDictionary = if (useDictionary) "(Dictionary)" else "" val name = s"Date Truncate $isDictionary - $level" val query = s"select trunc(dt, '$level') from parquetV1Table" - runExpressionBenchmark(name, values, query) + runExpressionBenchmark(name, values, query, isAnsiMode = false) } } } @@ -70,7 +70,7 @@ object CometDatetimeExpressionBenchmark extends CometBenchmarkBase { val isDictionary = if (useDictionary) "(Dictionary)" else "" val name = s"Timestamp Truncate $isDictionary - $level" val query = s"select date_trunc('$level', ts) from parquetV1Table" - runExpressionBenchmark(name, values, query) + runExpressionBenchmark(name, values, query, isAnsiMode = false) } } } diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometJsonExpressionBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometJsonExpressionBenchmark.scala index 5b4741ba68..f52d4900d2 100644 --- a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometJsonExpressionBenchmark.scala +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometJsonExpressionBenchmark.scala @@ -120,7 +120,12 @@ object CometJsonExpressionBenchmark extends CometBenchmarkBase { CometConf.getExprAllowIncompatConfigKey( classOf[JsonToStructs]) -> "true") ++ config.extraCometConfigs - runExpressionBenchmark(config.name, values, config.query, extraConfigs) + runExpressionBenchmark( + config.name, + values, + config.query, + isAnsiMode = false, + extraConfigs) } } } diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometPredicateExpressionBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometPredicateExpressionBenchmark.scala index 6506c5665d..db68e2db40 100644 --- a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometPredicateExpressionBenchmark.scala +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometPredicateExpressionBenchmark.scala @@ -38,7 +38,7 @@ object CometPredicateExpressionBenchmark extends CometBenchmarkBase { val query = "select * from parquetV1Table where c1 in ('positive', 'zero')" - runExpressionBenchmark("in Expr", values, query) + runExpressionBenchmark("in Expr", values, query, isAnsiMode = false) } } } diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometStringExpressionBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometStringExpressionBenchmark.scala index 41eabb8513..7f27cd593b 100644 --- a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometStringExpressionBenchmark.scala +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometStringExpressionBenchmark.scala @@ -55,7 +55,12 @@ object CometStringExpressionBenchmark extends CometBenchmarkBase { val extraConfigs = Map(CometConf.COMET_CASE_CONVERSION_ENABLED.key -> "true") ++ config.extraCometConfigs - runExpressionBenchmark(config.name, values, config.query, extraConfigs) + runExpressionBenchmark( + config.name, + values, + config.query, + isAnsiMode = false, + extraConfigs) } } }