Skip to content

Commit 9bb890e

Browse files
committed
feat: do not fallback to Spark for distincts
1 parent 0c20769 commit 9bb890e

File tree

1 file changed

+44
-21
lines changed

1 file changed

+44
-21
lines changed

spark/src/test/scala/org/apache/spark/sql/benchmark/CometAggregateBenchmark.scala

Lines changed: 44 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,11 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
4444
def singleGroupAndAggregate(
4545
values: Int,
4646
groupingKeyCardinality: Int,
47-
aggregateFunction: String): Unit = {
47+
aggregateFunction: BenchAggregateFunction): Unit = {
4848
val benchmark =
4949
new Benchmark(
5050
s"Grouped HashAgg Exec: single group key (cardinality $groupingKeyCardinality), " +
51-
s"single aggregate $aggregateFunction",
51+
s"single aggregate ${aggregateFunction.name}",
5252
values,
5353
output = output)
5454

@@ -58,13 +58,14 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
5858
dir,
5959
spark.sql(s"SELECT value, floor(rand() * $groupingKeyCardinality) as key FROM $tbl"))
6060

61-
val query = s"SELECT key, $aggregateFunction(value) FROM parquetV1Table GROUP BY key"
61+
val functionSQL = aggFunctionSQL(aggregateFunction, "value")
62+
val query = s"SELECT key, $functionSQL FROM parquetV1Table GROUP BY key"
6263

