Skip to content

Commit b5e61b1

Browse files
committed
improve_benchmark
1 parent 593b788 commit b5e61b1

File tree

5 files changed

+113
-32
lines changed

5 files changed

+113
-32
lines changed

Makefile

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,5 +105,35 @@ release-nogit:
105105
./mvnw install -Prelease -DskipTests $(PROFILES) -Dmaven.gitcommitid.skip=true
106106
benchmark-%: release
107107
cd spark && COMET_CONF_DIR=$(shell pwd)/conf MAVEN_OPTS='-Xmx20g ${call spark_jvm_17_extra_args}' ../mvnw exec:java -Dexec.mainClass="$*" -Dexec.classpathScope="test" -Dexec.cleanupDaemonThreads="false" -Dexec.args="$(filter-out $@,$(MAKECMDGOALS))" $(PROFILES)
108+
109+
# Discover all benchmark classes dynamically
110+
BENCHMARK_CLASSES := $(shell find spark/src/test/scala/org/apache/spark/sql/benchmark -name "Comet*Benchmark.scala" -type f | \
111+
xargs grep -l "object.*Benchmark.*extends.*CometBenchmarkBase" | \
112+
sed 's|spark/src/test/scala/||g' | \
113+
sed 's|/|.|g' | \
114+
sed 's|.scala||g' | \
115+
sort)
116+
117+
# Run all discovered benchmarks
118+
benchmark-all:
119+
@echo "Discovered benchmarks:"
120+
@echo "$(BENCHMARK_CLASSES)" | tr ' ' '\n'
121+
@echo ""
122+
@echo "Running all benchmarks (this will take a long time)..."
123+
@for benchmark in $(BENCHMARK_CLASSES); do \
124+
echo ""; \
125+
echo "======================================"; \
126+
echo "Running: $$benchmark"; \
127+
echo "======================================"; \
128+
SPARK_GENERATE_BENCHMARK_FILES=1 $(MAKE) benchmark-$$benchmark || echo "WARNING: $$benchmark failed"; \
129+
done
130+
@echo ""
131+
@echo "All benchmarks completed!"
132+
133+
# List all available benchmarks
134+
list-benchmarks:
135+
@echo "Available benchmarks:"
136+
@echo "$(BENCHMARK_CLASSES)" | tr ' ' '\n'
137+
108138
.DEFAULT:
109139
@: # ignore arguments provided to benchmarks e.g. "make benchmark-foo -- --bar", we do not want to treat "--bar" as target

spark/src/test/scala/org/apache/spark/sql/benchmark/CometAggregateBenchmark.scala

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919

2020
package org.apache.spark.sql.benchmark
2121

22-
import scala.util.Try
23-
2422
import org.apache.spark.benchmark.Benchmark
2523
import org.apache.spark.sql.SparkSession
2624
import org.apache.spark.sql.internal.SQLConf
@@ -87,7 +85,7 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
8785
s"SQL Parquet - Spark (${aggregateFunction.toString}) ansi mode enabled : ${isAnsiMode}") {
8886
_ =>
8987
withSQLConf(SQLConf.ANSI_ENABLED.key -> isAnsiMode.toString) {
90-
Try { spark.sql(query).noop() }
88+
spark.sql(query).noop()
9189
}
9290
}
9391

@@ -98,7 +96,7 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
9896
CometConf.COMET_ENABLED.key -> "true",
9997
CometConf.COMET_EXEC_ENABLED.key -> "true",
10098
SQLConf.ANSI_ENABLED.key -> isAnsiMode.toString) {
101-
Try { spark.sql(query).noop() }
99+
spark.sql(query).noop()
102100
}
103101
}
104102

@@ -137,7 +135,7 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
137135
benchmark.addCase(
138136
s"SQL Parquet - Spark (${aggregateFunction.toString}) ansi mode enabled : ${isAnsiMode}") {
139137
_ =>
140-
Try { spark.sql(query).noop() }
138+
spark.sql(query).noop()
141139
}
142140
}
143141

@@ -148,7 +146,7 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
148146
CometConf.COMET_ENABLED.key -> "true",
149147
CometConf.COMET_EXEC_ENABLED.key -> "true",
150148
SQLConf.ANSI_ENABLED.key -> isAnsiMode.toString) {
151-
Try { spark.sql(query).noop() }
149+
spark.sql(query).noop()
152150
}
153151
}
154152

@@ -185,7 +183,7 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
185183
s"SQL Parquet - Spark (${aggregateFunction.toString}) isANSIMode: ${isAnsiMode.toString}") {
186184
_ =>
187185
withSQLConf(SQLConf.ANSI_ENABLED.key -> isAnsiMode.toString) {
188-
Try { spark.sql(query).noop() }
186+
spark.sql(query).noop()
189187
}
190188
}
191189

