@@ -41,14 +41,31 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
4141 session
4242 }
4343
44+ // Wrapper on SQL aggregation function
45+ case class BenchAggregateFunction (name : String , distinct : Boolean = false ) {
46+ override def toString : String = if (distinct) s " $name(DISTINCT) " else name
47+ }
48+
49+ // Aggregation functions to test
50+ private val benchmarkAggFuncs = Seq (
51+ BenchAggregateFunction (" SUM" ),
52+ BenchAggregateFunction (" MIN" ),
53+ BenchAggregateFunction (" MAX" ),
54+ BenchAggregateFunction (" COUNT" ),
55+ BenchAggregateFunction (" COUNT" , distinct = true ))
56+
57+ def aggFunctionSQL (aggregateFunction : BenchAggregateFunction , input : String ): String = {
58+ s " ${aggregateFunction.name}( ${if (aggregateFunction.distinct) s " DISTINCT $input" else input}) "
59+ }
60+
4461 def singleGroupAndAggregate (
4562 values : Int ,
4663 groupingKeyCardinality : Int ,
47- aggregateFunction : String ): Unit = {
64+ aggregateFunction : BenchAggregateFunction ): Unit = {
4865 val benchmark =
4966 new Benchmark (
5067 s " Grouped HashAgg Exec: single group key (cardinality $groupingKeyCardinality), " +
51- s " single aggregate $aggregateFunction" ,
68+ s " single aggregate ${ aggregateFunction.toString} " ,
5269 values,
5370 output = output)
5471
@@ -58,13 +75,14 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
5875 dir,
5976 spark.sql(s " SELECT value, floor(rand() * $groupingKeyCardinality) as key FROM $tbl" ))
6077
61- val query = s " SELECT key, $aggregateFunction(value) FROM parquetV1Table GROUP BY key "
78+ val functionSQL = aggFunctionSQL(aggregateFunction, " value" )
79+ val query = s " SELECT key, $functionSQL FROM parquetV1Table GROUP BY key "
6280
63- benchmark.addCase(s " SQL Parquet - Spark ( $aggregateFunction) " ) { _ =>
81+ benchmark.addCase(s " SQL Parquet - Spark ( ${ aggregateFunction.toString} ) " ) { _ =>
6482 spark.sql(query).noop()
6583 }
6684
67- benchmark.addCase(s " SQL Parquet - Comet ( $aggregateFunction) " ) { _ =>
85+ benchmark.addCase(s " SQL Parquet - Comet ( ${ aggregateFunction.toString} ) " ) { _ =>
6886 withSQLConf(
6987 CometConf .COMET_ENABLED .key -> " true" ,
7088 CometConf .COMET_EXEC_ENABLED .key -> " true" ) {
@@ -81,11 +99,11 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
8199 values : Int ,
82100 dataType : DecimalType ,
83101 groupingKeyCardinality : Int ,
84- aggregateFunction : String ): Unit = {
102+ aggregateFunction : BenchAggregateFunction ): Unit = {
85103 val benchmark =
86104 new Benchmark (
87105 s " Grouped HashAgg Exec: single group key (cardinality $groupingKeyCardinality), " +
88- s " single aggregate $aggregateFunction on decimal " ,
106+ s " single aggregate ${ aggregateFunction.toString} on decimal " ,
89107 values,
90108 output = output)
91109
@@ -99,13 +117,14 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
99117 spark.sql(
100118 s " SELECT dec as value, floor(rand() * $groupingKeyCardinality) as key FROM $tbl" ))
101119
102- val query = s " SELECT key, $aggregateFunction(value) FROM parquetV1Table GROUP BY key "
120+ val functionSQL = aggFunctionSQL(aggregateFunction, " value" )
121+ val query = s " SELECT key, $functionSQL FROM parquetV1Table GROUP BY key "
103122
104- benchmark.addCase(s " SQL Parquet - Spark ( $aggregateFunction) " ) { _ =>
123+ benchmark.addCase(s " SQL Parquet - Spark ( ${ aggregateFunction.toString} ) " ) { _ =>
105124 spark.sql(query).noop()
106125 }
107126
108- benchmark.addCase(s " SQL Parquet - Comet ( $aggregateFunction) " ) { _ =>
127+ benchmark.addCase(s " SQL Parquet - Comet ( ${ aggregateFunction.toString} ) " ) { _ =>
109128 withSQLConf(
110129 CometConf .COMET_ENABLED .key -> " true" ,
111130 CometConf .COMET_EXEC_ENABLED .key -> " true" ) {
@@ -118,11 +137,14 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
118137 }
119138 }
120139
121- def multiGroupKeys (values : Int , groupingKeyCard : Int , aggregateFunction : String ): Unit = {
140+ def multiGroupKeys (
141+ values : Int ,
142+ groupingKeyCard : Int ,
143+ aggregateFunction : BenchAggregateFunction ): Unit = {
122144 val benchmark =
123145 new Benchmark (
124146 s " Grouped HashAgg Exec: multiple group keys (cardinality $groupingKeyCard), " +
125- s " single aggregate $aggregateFunction" ,
147+ s " single aggregate ${ aggregateFunction.toString} " ,
126148 values,
127149 output = output)
128150
@@ -134,14 +156,15 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
134156 s " SELECT value, floor(rand() * $groupingKeyCard) as key1, " +
135157 s " floor(rand() * $groupingKeyCard) as key2 FROM $tbl" ))
136158
159+ val functionSQL = aggFunctionSQL(aggregateFunction, " value" )
137160 val query =
138- s " SELECT key1, key2, $aggregateFunction (value) FROM parquetV1Table GROUP BY key1, key2"
161+ s " SELECT key1, key2, $functionSQL FROM parquetV1Table GROUP BY key1, key2 "
139162
140- benchmark.addCase(s " SQL Parquet - Spark ( $aggregateFunction) " ) { _ =>
163+ benchmark.addCase(s " SQL Parquet - Spark ( ${ aggregateFunction.toString} ) " ) { _ =>
141164 spark.sql(query).noop()
142165 }
143166
144- benchmark.addCase(s " SQL Parquet - Comet ( $aggregateFunction) " ) { _ =>
167+ benchmark.addCase(s " SQL Parquet - Comet ( ${ aggregateFunction.toString} ) " ) { _ =>
145168 withSQLConf(
146169 CometConf .COMET_ENABLED .key -> " true" ,
147170 CometConf .COMET_EXEC_ENABLED .key -> " true" ,
@@ -155,11 +178,14 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
155178 }
156179 }
157180
158- def multiAggregates (values : Int , groupingKeyCard : Int , aggregateFunction : String ): Unit = {
181+ def multiAggregates (
182+ values : Int ,
183+ groupingKeyCard : Int ,
184+ aggregateFunction : BenchAggregateFunction ): Unit = {
159185 val benchmark =
160186 new Benchmark (
161187 s " Grouped HashAgg Exec: single group key (cardinality $groupingKeyCard), " +
162- s " multiple aggregates $aggregateFunction" ,
188+ s " multiple aggregates ${ aggregateFunction.toString} " ,
163189 values,
164190 output = output)
165191
@@ -171,14 +197,17 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
171197 s " SELECT value as value1, value as value2, floor(rand() * $groupingKeyCard) as key " +
172198 s " FROM $tbl" ))
173199
174- val query = s " SELECT key, $aggregateFunction(value1), $aggregateFunction(value2) " +
200+ val functionSQL1 = aggFunctionSQL(aggregateFunction, " value1" )
201+ val functionSQL2 = aggFunctionSQL(aggregateFunction, " value2" )
202+
203+ val query = s " SELECT key, $functionSQL1, $functionSQL2 " +
175204 " FROM parquetV1Table GROUP BY key"
176205
177- benchmark.addCase(s " SQL Parquet - Spark ( $aggregateFunction) " ) { _ =>
206+ benchmark.addCase(s " SQL Parquet - Spark ( ${ aggregateFunction.toString} ) " ) { _ =>
178207 spark.sql(query).noop()
179208 }
180209
181- benchmark.addCase(s " SQL Parquet - Comet ( $aggregateFunction) " ) { _ =>
210+ benchmark.addCase(s " SQL Parquet - Comet ( ${ aggregateFunction.toString} ) " ) { _ =>
182211 withSQLConf(
183212 CometConf .COMET_ENABLED .key -> " true" ,
184213 CometConf .COMET_EXEC_ENABLED .key -> " true" ) {
@@ -194,9 +223,8 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
194223 override def runCometBenchmark (mainArgs : Array [String ]): Unit = {
195224 val total = 1024 * 1024 * 10
196225 val combinations = List (100 , 1024 , 1024 * 1024 ) // number of distinct groups
197- val aggregateFunctions = List (" SUM" , " MIN" , " MAX" , " COUNT" )
198226
199- aggregateFunctions .foreach { aggFunc =>
227+ benchmarkAggFuncs .foreach { aggFunc =>
200228 runBenchmarkWithTable(
201229 s " Grouped Aggregate (single group key + single aggregate $aggFunc) " ,
202230 total) { v =>
0 commit comments