Skip to content

Commit 9d3a40f

Browse files
committed
feat: do not fallback to Spark for distincts
1 parent 9bb890e commit 9d3a40f

File tree

1 file changed

+29
-24
lines changed

1 file changed

+29
-24
lines changed

spark/src/test/scala/org/apache/spark/sql/benchmark/CometAggregateBenchmark.scala

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,31 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
4141
session
4242
}
4343

44+
// Wrapper on SQL aggregation function
45+
case class BenchAggregateFunction(name: String, distinct: Boolean = false) {
46+
override def toString: String = if (distinct) s"$name(DISTINCT)" else name
47+
}
48+
49+
// Aggregation functions to test
50+
private val benchmarkAggFuncs = Seq(
51+
BenchAggregateFunction("SUM"),
52+
BenchAggregateFunction("MIN"),
53+
BenchAggregateFunction("MAX"),
54+
BenchAggregateFunction("COUNT"),
55+
BenchAggregateFunction("COUNT", distinct = true))
56+
57+
def aggFunctionSQL(aggregateFunction: BenchAggregateFunction, input: String): String = {
58+
s"${aggregateFunction.name}(${if (aggregateFunction.distinct) s"DISTINCT $input" else input})"
59+
}
60+
4461
def singleGroupAndAggregate(
4562
values: Int,
4663
groupingKeyCardinality: Int,
4764
aggregateFunction: BenchAggregateFunction): Unit = {
4865
val benchmark =
4966
new Benchmark(
5067
s"Grouped HashAgg Exec: single group key (cardinality $groupingKeyCardinality), " +
51-
s"single aggregate ${aggregateFunction.name}",
68+
s"single aggregate ${aggregateFunction.toString}",
5269
values,
5370
output = output)
5471

@@ -61,11 +78,11 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
6178
val functionSQL = aggFunctionSQL(aggregateFunction, "value")
6279
val query = s"SELECT key, $functionSQL FROM parquetV1Table GROUP BY key"
6380

64-
benchmark.addCase(s"SQL Parquet - Spark (${aggregateFunction.name})") { _ =>
81+
benchmark.addCase(s"SQL Parquet - Spark (${aggregateFunction.toString})") { _ =>
6582
spark.sql(query).noop()
6683
}
6784

68-
benchmark.addCase(s"SQL Parquet - Comet (${aggregateFunction.name})") { _ =>
85+
benchmark.addCase(s"SQL Parquet - Comet (${aggregateFunction.toString})") { _ =>
6986
withSQLConf(
7087
CometConf.COMET_ENABLED.key -> "true",
7188
CometConf.COMET_EXEC_ENABLED.key -> "true") {
@@ -78,10 +95,6 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
7895
}
7996
}
8097

81-
def aggFunctionSQL(aggregateFunction: BenchAggregateFunction, input: String): String = {
82-
s"${aggregateFunction.name}(${if (aggregateFunction.distinct) s"DISTINCT $input" else input})"
83-
}
84-
8598
def singleGroupAndAggregateDecimal(
8699
values: Int,
87100
dataType: DecimalType,
@@ -90,7 +103,7 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
90103
val benchmark =
91104
new Benchmark(
92105
s"Grouped HashAgg Exec: single group key (cardinality $groupingKeyCardinality), " +
93-
s"single aggregate $aggregateFunction on decimal",
106+
s"single aggregate ${aggregateFunction.toString} on decimal",
94107
values,
95108
output = output)
96109

@@ -107,11 +120,11 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
107120
val functionSQL = aggFunctionSQL(aggregateFunction, "value")
108121
val query = s"SELECT key, $functionSQL FROM parquetV1Table GROUP BY key"
109122

110-
benchmark.addCase(s"SQL Parquet - Spark (${aggregateFunction.name})") { _ =>
123+
benchmark.addCase(s"SQL Parquet - Spark (${aggregateFunction.toString})") { _ =>
111124
spark.sql(query).noop()
112125
}
113126

114-
benchmark.addCase(s"SQL Parquet - Comet (${aggregateFunction.name})") { _ =>
127+
benchmark.addCase(s"SQL Parquet - Comet (${aggregateFunction.toString})") { _ =>
115128
withSQLConf(
116129
CometConf.COMET_ENABLED.key -> "true",
117130
CometConf.COMET_EXEC_ENABLED.key -> "true") {
@@ -131,7 +144,7 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
131144
val benchmark =
132145
new Benchmark(
133146
s"Grouped HashAgg Exec: multiple group keys (cardinality $groupingKeyCard), " +
134-
s"single aggregate ${aggregateFunction.name}",
147+
s"single aggregate ${aggregateFunction.toString}",
135148
values,
136149
output = output)
137150

@@ -147,11 +160,11 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
147160
val query =
148161
s"SELECT key1, key2, $functionSQL FROM parquetV1Table GROUP BY key1, key2"
149162

150-
benchmark.addCase(s"SQL Parquet - Spark (${aggregateFunction.name})") { _ =>
163+
benchmark.addCase(s"SQL Parquet - Spark (${aggregateFunction.toString})") { _ =>
151164
spark.sql(query).noop()
152165
}
153166

154-
benchmark.addCase(s"SQL Parquet - Comet (${aggregateFunction.name})") { _ =>
167+
benchmark.addCase(s"SQL Parquet - Comet (${aggregateFunction.toString})") { _ =>
155168
withSQLConf(
156169
CometConf.COMET_ENABLED.key -> "true",
157170
CometConf.COMET_EXEC_ENABLED.key -> "true",
@@ -172,7 +185,7 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
172185
val benchmark =
173186
new Benchmark(
174187
s"Grouped HashAgg Exec: single group key (cardinality $groupingKeyCard), " +
175-
s"multiple aggregates ${aggregateFunction.name}",
188+
s"multiple aggregates ${aggregateFunction.toString}",
176189
values,
177190
output = output)
178191

@@ -190,11 +203,11 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
190203
val query = s"SELECT key, $functionSQL1, $functionSQL2 " +
191204
"FROM parquetV1Table GROUP BY key"
192205

193-
benchmark.addCase(s"SQL Parquet - Spark (${aggregateFunction.name})") { _ =>
206+
benchmark.addCase(s"SQL Parquet - Spark (${aggregateFunction.toString})") { _ =>
194207
spark.sql(query).noop()
195208
}
196209

197-
benchmark.addCase(s"SQL Parquet - Comet (${aggregateFunction.name})") { _ =>
210+
benchmark.addCase(s"SQL Parquet - Comet (${aggregateFunction.toString})") { _ =>
198211
withSQLConf(
199212
CometConf.COMET_ENABLED.key -> "true",
200213
CometConf.COMET_EXEC_ENABLED.key -> "true") {
@@ -207,14 +220,6 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
207220
}
208221
}
209222

210-
case class BenchAggregateFunction(name: String, distinct: Boolean = false)
211-
private val benchmarkAggFuncs = Seq(
212-
BenchAggregateFunction("SUM"),
213-
BenchAggregateFunction("MIN"),
214-
BenchAggregateFunction("MAX"),
215-
BenchAggregateFunction("COUNT"),
216-
BenchAggregateFunction("COUNT", distinct = true))
217-
218223
override def runCometBenchmark(mainArgs: Array[String]): Unit = {
219224
val total = 1024 * 1024 * 10
220225
val combinations = List(100, 1024, 1024 * 1024) // number of distinct groups

0 commit comments

Comments
 (0)