@@ -197,7 +195,7 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
197195
CometConf.COMET_EXEC_ENABLED.key -> "true",
198196
CometConf.COMET_ONHEAP_MEMORY_OVERHEAD.key -> "1G",
199197
SQLConf.ANSI_ENABLED.key -> isAnsiMode.toString) {
200-
Try { spark.sql(query).noop() }
198+
spark.sql(query).noop()
201199
}
202200
}
203201

@@ -236,7 +234,7 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
236234
s"SQL Parquet - Spark (${aggregateFunction.toString}) isANSIMode: ${isAnsiMode.toString}") {
237235
_ =>
238236
withSQLConf(SQLConf.ANSI_ENABLED.key -> isAnsiMode.toString) {
239-
Try { spark.sql(query).noop() }
237+
spark.sql(query).noop()
240238
}
241239
}
242240

@@ -247,7 +245,7 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
247245
CometConf.COMET_ENABLED.key -> "true",
248246
CometConf.COMET_EXEC_ENABLED.key -> "true",
249247
SQLConf.ANSI_ENABLED.key -> isAnsiMode.toString) {
250-
Try { spark.sql(query).noop() }
248+
spark.sql(query).noop()
251249
}
252250
}
253251

@@ -260,36 +258,36 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
260258
val total = 1024 * 1024 * 10
261259
val combinations = List(100, 1024, 1024 * 1024) // number of distinct groups
262260
benchmarkAggFuncs.foreach { aggFunc =>
263-
Seq(true, false).foreach(k => {
264-
runBenchmarkWithTable(
261+
Seq(true, false).foreach(ansiMode => {
262+
runBenchmarkWithSafeTable(
265263
s"Grouped Aggregate (single group key + single aggregate $aggFunc)",
266264
total) { v =>
267265
for (card <- combinations) {
268-
singleGroupAndAggregate(v, card, aggFunc, k)
266+
singleGroupAndAggregate(v, card, aggFunc, ansiMode)
269267
}
270268
}
271269

272-
runBenchmarkWithTable(
270+
runBenchmarkWithSafeTable(
273271
s"Grouped Aggregate (multiple group keys + single aggregate $aggFunc)",
274272
total) { v =>
275273
for (card <- combinations) {
276-
multiGroupKeys(v, card, aggFunc, k)
274+
multiGroupKeys(v, card, aggFunc, ansiMode)
277275
}
278276
}
279277

280-
runBenchmarkWithTable(
278+
runBenchmarkWithSafeTable(
281279
s"Grouped Aggregate (single group key + multiple aggregates $aggFunc)",
282280
total) { v =>
283281
for (card <- combinations) {
284-
multiAggregates(v, card, aggFunc, k)
282+
multiAggregates(v, card, aggFunc, ansiMode)
285283
}
286284
}
287285

288-
runBenchmarkWithTable(
286+
runBenchmarkWithSafeTable(
289287
s"Grouped Aggregate (single group key + single aggregate $aggFunc on decimal)",
290288
total) { v =>
291289
for (card <- combinations) {
292-
singleGroupAndAggregateDecimal(v, DecimalType(18, 10), card, aggFunc, k)
290+
singleGroupAndAggregateDecimal(v, DecimalType(18, 10), card, aggFunc, ansiMode)
293291
}
294292
}
295293
})

spark/src/test/scala/org/apache/spark/sql/benchmark/CometArithmeticBenchmark.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ object CometArithmeticBenchmark extends CometBenchmarkBase {
8484

8585
Seq(true, false).foreach { useDictionary =>
8686
Seq(Minus, Mul).foreach { op =>
87-
runBenchmarkWithTable(op.name, TOTAL, useDictionary) { v =>
87+
runBenchmarkWithSafeTable(op.name, TOTAL, useDictionary) { v =>
8888
integerArithmeticBenchmark(v, op, useDictionary)
8989
}
9090
}

spark/src/test/scala/org/apache/spark/sql/benchmark/CometBenchmarkBase.scala

Lines changed: 56 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import java.io.File
2323
import java.nio.charset.StandardCharsets
2424
import java.util.Base64
2525

26-
import scala.util.{Random, Try}
26+
import scala.util.Random
2727

2828
import org.apache.parquet.crypto.DecryptionPropertiesFactory
2929
import org.apache.parquet.crypto.keytools.{KeyToolkit, PropertiesDrivenCryptoFactory}
@@ -33,7 +33,7 @@ import org.apache.spark.benchmark.Benchmark
3333
import org.apache.spark.sql.{DataFrame, DataFrameWriter, Row, SparkSession}
3434
import org.apache.spark.sql.execution.benchmark.SqlBasedBenchmark
3535
import org.apache.spark.sql.internal.SQLConf
36-
import org.apache.spark.sql.types.DecimalType
36+
import org.apache.spark.sql.types.{DataType, DecimalType}
3737

3838
import org.apache.comet.CometConf
3939
import org.apache.comet.CometSparkSessionExtensions
@@ -88,6 +88,54 @@ trait CometBenchmarkBase extends SqlBasedBenchmark {
8888
}
8989
}
9090

91+
/**
92+
* Creates a table with ANSI-safe values that won't overflow in arithmetic operations. Use this
93+
* instead of runBenchmarkWithTable for arithmetic/aggregate benchmarks.
94+
*/
95+
protected def runBenchmarkWithSafeTable(
96+
benchmarkName: String,
97+
values: Int,
98+
useDictionary: Boolean = false)(f: Int => Any): Unit = {
99+
withTempTable(tbl) {
100+
import spark.implicits._
101+
spark
102+
.range(values)
103+
.map(i => if (useDictionary) i % 5 else i % 10000)
104+
.createOrReplaceTempView(tbl)
105+
runBenchmark(benchmarkName)(f(values))
106+
}
107+
}
108+
109+
/**
110+
* Generates ANSI-safe data for casting from Long to the specified target type. Returns a SQL
111+
* expression that transforms the base "value" column to be within safe ranges.
112+
*
113+
* @param targetType
114+
* The target data type for casting
115+
* @return
116+
* SQL expression to generate safe data
117+
*/
118+
protected def generateAnsiSafeData(targetType: DataType): String = {
119+
import org.apache.spark.sql.types._
120+
121+
// we generate long inputs initially and this case statement translates them into right data type so that the code doesn't fail in ANSI mode
122+
targetType match {
123+
case ByteType => "CAST((value % 128) AS BIGINT)"
124+
case ShortType => "CAST((value % 32768) AS BIGINT)"
125+
case IntegerType => "CAST((value % 2147483648) AS BIGINT)"
126+
case LongType => "value"
127+
case FloatType => "CAST((value % 1000000) AS BIGINT)"
128+
case DoubleType => "value"
129+
case _: DecimalType => "CAST((value % 100000000) AS BIGINT)"
130+
case StringType => "CAST(value AS STRING)"
131+
case BooleanType => "CAST((value % 2) AS BIGINT)"
132+
case DateType => "CAST((value % 18262) AS BIGINT)"
133+
case TimestampType => "value"
134+
case BinaryType => "value"
135+
case _ => "value"
136+
}
137+
}
138+
91139
/**
92140
* Runs an expression benchmark with standard cases: Spark, Comet (Scan), Comet (Scan + Exec).
93141
* This provides a consistent benchmark structure for expression evaluation.
@@ -143,13 +191,8 @@ trait CometBenchmarkBase extends SqlBasedBenchmark {
143191
}
144192

145193
private def runSparkCommand(spark: SparkSession, query: String, isANSIMode: Boolean): Unit = {
146-
if (isANSIMode) {
147-
Try {
148-
spark.sql(query).noop()
149-
}
150-
} else {
151-
spark.sql(query).noop()
152-
}
194+
// With ANSI-safe data generation, queries should not throw exceptions
195+
spark.sql(query).noop()
153196
}
154197

155198
protected def prepareTable(dir: File, df: DataFrame, partition: Option[String] = None): Unit = {
@@ -244,10 +287,12 @@ trait CometBenchmarkBase extends SqlBasedBenchmark {
244287
useDictionary: Boolean): DataFrame = {
245288
import spark.implicits._
246289

247-
val div = if (useDictionary) 5 else values
290+
// Use safe range to avoid overflow in decimal operations
291+
val maxValue = 10000
292+
val div = if (useDictionary) 5 else maxValue
248293
spark
249294
.range(values)
250-
.map(_ % div)
295+
.map(i => i % div)
251296
.select((($"value" - 500) / 100.0) cast decimal as Symbol("dec"))
252297
}
253298
}

spark/src/test/scala/org/apache/spark/sql/benchmark/CometCastBenchmark.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,15 @@ object CometCastBenchmark extends CometBenchmarkBase {
7979

8080
withTempPath { dir =>
8181
withTempTable("parquetV1Table") {
82-
prepareTable(dir, spark.sql(s"SELECT value FROM $tbl"))
82+
// Generate ANSI-safe data when in ANSI mode to avoid overflow exceptions
83+
// In legacy mode, use raw values to test overflow handling
84+
val dataExpr = if (isAnsiMode) {
85+
generateAnsiSafeData(toDataType)
86+
} else {
87+
"value"
88+
}
89+
90+
prepareTable(dir, spark.sql(s"SELECT $dataExpr as value FROM $tbl"))
8391

8492
val functionSQL = castExprSQL(toDataType, "value")
8593
val query = s"SELECT $functionSQL FROM parquetV1Table"

0 commit comments

Comments
 (0)