Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -105,5 +105,35 @@ release-nogit:
./mvnw install -Prelease -DskipTests $(PROFILES) -Dmaven.gitcommitid.skip=true
benchmark-%: release
cd spark && COMET_CONF_DIR=$(shell pwd)/conf MAVEN_OPTS='-Xmx20g ${call spark_jvm_17_extra_args}' ../mvnw exec:java -Dexec.mainClass="$*" -Dexec.classpathScope="test" -Dexec.cleanupDaemonThreads="false" -Dexec.args="$(filter-out $@,$(MAKECMDGOALS))" $(PROFILES)

# Discover all benchmark classes dynamically
BENCHMARK_CLASSES := $(shell find spark/src/test/scala/org/apache/spark/sql/benchmark -name "Comet*Benchmark.scala" -type f | \
xargs grep -l "object.*Benchmark.*extends.*CometBenchmarkBase" | \
sed 's|spark/src/test/scala/||g' | \
sed 's|/|.|g' | \
sed 's|.scala||g' | \
sort)

# Run all discovered benchmarks
benchmark-all:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added convenient way to run all benchmark through make file

@echo "Discovered benchmarks:"
@echo "$(BENCHMARK_CLASSES)" | tr ' ' '\n'
@echo ""
@echo "Running all benchmarks (this will take a long time)..."
@for benchmark in $(BENCHMARK_CLASSES); do \
echo ""; \
echo "======================================"; \
echo "Running: $$benchmark"; \
echo "======================================"; \
SPARK_GENERATE_BENCHMARK_FILES=1 $(MAKE) benchmark-$$benchmark || echo "WARNING: $$benchmark failed"; \
done
@echo ""
@echo "All benchmarks completed!"

# List all available benchmarks
list-benchmarks:
@echo "Available benchmarks:"
@echo "$(BENCHMARK_CLASSES)" | tr ' ' '\n'

.DEFAULT:
@: # ignore arguments provided to benchmarks e.g. "make benchmark-foo -- --bar", we do not want to treat "--bar" as target
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@

package org.apache.spark.sql.benchmark

import scala.util.Try

import org.apache.spark.benchmark.Benchmark
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -87,7 +85,7 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
s"SQL Parquet - Spark (${aggregateFunction.toString}) ansi mode enabled : ${isAnsiMode}") {
_ =>
withSQLConf(SQLConf.ANSI_ENABLED.key -> isAnsiMode.toString) {
Try { spark.sql(query).noop() }
spark.sql(query).noop()
}
}

Expand All @@ -98,7 +96,7 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
CometConf.COMET_ENABLED.key -> "true",
CometConf.COMET_EXEC_ENABLED.key -> "true",
SQLConf.ANSI_ENABLED.key -> isAnsiMode.toString) {
Try { spark.sql(query).noop() }
spark.sql(query).noop()
}
}

Expand Down Expand Up @@ -137,7 +135,7 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
benchmark.addCase(
s"SQL Parquet - Spark (${aggregateFunction.toString}) ansi mode enabled : ${isAnsiMode}") {
_ =>
Try { spark.sql(query).noop() }
spark.sql(query).noop()
}
}

Expand All @@ -148,7 +146,7 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
CometConf.COMET_ENABLED.key -> "true",
CometConf.COMET_EXEC_ENABLED.key -> "true",
SQLConf.ANSI_ENABLED.key -> isAnsiMode.toString) {
Try { spark.sql(query).noop() }
spark.sql(query).noop()
}
}

Expand Down Expand Up @@ -185,7 +183,7 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
s"SQL Parquet - Spark (${aggregateFunction.toString}) isANSIMode: ${isAnsiMode.toString}") {
_ =>
withSQLConf(SQLConf.ANSI_ENABLED.key -> isAnsiMode.toString) {
Try { spark.sql(query).noop() }
spark.sql(query).noop()
}
}

Expand All @@ -197,7 +195,7 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
CometConf.COMET_EXEC_ENABLED.key -> "true",
CometConf.COMET_ONHEAP_MEMORY_OVERHEAD.key -> "1G",
SQLConf.ANSI_ENABLED.key -> isAnsiMode.toString) {
Try { spark.sql(query).noop() }
spark.sql(query).noop()
}
}

