Skip to content

Commit f029bf7

Browse files
committed
feat: do not fallback to Spark for distincts
1 parent 66103c9 commit f029bf7

File tree

2 files changed

+39
-5
lines changed

2 files changed

+39
-5
lines changed

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -564,7 +564,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
564564
// COUNT(DISTINCT x, y) - not supported
565565
if (aggExpr.isDistinct
566566
&&
567-
(aggExpr.aggregateFunction.prettyName == "count" &&
567+
!(aggExpr.aggregateFunction.prettyName == "count" &&
568568
aggExpr.aggregateFunction.children.length == 1)) {
569569
withInfo(aggExpr, s"Distinct aggregate not supported for: $aggExpr")
570570
return None

spark/src/test/scala/org/apache/comet/CometFuzzAggregateSuite.scala

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,14 @@
1919

2020
package org.apache.comet
2121

22+
import org.apache.comet.DataTypeSupport.isComplexType
23+
2224
class CometFuzzAggregateSuite extends CometFuzzTestBase {
2325

24-
test("count distinct") {
26+
test("count distinct - simple columns") {
2527
val df = spark.read.parquet(filename)
2628
df.createOrReplaceTempView("t1")
27-
for (col <- df.columns) {
29+
for (col <- df.schema.fields.filterNot(f => isComplexType(f.dataType)).map(_.name)) {
2830
val sql = s"SELECT count(distinct $col) FROM t1"
2931
val (_, cometPlan) = checkSparkAnswer(sql)
3032
if (usingDataSourceExec) {
@@ -35,10 +37,40 @@ class CometFuzzAggregateSuite extends CometFuzzTestBase {
3537
}
3638
}
3739

38-
test("count distinct group by multiple column") {
40+
// Aggregate by complex columns not yet supported
41+
// https://github.com/apache/datafusion-comet/issues/2382
42+
test("count distinct - complex columns") {
3943
val df = spark.read.parquet(filename)
4044
df.createOrReplaceTempView("t1")
41-
for (col <- df.columns) {
45+
for (col <- df.schema.fields.filter(f => isComplexType(f.dataType)).map(_.name)) {
46+
val sql = s"SELECT count(distinct $col) FROM t1"
47+
val (_, cometPlan) = checkSparkAnswer(sql)
48+
if (usingDataSourceExec) {
49+
assert(1 == collectNativeScans(cometPlan).length)
50+
}
51+
}
52+
}
53+
54+
test("count distinct group by multiple column - simple columns ") {
55+
val df = spark.read.parquet(filename)
56+
df.createOrReplaceTempView("t1")
57+
for (col <- df.schema.fields.filterNot(f => isComplexType(f.dataType)).map(_.name)) {
58+
val sql = s"SELECT c1, c2, c3, count(distinct $col) FROM t1 group by c1, c2, c3"
59+
val (_, cometPlan) = checkSparkAnswer(sql)
60+
if (usingDataSourceExec) {
61+
assert(1 == collectNativeScans(cometPlan).length)
62+
}
63+
64+
checkSparkAnswerAndOperator(sql)
65+
}
66+
}
67+
68+
// Aggregate by complex columns not yet supported
69+
// https://github.com/apache/datafusion-comet/issues/2382
70+
test("count distinct group by multiple column - complex columns ") {
71+
val df = spark.read.parquet(filename)
72+
df.createOrReplaceTempView("t1")
73+
for (col <- df.schema.fields.filter(f => isComplexType(f.dataType)).map(_.name)) {
4274
val sql = s"SELECT c1, c2, c3, count(distinct $col) FROM t1 group by c1, c2, c3"
4375
val (_, cometPlan) = checkSparkAnswer(sql)
4476
if (usingDataSourceExec) {
@@ -47,6 +79,8 @@ class CometFuzzAggregateSuite extends CometFuzzTestBase {
4779
}
4880
}
4981

82+
// Not yet supported
83+
// https://github.com/apache/datafusion-comet/issues/2292
5084
test("count distinct multiple values and group by multiple column") {
5185
val df = spark.read.parquet(filename)
5286
df.createOrReplaceTempView("t1")

0 commit comments

Comments
 (0)