Skip to content

Commit 5845227

Browse files
authored
feat: do not fallback to Spark for COUNT(distinct) (apache#2429)
* feat: do not fallback to Spark for distincts
1 parent abf5b69 commit 5845227

File tree

4 files changed

+123
-36
lines changed

4 files changed

+123
-36
lines changed

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -558,9 +558,15 @@ object QueryPlanSerde extends Logging with CometExprShim {
558558
binding: Boolean,
559559
conf: SQLConf): Option[AggExpr] = {
560560

561-
if (aggExpr.isDistinct) {
562-
// https://github.com/apache/datafusion-comet/issues/1260
563-
withInfo(aggExpr, s"distinct aggregate not supported: $aggExpr")
561+
// Support Count(distinct single_value)
562+
// COUNT(DISTINCT x) - supported
563+
// COUNT(DISTINCT x, x) - supported through transition to COUNT(DISTINCT x)
564+
// COUNT(DISTINCT x, y) - not supported
565+
if (aggExpr.isDistinct
566+
&&
567+
!(aggExpr.aggregateFunction.prettyName == "count" &&
568+
aggExpr.aggregateFunction.children.length == 1)) {
569+
withInfo(aggExpr, s"Distinct aggregate not supported for: $aggExpr")
564570
return None
565571
}
566572

spark/src/test/scala/org/apache/comet/CometFuzzAggregateSuite.scala

Lines changed: 62 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,73 @@
1919

2020
package org.apache.comet
2121

22+
import org.apache.comet.DataTypeSupport.isComplexType
23+
2224
class CometFuzzAggregateSuite extends CometFuzzTestBase {
2325

24-
test("count distinct") {
26+
test("count distinct - simple columns") {
2527
val df = spark.read.parquet(filename)
2628
df.createOrReplaceTempView("t1")
27-
for (col <- df.columns) {
29+
for (col <- df.schema.fields.filterNot(f => isComplexType(f.dataType)).map(_.name)) {
30+
val sql = s"SELECT count(distinct $col) FROM t1"
31+
val (_, cometPlan) = checkSparkAnswer(sql)
32+
if (usingDataSourceExec) {
33+
assert(1 == collectNativeScans(cometPlan).length)
34+
}
35+
36+
checkSparkAnswerAndOperator(sql)
37+
}
38+
}
39+
40+
// Aggregate by complex columns not yet supported
41+
// https://github.com/apache/datafusion-comet/issues/2382
42+
test("count distinct - complex columns") {
43+
val df = spark.read.parquet(filename)
44+
df.createOrReplaceTempView("t1")
45+
for (col <- df.schema.fields.filter(f => isComplexType(f.dataType)).map(_.name)) {
2846
val sql = s"SELECT count(distinct $col) FROM t1"
29-
// Comet does not support count distinct yet
30-
// https://github.com/apache/datafusion-comet/issues/2292
47+
val (_, cometPlan) = checkSparkAnswer(sql)
48+
if (usingDataSourceExec) {
49+
assert(1 == collectNativeScans(cometPlan).length)
50+
}
51+
}
52+
}
53+
54+
test("count distinct group by multiple column - simple columns ") {
55+
val df = spark.read.parquet(filename)
56+
df.createOrReplaceTempView("t1")
57+
for (col <- df.schema.fields.filterNot(f => isComplexType(f.dataType)).map(_.name)) {
58+
val sql = s"SELECT c1, c2, c3, count(distinct $col) FROM t1 group by c1, c2, c3"
59+
val (_, cometPlan) = checkSparkAnswer(sql)
60+
if (usingDataSourceExec) {
61+
assert(1 == collectNativeScans(cometPlan).length)
62+
}
63+
64+
checkSparkAnswerAndOperator(sql)
65+
}
66+
}
67+
68+
// Aggregate by complex columns not yet supported
69+
// https://github.com/apache/datafusion-comet/issues/2382
70+
test("count distinct group by multiple column - complex columns ") {
71+
val df = spark.read.parquet(filename)
72+
df.createOrReplaceTempView("t1")
73+
for (col <- df.schema.fields.filter(f => isComplexType(f.dataType)).map(_.name)) {
74+
val sql = s"SELECT c1, c2, c3, count(distinct $col) FROM t1 group by c1, c2, c3"
75+
val (_, cometPlan) = checkSparkAnswer(sql)
76+
if (usingDataSourceExec) {
77+
assert(1 == collectNativeScans(cometPlan).length)
78+
}
79+
}
80+
}
81+
82+
// COUNT(distinct x, y, z, ...) not yet supported
83+
// https://github.com/apache/datafusion-comet/issues/2292
84+
test("count distinct multiple values and group by multiple column") {
85+
val df = spark.read.parquet(filename)
86+
df.createOrReplaceTempView("t1")
87+
for (col <- df.columns) {
88+
val sql = s"SELECT c1, c2, c3, count(distinct $col, c4, c5) FROM t1 group by c1, c2, c3"
3189
val (_, cometPlan) = checkSparkAnswer(sql)
3290
if (usingDataSourceExec) {
3391
assert(1 == collectNativeScans(cometPlan).length)
@@ -39,7 +97,6 @@ class CometFuzzAggregateSuite extends CometFuzzTestBase {
3997
val df = spark.read.parquet(filename)
4098
df.createOrReplaceTempView("t1")
4199
for (col <- df.columns) {
42-
// cannot run fully natively due to range partitioning and sort
43100
val sql = s"SELECT $col, count(*) FROM t1 GROUP BY $col ORDER BY $col"
44101
val (_, cometPlan) = checkSparkAnswer(sql)
45102
if (usingDataSourceExec) {
@@ -53,7 +110,6 @@ class CometFuzzAggregateSuite extends CometFuzzTestBase {
53110
df.createOrReplaceTempView("t1")
54111
val groupCol = df.columns.head
55112
for (col <- df.columns.drop(1)) {
56-
// cannot run fully natively due to range partitioning and sort
57113
val sql = s"SELECT $groupCol, count($col) FROM t1 GROUP BY $groupCol ORDER BY $groupCol"
58114
val (_, cometPlan) = checkSparkAnswer(sql)
59115
if (usingDataSourceExec) {
@@ -67,7 +123,6 @@ class CometFuzzAggregateSuite extends CometFuzzTestBase {
67123
df.createOrReplaceTempView("t1")
68124
val groupCol = df.columns.head
69125
val otherCol = df.columns.drop(1)
70-
// cannot run fully natively due to range partitioning and sort
71126
val sql = s"SELECT $groupCol, count(${otherCol.mkString(", ")}) FROM t1 " +
72127
s"GROUP BY $groupCol ORDER BY $groupCol"
73128
val (_, cometPlan) = checkSparkAnswer(sql)
@@ -88,5 +143,4 @@ class CometFuzzAggregateSuite extends CometFuzzTestBase {
88143
}
89144
}
90145
}
91-
92146
}

spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1031,9 +1031,8 @@ class CometExecSuite extends CometTestBase {
10311031
|GROUP BY key
10321032
""".stripMargin)
10331033

1034-
// The above query uses COUNT(DISTINCT) which Comet doesn't support yet, so the plan will
1035-
// have a mix of `HashAggregate` and `CometHashAggregate`. In the following we check all
1036-
// operators starting from `CometHashAggregate` are native.
1034+
// The above query uses SUM(DISTINCT) and count(distinct value1, value2)
1035+
// which is not yet supported
10371036
checkSparkAnswer(df)
10381037
val subPlan = stripAQEPlan(df.queryExecution.executedPlan).collectFirst {
10391038
case s: CometHashAggregateExec => s

spark/src/test/scala/org/apache/spark/sql/benchmark/CometAggregateBenchmark.scala

Lines changed: 50 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,31 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
4141
session
4242
}
4343

44+
// Wrapper on SQL aggregation function
45+
case class BenchAggregateFunction(name: String, distinct: Boolean = false) {
46+
override def toString: String = if (distinct) s"$name(DISTINCT)" else name
47+
}
48+
49+
// Aggregation functions to test
50+
private val benchmarkAggFuncs = Seq(
51+
BenchAggregateFunction("SUM"),
52+
BenchAggregateFunction("MIN"),
53+
BenchAggregateFunction("MAX"),
54+
BenchAggregateFunction("COUNT"),
55+
BenchAggregateFunction("COUNT", distinct = true))
56+
57+
def aggFunctionSQL(aggregateFunction: BenchAggregateFunction, input: String): String = {
58+
s"${aggregateFunction.name}(${if (aggregateFunction.distinct) s"DISTINCT $input" else input})"
59+
}
60+
4461
def singleGroupAndAggregate(
4562
values: Int,
4663
groupingKeyCardinality: Int,
47-
aggregateFunction: String): Unit = {
64+
aggregateFunction: BenchAggregateFunction): Unit = {
4865
val benchmark =
4966
new Benchmark(
5067
s"Grouped HashAgg Exec: single group key (cardinality $groupingKeyCardinality), " +
51-
s"single aggregate $aggregateFunction",
68+
s"single aggregate ${aggregateFunction.toString}",
5269
values,
5370
output = output)
5471

@@ -58,13 +75,14 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
5875
dir,
5976
spark.sql(s"SELECT value, floor(rand() * $groupingKeyCardinality) as key FROM $tbl"))
6077

61-
val query = s"SELECT key, $aggregateFunction(value) FROM parquetV1Table GROUP BY key"
78+
val functionSQL = aggFunctionSQL(aggregateFunction, "value")
79+
val query = s"SELECT key, $functionSQL FROM parquetV1Table GROUP BY key"
6280

63-
benchmark.addCase(s"SQL Parquet - Spark ($aggregateFunction)") { _ =>
81+
benchmark.addCase(s"SQL Parquet - Spark (${aggregateFunction.toString})") { _ =>
6482
spark.sql(query).noop()
6583
}
6684

67-
benchmark.addCase(s"SQL Parquet - Comet ($aggregateFunction)") { _ =>
85+
benchmark.addCase(s"SQL Parquet - Comet (${aggregateFunction.toString})") { _ =>
6886
withSQLConf(
6987
CometConf.COMET_ENABLED.key -> "true",
7088
CometConf.COMET_EXEC_ENABLED.key -> "true") {
@@ -81,11 +99,11 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
8199
values: Int,
82100
dataType: DecimalType,
83101
groupingKeyCardinality: Int,
84-
aggregateFunction: String): Unit = {
102+
aggregateFunction: BenchAggregateFunction): Unit = {
85103
val benchmark =
86104
new Benchmark(
87105
s"Grouped HashAgg Exec: single group key (cardinality $groupingKeyCardinality), " +
88-
s"single aggregate $aggregateFunction on decimal",
106+
s"single aggregate ${aggregateFunction.toString} on decimal",
89107
values,
90108
output = output)
91109

@@ -99,13 +117,14 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
99117
spark.sql(
100118
s"SELECT dec as value, floor(rand() * $groupingKeyCardinality) as key FROM $tbl"))
101119

102-
val query = s"SELECT key, $aggregateFunction(value) FROM parquetV1Table GROUP BY key"
120+
val functionSQL = aggFunctionSQL(aggregateFunction, "value")
121+
val query = s"SELECT key, $functionSQL FROM parquetV1Table GROUP BY key"
103122

104-
benchmark.addCase(s"SQL Parquet - Spark ($aggregateFunction)") { _ =>
123+
benchmark.addCase(s"SQL Parquet - Spark (${aggregateFunction.toString})") { _ =>
105124
spark.sql(query).noop()
106125
}
107126

108-
benchmark.addCase(s"SQL Parquet - Comet ($aggregateFunction)") { _ =>
127+
benchmark.addCase(s"SQL Parquet - Comet (${aggregateFunction.toString})") { _ =>
109128
withSQLConf(
110129
CometConf.COMET_ENABLED.key -> "true",
111130
CometConf.COMET_EXEC_ENABLED.key -> "true") {
@@ -118,11 +137,14 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
118137
}
119138
}
120139

121-
def multiGroupKeys(values: Int, groupingKeyCard: Int, aggregateFunction: String): Unit = {
140+
def multiGroupKeys(
141+
values: Int,
142+
groupingKeyCard: Int,
143+
aggregateFunction: BenchAggregateFunction): Unit = {
122144
val benchmark =
123145
new Benchmark(
124146
s"Grouped HashAgg Exec: multiple group keys (cardinality $groupingKeyCard), " +
125-
s"single aggregate $aggregateFunction",
147+
s"single aggregate ${aggregateFunction.toString}",
126148
values,
127149
output = output)
128150

@@ -134,14 +156,15 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
134156
s"SELECT value, floor(rand() * $groupingKeyCard) as key1, " +
135157
s"floor(rand() * $groupingKeyCard) as key2 FROM $tbl"))
136158

159+
val functionSQL = aggFunctionSQL(aggregateFunction, "value")
137160
val query =
138-
s"SELECT key1, key2, $aggregateFunction(value) FROM parquetV1Table GROUP BY key1, key2"
161+
s"SELECT key1, key2, $functionSQL FROM parquetV1Table GROUP BY key1, key2"
139162

140-
benchmark.addCase(s"SQL Parquet - Spark ($aggregateFunction)") { _ =>
163+
benchmark.addCase(s"SQL Parquet - Spark (${aggregateFunction.toString})") { _ =>
141164
spark.sql(query).noop()
142165
}
143166

144-
benchmark.addCase(s"SQL Parquet - Comet ($aggregateFunction)") { _ =>
167+
benchmark.addCase(s"SQL Parquet - Comet (${aggregateFunction.toString})") { _ =>
145168
withSQLConf(
146169
CometConf.COMET_ENABLED.key -> "true",
147170
CometConf.COMET_EXEC_ENABLED.key -> "true",
@@ -155,11 +178,14 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
155178
}
156179
}
157180

158-
def multiAggregates(values: Int, groupingKeyCard: Int, aggregateFunction: String): Unit = {
181+
def multiAggregates(
182+
values: Int,
183+
groupingKeyCard: Int,
184+
aggregateFunction: BenchAggregateFunction): Unit = {
159185
val benchmark =
160186
new Benchmark(
161187
s"Grouped HashAgg Exec: single group key (cardinality $groupingKeyCard), " +
162-
s"multiple aggregates $aggregateFunction",
188+
s"multiple aggregates ${aggregateFunction.toString}",
163189
values,
164190
output = output)
165191

@@ -171,14 +197,17 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
171197
s"SELECT value as value1, value as value2, floor(rand() * $groupingKeyCard) as key " +
172198
s"FROM $tbl"))
173199

174-
val query = s"SELECT key, $aggregateFunction(value1), $aggregateFunction(value2) " +
200+
val functionSQL1 = aggFunctionSQL(aggregateFunction, "value1")
201+
val functionSQL2 = aggFunctionSQL(aggregateFunction, "value2")
202+
203+
val query = s"SELECT key, $functionSQL1, $functionSQL2 " +
175204
"FROM parquetV1Table GROUP BY key"
176205

177-
benchmark.addCase(s"SQL Parquet - Spark ($aggregateFunction)") { _ =>
206+
benchmark.addCase(s"SQL Parquet - Spark (${aggregateFunction.toString})") { _ =>
178207
spark.sql(query).noop()
179208
}
180209

181-
benchmark.addCase(s"SQL Parquet - Comet ($aggregateFunction)") { _ =>
210+
benchmark.addCase(s"SQL Parquet - Comet (${aggregateFunction.toString})") { _ =>
182211
withSQLConf(
183212
CometConf.COMET_ENABLED.key -> "true",
184213
CometConf.COMET_EXEC_ENABLED.key -> "true") {
@@ -194,9 +223,8 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
194223
override def runCometBenchmark(mainArgs: Array[String]): Unit = {
195224
val total = 1024 * 1024 * 10
196225
val combinations = List(100, 1024, 1024 * 1024) // number of distinct groups
197-
val aggregateFunctions = List("SUM", "MIN", "MAX", "COUNT")
198226

199-
aggregateFunctions.foreach { aggFunc =>
227+
benchmarkAggFuncs.foreach { aggFunc =>
200228
runBenchmarkWithTable(
201229
s"Grouped Aggregate (single group key + single aggregate $aggFunc)",
202230
total) { v =>

0 commit comments

Comments
 (0)