Expand Down Expand Up @@ -236,7 +234,7 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
s"SQL Parquet - Spark (${aggregateFunction.toString}) isANSIMode: ${isAnsiMode.toString}") {
_ =>
withSQLConf(SQLConf.ANSI_ENABLED.key -> isAnsiMode.toString) {
Try { spark.sql(query).noop() }
spark.sql(query).noop()
}
}

Expand All @@ -247,7 +245,7 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
CometConf.COMET_ENABLED.key -> "true",
CometConf.COMET_EXEC_ENABLED.key -> "true",
SQLConf.ANSI_ENABLED.key -> isAnsiMode.toString) {
Try { spark.sql(query).noop() }
spark.sql(query).noop()
}
}

Expand All @@ -260,36 +258,36 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
val total = 1024 * 1024 * 10
val combinations = List(100, 1024, 1024 * 1024) // number of distinct groups
benchmarkAggFuncs.foreach { aggFunc =>
Seq(true, false).foreach(k => {
runBenchmarkWithTable(
Seq(true, false).foreach(ansiMode => {
runBenchmarkWithSafeTable(
s"Grouped Aggregate (single group key + single aggregate $aggFunc)",
total) { v =>
for (card <- combinations) {
singleGroupAndAggregate(v, card, aggFunc, k)
singleGroupAndAggregate(v, card, aggFunc, ansiMode)
}
}

runBenchmarkWithTable(
runBenchmarkWithSafeTable(
s"Grouped Aggregate (multiple group keys + single aggregate $aggFunc)",
total) { v =>
for (card <- combinations) {
multiGroupKeys(v, card, aggFunc, k)
multiGroupKeys(v, card, aggFunc, ansiMode)
}
}

runBenchmarkWithTable(
runBenchmarkWithSafeTable(
s"Grouped Aggregate (single group key + multiple aggregates $aggFunc)",
total) { v =>
for (card <- combinations) {
multiAggregates(v, card, aggFunc, k)
multiAggregates(v, card, aggFunc, ansiMode)
}
}

runBenchmarkWithTable(
runBenchmarkWithSafeTable(
s"Grouped Aggregate (single group key + single aggregate $aggFunc on decimal)",
total) { v =>
for (card <- combinations) {
singleGroupAndAggregateDecimal(v, DecimalType(18, 10), card, aggFunc, k)
singleGroupAndAggregateDecimal(v, DecimalType(18, 10), card, aggFunc, ansiMode)
}
}
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ object CometArithmeticBenchmark extends CometBenchmarkBase {
val name = s"Binary op ${dataType.sql}, dictionary = $useDictionary"
val query = s"SELECT c1 ${op.sig} c2 FROM $table"

runExpressionBenchmark(name, values, query)
runExpressionBenchmark(name, values, query, isAnsiMode = false)
}
}
}
Expand All @@ -64,7 +64,7 @@ object CometArithmeticBenchmark extends CometBenchmarkBase {
val name = s"Binary op ${dataType.sql}, dictionary = $useDictionary"
val query = s"SELECT c1 ${op.sig} c2 FROM $table"

runExpressionBenchmark(name, values, query)
runExpressionBenchmark(name, values, query, isAnsiMode = false)
}
}
}
Expand All @@ -84,7 +84,7 @@ object CometArithmeticBenchmark extends CometBenchmarkBase {

Seq(true, false).foreach { useDictionary =>
Seq(Minus, Mul).foreach { op =>
runBenchmarkWithTable(op.name, TOTAL, useDictionary) { v =>
runBenchmarkWithSafeTable(op.name, TOTAL, useDictionary) { v =>
integerArithmeticBenchmark(v, op, useDictionary)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import org.apache.spark.benchmark.Benchmark
import org.apache.spark.sql.{DataFrame, DataFrameWriter, Row, SparkSession}
import org.apache.spark.sql.execution.benchmark.SqlBasedBenchmark
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.DecimalType
import org.apache.spark.sql.types.{DataType, DecimalType}

import org.apache.comet.CometConf
import org.apache.comet.CometSparkSessionExtensions
Expand Down Expand Up @@ -88,26 +88,52 @@ trait CometBenchmarkBase extends SqlBasedBenchmark {
}
}

/** Runs function `f` with Comet on and off. */
final def runWithComet(name: String, cardinality: Long)(f: => Unit): Unit = {
val benchmark = new Benchmark(name, cardinality, output = output)

Copy link
Contributor Author

@coderfender coderfender Dec 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is no longer being used and needed as well IMO

benchmark.addCase(s"$name - Spark ") { _ =>
withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
f
}
/**
* Creates a table with ANSI-safe values that won't overflow in arithmetic operations. Use this
* instead of runBenchmarkWithTable for arithmetic/aggregate benchmarks.
*/
protected def runBenchmarkWithSafeTable(
benchmarkName: String,
values: Int,
useDictionary: Boolean = false)(f: Int => Any): Unit = {
withTempTable(tbl) {
import spark.implicits._
spark
.range(values)
.map(i => if (useDictionary) i % 5 else i % 10000)
.createOrReplaceTempView(tbl)
runBenchmark(benchmarkName)(f(values))
}
}

benchmark.addCase(s"$name - Comet") { _ =>
withSQLConf(
CometConf.COMET_ENABLED.key -> "true",
CometConf.COMET_EXEC_ENABLED.key -> "true",
SQLConf.ANSI_ENABLED.key -> "false") {
f
}
/**
* Generates ANSI-safe data for casting from Long to the specified target type. Returns a SQL
* expression that transforms the base "value" column to be within safe ranges.
*
* @param targetType
* The target data type for casting
* @return
* SQL expression to generate safe data
*/
protected def generateAnsiSafeData(targetType: DataType): String = {
import org.apache.spark.sql.types._

// we generate long inputs initially and this case statement translates them into right data type so that the code doesn't fail in ANSI mode
targetType match {
case ByteType => "CAST((value % 128) AS BIGINT)"
case ShortType => "CAST((value % 32768) AS BIGINT)"
case IntegerType => "CAST((value % 2147483648) AS BIGINT)"
case LongType => "value"
case FloatType => "CAST((value % 1000000) AS BIGINT)"
case DoubleType => "value"
case _: DecimalType => "CAST((value % 100000000) AS BIGINT)"
case StringType => "CAST(value AS STRING)"
case BooleanType => "CAST((value % 2) AS BIGINT)"
case DateType => "CAST((value % 18262) AS BIGINT)"
case TimestampType => "value"
case BinaryType => "value"
case _ => "value"
}

benchmark.run()
}

/**
Expand All @@ -127,37 +153,48 @@ trait CometBenchmarkBase extends SqlBasedBenchmark {
name: String,
cardinality: Long,
query: String,
isAnsiMode: Boolean,
extraCometConfigs: Map[String, String] = Map.empty): Unit = {

val benchmark = new Benchmark(name, cardinality, output = output)

benchmark.addCase("Spark") { _ =>
withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
spark.sql(query).noop()
withSQLConf(
CometConf.COMET_ENABLED.key -> "false",
SQLConf.ANSI_ENABLED.key -> isAnsiMode.toString) {
runSparkCommand(spark, query, isAnsiMode)
}
}

benchmark.addCase("Comet (Scan)") { _ =>
withSQLConf(
CometConf.COMET_ENABLED.key -> "true",
CometConf.COMET_EXEC_ENABLED.key -> "false") {
spark.sql(query).noop()
CometConf.COMET_EXEC_ENABLED.key -> "false",
SQLConf.ANSI_ENABLED.key -> isAnsiMode.toString) {
runSparkCommand(spark, query, isAnsiMode)
}
}

val cometExecConfigs = Map(
CometConf.COMET_ENABLED.key -> "true",
CometConf.COMET_EXEC_ENABLED.key -> "true",
"spark.sql.optimizer.constantFolding.enabled" -> "false") ++ extraCometConfigs
"spark.sql.optimizer.constantFolding.enabled" -> "false",
SQLConf.ANSI_ENABLED.key -> isAnsiMode.toString) ++ extraCometConfigs

benchmark.addCase("Comet (Scan + Exec)") { _ =>
withSQLConf(cometExecConfigs.toSeq: _*) {
spark.sql(query).noop()
runSparkCommand(spark, query, isAnsiMode)
}
}

benchmark.run()
}

private def runSparkCommand(spark: SparkSession, query: String, isANSIMode: Boolean): Unit = {
// With ANSI-safe data generation, queries should not throw exceptions
spark.sql(query).noop()
}

protected def prepareTable(dir: File, df: DataFrame, partition: Option[String] = None): Unit = {
val testDf = if (partition.isDefined) {
df.write.partitionBy(partition.get)
Expand Down Expand Up @@ -250,7 +287,9 @@ trait CometBenchmarkBase extends SqlBasedBenchmark {
useDictionary: Boolean): DataFrame = {
import spark.implicits._

val div = if (useDictionary) 5 else values
// Use safe range to avoid overflow in decimal operations
val maxValue = 10000
val div = if (useDictionary) 5 else maxValue
spark
.range(values)
.map(_ % div)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,16 +79,22 @@ object CometCastBenchmark extends CometBenchmarkBase {

withTempPath { dir =>
withTempTable("parquetV1Table") {
prepareTable(dir, spark.sql(s"SELECT value FROM $tbl"))
// Generate ANSI-safe data when in ANSI mode to avoid overflow exceptions
// In legacy mode, use raw values to test overflow handling
val dataExpr = if (isAnsiMode) {
generateAnsiSafeData(toDataType)
} else {
"value"
}

prepareTable(dir, spark.sql(s"SELECT $dataExpr as value FROM $tbl"))

val functionSQL = castExprSQL(toDataType, "value")
val query = s"SELECT $functionSQL FROM parquetV1Table"
val name =
s"Cast function to : ${toDataType} , ansi mode enabled : ${isAnsiMode}"

val extraConfigs = Map(SQLConf.ANSI_ENABLED.key -> isAnsiMode.toString)
s"Cast function from : ${fromDataType} to : ${toDataType} , ansi mode enabled : ${isAnsiMode}"

runExpressionBenchmark(name, values, query, extraConfigs)
runExpressionBenchmark(name, values, query, isAnsiMode)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ object CometConditionalExpressionBenchmark extends CometBenchmarkBase {
val query =
"select CASE WHEN c1 < 0 THEN '<0' WHEN c1 = 0 THEN '=0' ELSE '>0' END from parquetV1Table"

runExpressionBenchmark("Case When Expr", values, query)
runExpressionBenchmark("Case When Expr", values, query, isAnsiMode = false)
}
}
}
Expand All @@ -47,7 +47,7 @@ object CometConditionalExpressionBenchmark extends CometBenchmarkBase {

val query = "select IF (c1 < 0, '<0', '>=0') from parquetV1Table"

runExpressionBenchmark("If Expr", values, query)
runExpressionBenchmark("If Expr", values, query, isAnsiMode = false)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ object CometDatetimeExpressionBenchmark extends CometBenchmarkBase {
val isDictionary = if (useDictionary) "(Dictionary)" else ""
val name = s"Date Truncate $isDictionary - $level"
val query = s"select trunc(dt, '$level') from parquetV1Table"
runExpressionBenchmark(name, values, query)
runExpressionBenchmark(name, values, query, isAnsiMode = false)
}
}
}
Expand Down Expand Up @@ -70,7 +70,7 @@ object CometDatetimeExpressionBenchmark extends CometBenchmarkBase {
val isDictionary = if (useDictionary) "(Dictionary)" else ""
val name = s"Timestamp Truncate $isDictionary - $level"
val query = s"select date_trunc('$level', ts) from parquetV1Table"
runExpressionBenchmark(name, values, query)
runExpressionBenchmark(name, values, query, isAnsiMode = false)
}
}
}
Expand Down
Loading
Loading