63-
benchmark.addCase(s"SQL Parquet - Spark ($aggregateFunction)") { _ =>
64+
benchmark.addCase(s"SQL Parquet - Spark (${aggregateFunction.name})") { _ =>
6465
spark.sql(query).noop()
6566
}
6667

67-
benchmark.addCase(s"SQL Parquet - Comet ($aggregateFunction)") { _ =>
68+
benchmark.addCase(s"SQL Parquet - Comet (${aggregateFunction.name})") { _ =>
6869
withSQLConf(
6970
CometConf.COMET_ENABLED.key -> "true",
7071
CometConf.COMET_EXEC_ENABLED.key -> "true") {
@@ -77,11 +78,15 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
7778
}
7879
}
7980

81+
def aggFunctionSQL(aggregateFunction: BenchAggregateFunction, input: String): String = {
82+
s"${aggregateFunction.name}(${if (aggregateFunction.distinct) s"DISTINCT $input" else input})"
83+
}
84+
8085
def singleGroupAndAggregateDecimal(
8186
values: Int,
8287
dataType: DecimalType,
8388
groupingKeyCardinality: Int,
84-
aggregateFunction: String): Unit = {
89+
aggregateFunction: BenchAggregateFunction): Unit = {
8590
val benchmark =
8691
new Benchmark(
8792
s"Grouped HashAgg Exec: single group key (cardinality $groupingKeyCardinality), " +
@@ -99,13 +104,14 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
99104
spark.sql(
100105
s"SELECT dec as value, floor(rand() * $groupingKeyCardinality) as key FROM $tbl"))
101106

102-
val query = s"SELECT key, $aggregateFunction(value) FROM parquetV1Table GROUP BY key"
107+
val functionSQL = aggFunctionSQL(aggregateFunction, "value")
108+
val query = s"SELECT key, $functionSQL FROM parquetV1Table GROUP BY key"
103109

104-
benchmark.addCase(s"SQL Parquet - Spark ($aggregateFunction)") { _ =>
110+
benchmark.addCase(s"SQL Parquet - Spark (${aggregateFunction.name})") { _ =>
105111
spark.sql(query).noop()
106112
}
107113

108-
benchmark.addCase(s"SQL Parquet - Comet ($aggregateFunction)") { _ =>
114+
benchmark.addCase(s"SQL Parquet - Comet (${aggregateFunction.name})") { _ =>
109115
withSQLConf(
110116
CometConf.COMET_ENABLED.key -> "true",
111117
CometConf.COMET_EXEC_ENABLED.key -> "true") {
@@ -118,11 +124,14 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
118124
}
119125
}
120126

121-
def multiGroupKeys(values: Int, groupingKeyCard: Int, aggregateFunction: String): Unit = {
127+
def multiGroupKeys(
128+
values: Int,
129+
groupingKeyCard: Int,
130+
aggregateFunction: BenchAggregateFunction): Unit = {
122131
val benchmark =
123132
new Benchmark(
124133
s"Grouped HashAgg Exec: multiple group keys (cardinality $groupingKeyCard), " +
125-
s"single aggregate $aggregateFunction",
134+
s"single aggregate ${aggregateFunction.name}",
126135
values,
127136
output = output)
128137

@@ -134,14 +143,15 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
134143
s"SELECT value, floor(rand() * $groupingKeyCard) as key1, " +
135144
s"floor(rand() * $groupingKeyCard) as key2 FROM $tbl"))
136145

146+
val functionSQL = aggFunctionSQL(aggregateFunction, "value")
137147
val query =
138-
s"SELECT key1, key2, $aggregateFunction(value) FROM parquetV1Table GROUP BY key1, key2"
148+
s"SELECT key1, key2, $functionSQL FROM parquetV1Table GROUP BY key1, key2"
139149

140-
benchmark.addCase(s"SQL Parquet - Spark ($aggregateFunction)") { _ =>
150+
benchmark.addCase(s"SQL Parquet - Spark (${aggregateFunction.name})") { _ =>
141151
spark.sql(query).noop()
142152
}
143153

144-
benchmark.addCase(s"SQL Parquet - Comet ($aggregateFunction)") { _ =>
154+
benchmark.addCase(s"SQL Parquet - Comet (${aggregateFunction.name})") { _ =>
145155
withSQLConf(
146156
CometConf.COMET_ENABLED.key -> "true",
147157
CometConf.COMET_EXEC_ENABLED.key -> "true",
@@ -155,11 +165,14 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
155165
}
156166
}
157167

158-
def multiAggregates(values: Int, groupingKeyCard: Int, aggregateFunction: String): Unit = {
168+
def multiAggregates(
169+
values: Int,
170+
groupingKeyCard: Int,
171+
aggregateFunction: BenchAggregateFunction): Unit = {
159172
val benchmark =
160173
new Benchmark(
161174
s"Grouped HashAgg Exec: single group key (cardinality $groupingKeyCard), " +
162-
s"multiple aggregates $aggregateFunction",
175+
s"multiple aggregates ${aggregateFunction.name}",
163176
values,
164177
output = output)
165178

@@ -171,14 +184,17 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
171184
s"SELECT value as value1, value as value2, floor(rand() * $groupingKeyCard) as key " +
172185
s"FROM $tbl"))
173186

174-
val query = s"SELECT key, $aggregateFunction(value1), $aggregateFunction(value2) " +
187+
val functionSQL1 = aggFunctionSQL(aggregateFunction, "value1")
188+
val functionSQL2 = aggFunctionSQL(aggregateFunction, "value2")
189+
190+
val query = s"SELECT key, $functionSQL1, $functionSQL2 " +
175191
"FROM parquetV1Table GROUP BY key"
176192

177-
benchmark.addCase(s"SQL Parquet - Spark ($aggregateFunction)") { _ =>
193+
benchmark.addCase(s"SQL Parquet - Spark (${aggregateFunction.name})") { _ =>
178194
spark.sql(query).noop()
179195
}
180196

181-
benchmark.addCase(s"SQL Parquet - Comet ($aggregateFunction)") { _ =>
197+
benchmark.addCase(s"SQL Parquet - Comet (${aggregateFunction.name})") { _ =>
182198
withSQLConf(
183199
CometConf.COMET_ENABLED.key -> "true",
184200
CometConf.COMET_EXEC_ENABLED.key -> "true") {
@@ -191,12 +207,19 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
191207
}
192208
}
193209

210+
case class BenchAggregateFunction(name: String, distinct: Boolean = false)
211+
private val benchmarkAggFuncs = Seq(
212+
BenchAggregateFunction("SUM"),
213+
BenchAggregateFunction("MIN"),
214+
BenchAggregateFunction("MAX"),
215+
BenchAggregateFunction("COUNT"),
216+
BenchAggregateFunction("COUNT", distinct = true))
217+
194218
override def runCometBenchmark(mainArgs: Array[String]): Unit = {
195219
val total = 1024 * 1024 * 10
196220
val combinations = List(100, 1024, 1024 * 1024) // number of distinct groups
197-
val aggregateFunctions = List("SUM", "MIN", "MAX", "COUNT")
198221

199-
aggregateFunctions.foreach { aggFunc =>
222+
benchmarkAggFuncs.foreach { aggFunc =>
200223
runBenchmarkWithTable(
201224
s"Grouped Aggregate (single group key + single aggregate $aggFunc)",
202225
total) { v =>

0 commit comments

Comments
 (0)