diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 258d275e5b..892d8bca63 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -558,9 +558,15 @@ object QueryPlanSerde extends Logging with CometExprShim { binding: Boolean, conf: SQLConf): Option[AggExpr] = { - if (aggExpr.isDistinct) { - // https://github.com/apache/datafusion-comet/issues/1260 - withInfo(aggExpr, s"distinct aggregate not supported: $aggExpr") + // Support Count(distinct single_value) + // COUNT(DISTINCT x) - supported + // COUNT(DISTINCT x, x) - supported through transition to COUNT(DISTINCT x) + // COUNT(DISTINCT x, y) - not supported + if (aggExpr.isDistinct + && + !(aggExpr.aggregateFunction.prettyName == "count" && + aggExpr.aggregateFunction.children.length == 1)) { + withInfo(aggExpr, s"Distinct aggregate not supported for: $aggExpr") return None } diff --git a/spark/src/test/scala/org/apache/comet/CometFuzzAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/CometFuzzAggregateSuite.scala index 6466f8fc29..19812f38ce 100644 --- a/spark/src/test/scala/org/apache/comet/CometFuzzAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometFuzzAggregateSuite.scala @@ -19,15 +19,73 @@ package org.apache.comet +import org.apache.comet.DataTypeSupport.isComplexType + class CometFuzzAggregateSuite extends CometFuzzTestBase { - test("count distinct") { + test("count distinct - simple columns") { val df = spark.read.parquet(filename) df.createOrReplaceTempView("t1") - for (col <- df.columns) { + for (col <- df.schema.fields.filterNot(f => isComplexType(f.dataType)).map(_.name)) { + val sql = s"SELECT count(distinct $col) FROM t1" + val (_, cometPlan) = checkSparkAnswer(sql) + if (usingDataSourceExec) { + assert(1 == collectNativeScans(cometPlan).length) + } + + checkSparkAnswerAndOperator(sql) + } + } + + // Aggregate by complex columns not yet supported + // https://github.com/apache/datafusion-comet/issues/2382 + test("count distinct - complex columns") { + val df = spark.read.parquet(filename) + df.createOrReplaceTempView("t1") + for (col <- df.schema.fields.filter(f => isComplexType(f.dataType)).map(_.name)) { val sql = s"SELECT count(distinct $col) FROM t1" - // Comet does not support count distinct yet - // https://github.com/apache/datafusion-comet/issues/2292 + val (_, cometPlan) = checkSparkAnswer(sql) + if (usingDataSourceExec) { + assert(1 == collectNativeScans(cometPlan).length) + } + } + } + + test("count distinct group by multiple column - simple columns ") { + val df = spark.read.parquet(filename) + df.createOrReplaceTempView("t1") + for (col <- df.schema.fields.filterNot(f => isComplexType(f.dataType)).map(_.name)) { + val sql = s"SELECT c1, c2, c3, count(distinct $col) FROM t1 group by c1, c2, c3" + val (_, cometPlan) = checkSparkAnswer(sql) + if (usingDataSourceExec) { + assert(1 == collectNativeScans(cometPlan).length) + } + + checkSparkAnswerAndOperator(sql) + } + } + + // Aggregate by complex columns not yet supported + // https://github.com/apache/datafusion-comet/issues/2382 + test("count distinct group by multiple column - complex columns ") { + val df = spark.read.parquet(filename) + df.createOrReplaceTempView("t1") + for (col <- df.schema.fields.filter(f => isComplexType(f.dataType)).map(_.name)) { + val sql = s"SELECT c1, c2, c3, count(distinct $col) FROM t1 group by c1, c2, c3" + val (_, cometPlan) = checkSparkAnswer(sql) + if (usingDataSourceExec) { + assert(1 == collectNativeScans(cometPlan).length) + } + } + } + + // COUNT(distinct x, y, z, ...) not yet supported + // https://github.com/apache/datafusion-comet/issues/2292 + test("count distinct multiple values and group by multiple column") { + val df = spark.read.parquet(filename) + df.createOrReplaceTempView("t1") + for (col <- df.columns) { + val sql = s"SELECT c1, c2, c3, count(distinct $col, c4, c5) FROM t1 group by c1, c2, c3" val (_, cometPlan) = checkSparkAnswer(sql) if (usingDataSourceExec) { assert(1 == collectNativeScans(cometPlan).length) @@ -39,7 +97,6 @@ class CometFuzzAggregateSuite extends CometFuzzTestBase { val df = spark.read.parquet(filename) df.createOrReplaceTempView("t1") for (col <- df.columns) { - // cannot run fully natively due to range partitioning and sort val sql = s"SELECT $col, count(*) FROM t1 GROUP BY $col ORDER BY $col" val (_, cometPlan) = checkSparkAnswer(sql) if (usingDataSourceExec) { @@ -53,7 +110,6 @@ class CometFuzzAggregateSuite extends CometFuzzTestBase { df.createOrReplaceTempView("t1") val groupCol = df.columns.head for (col <- df.columns.drop(1)) { - // cannot run fully natively due to range partitioning and sort val sql = s"SELECT $groupCol, count($col) FROM t1 GROUP BY $groupCol ORDER BY $groupCol" val (_, cometPlan) = checkSparkAnswer(sql) if (usingDataSourceExec) { @@ -67,7 +123,6 @@ class CometFuzzAggregateSuite extends CometFuzzTestBase { df.createOrReplaceTempView("t1") val groupCol = df.columns.head val otherCol = df.columns.drop(1) - // cannot run fully natively due to range partitioning and sort val sql = s"SELECT $groupCol, count(${otherCol.mkString(", ")}) FROM t1 " + s"GROUP BY $groupCol ORDER BY $groupCol" val (_, cometPlan) = checkSparkAnswer(sql) @@ -88,5 +143,4 @@ class CometFuzzAggregateSuite extends CometFuzzTestBase { } } } - } diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 47d2205a08..5dfc4cbac2 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -1031,9 +1031,8 @@ class CometExecSuite extends CometTestBase { |GROUP BY key """.stripMargin) - // The above query uses COUNT(DISTINCT) which Comet doesn't support yet, so the plan will - // have a mix of `HashAggregate` and `CometHashAggregate`. In the following we check all - // operators starting from `CometHashAggregate` are native. + // The above query uses SUM(DISTINCT) and count(distinct value1, value2) + // which is not yet supported checkSparkAnswer(df) val subPlan = stripAQEPlan(df.queryExecution.executedPlan).collectFirst { case s: CometHashAggregateExec => s diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometAggregateBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometAggregateBenchmark.scala index 47fbe354f5..1efd3974ed 100644 --- a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometAggregateBenchmark.scala +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometAggregateBenchmark.scala @@ -41,14 +41,31 @@ object CometAggregateBenchmark extends CometBenchmarkBase { session } + // Wrapper on SQL aggregation function + case class BenchAggregateFunction(name: String, distinct: Boolean = false) { + override def toString: String = if (distinct) s"$name(DISTINCT)" else name + } + + // Aggregation functions to test + private val benchmarkAggFuncs = Seq( + BenchAggregateFunction("SUM"), + BenchAggregateFunction("MIN"), + BenchAggregateFunction("MAX"), + BenchAggregateFunction("COUNT"), + BenchAggregateFunction("COUNT", distinct = true)) + + def aggFunctionSQL(aggregateFunction: BenchAggregateFunction, input: String): String = { + s"${aggregateFunction.name}(${if (aggregateFunction.distinct) s"DISTINCT $input" else input})" + } + def singleGroupAndAggregate( values: Int, groupingKeyCardinality: Int, - aggregateFunction: String): Unit = { + aggregateFunction: BenchAggregateFunction): Unit = { val benchmark = new Benchmark( s"Grouped HashAgg Exec: single group key (cardinality $groupingKeyCardinality), " + - s"single aggregate $aggregateFunction", + s"single aggregate ${aggregateFunction.toString}", values, output = output) @@ -58,13 +75,14 @@ object CometAggregateBenchmark extends CometBenchmarkBase { dir, spark.sql(s"SELECT value, floor(rand() * $groupingKeyCardinality) as key FROM $tbl")) - val query = s"SELECT key, $aggregateFunction(value) FROM parquetV1Table GROUP BY key" + val functionSQL = aggFunctionSQL(aggregateFunction, "value") + val query = s"SELECT key, $functionSQL FROM parquetV1Table GROUP BY key" - benchmark.addCase(s"SQL Parquet - Spark ($aggregateFunction)") { _ => + benchmark.addCase(s"SQL Parquet - Spark (${aggregateFunction.toString})") { _ => spark.sql(query).noop() } - benchmark.addCase(s"SQL Parquet - Comet ($aggregateFunction)") { _ => + benchmark.addCase(s"SQL Parquet - Comet (${aggregateFunction.toString})") { _ => withSQLConf( CometConf.COMET_ENABLED.key -> "true", CometConf.COMET_EXEC_ENABLED.key -> "true") { @@ -81,11 +99,11 @@ object CometAggregateBenchmark extends CometBenchmarkBase { values: Int, dataType: DecimalType, groupingKeyCardinality: Int, - aggregateFunction: String): Unit = { + aggregateFunction: BenchAggregateFunction): Unit = { val benchmark = new Benchmark( s"Grouped HashAgg Exec: single group key (cardinality $groupingKeyCardinality), " + - s"single aggregate $aggregateFunction on decimal", + s"single aggregate ${aggregateFunction.toString} on decimal", values, output = output) @@ -99,13 +117,14 @@ object CometAggregateBenchmark extends CometBenchmarkBase { spark.sql( s"SELECT dec as value, floor(rand() * $groupingKeyCardinality) as key FROM $tbl")) - val query = s"SELECT key, $aggregateFunction(value) FROM parquetV1Table GROUP BY key" + val functionSQL = aggFunctionSQL(aggregateFunction, "value") + val query = s"SELECT key, $functionSQL FROM parquetV1Table GROUP BY key" - benchmark.addCase(s"SQL Parquet - Spark ($aggregateFunction)") { _ => + benchmark.addCase(s"SQL Parquet - Spark (${aggregateFunction.toString})") { _ => spark.sql(query).noop() } - benchmark.addCase(s"SQL Parquet - Comet ($aggregateFunction)") { _ => + benchmark.addCase(s"SQL Parquet - Comet (${aggregateFunction.toString})") { _ => withSQLConf( CometConf.COMET_ENABLED.key -> "true", CometConf.COMET_EXEC_ENABLED.key -> "true") { @@ -118,11 +137,14 @@ object CometAggregateBenchmark extends CometBenchmarkBase { } } - def multiGroupKeys(values: Int, groupingKeyCard: Int, aggregateFunction: String): Unit = { + def multiGroupKeys( + values: Int, + groupingKeyCard: Int, + aggregateFunction: BenchAggregateFunction): Unit = { val benchmark = new Benchmark( s"Grouped HashAgg Exec: multiple group keys (cardinality $groupingKeyCard), " + - s"single aggregate $aggregateFunction", + s"single aggregate ${aggregateFunction.toString}", values, output = output) @@ -134,14 +156,15 @@ object CometAggregateBenchmark extends CometBenchmarkBase { s"SELECT value, floor(rand() * $groupingKeyCard) as key1, " + s"floor(rand() * $groupingKeyCard) as key2 FROM $tbl")) + val functionSQL = aggFunctionSQL(aggregateFunction, "value") val query = - s"SELECT key1, key2, $aggregateFunction(value) FROM parquetV1Table GROUP BY key1, key2" + s"SELECT key1, key2, $functionSQL FROM parquetV1Table GROUP BY key1, key2" - benchmark.addCase(s"SQL Parquet - Spark ($aggregateFunction)") { _ => + benchmark.addCase(s"SQL Parquet - Spark (${aggregateFunction.toString})") { _ => spark.sql(query).noop() } - benchmark.addCase(s"SQL Parquet - Comet ($aggregateFunction)") { _ => + benchmark.addCase(s"SQL Parquet - Comet (${aggregateFunction.toString})") { _ => withSQLConf( CometConf.COMET_ENABLED.key -> "true", CometConf.COMET_EXEC_ENABLED.key -> "true", @@ -155,11 +178,14 @@ object CometAggregateBenchmark extends CometBenchmarkBase { } } - def multiAggregates(values: Int, groupingKeyCard: Int, aggregateFunction: String): Unit = { + def multiAggregates( + values: Int, + groupingKeyCard: Int, + aggregateFunction: BenchAggregateFunction): Unit = { val benchmark = new Benchmark( s"Grouped HashAgg Exec: single group key (cardinality $groupingKeyCard), " + - s"multiple aggregates $aggregateFunction", + s"multiple aggregates ${aggregateFunction.toString}", values, output = output) @@ -171,14 +197,17 @@ object CometAggregateBenchmark extends CometBenchmarkBase { s"SELECT value as value1, value as value2, floor(rand() * $groupingKeyCard) as key " + s"FROM $tbl")) - val query = s"SELECT key, $aggregateFunction(value1), $aggregateFunction(value2) " + + val functionSQL1 = aggFunctionSQL(aggregateFunction, "value1") + val functionSQL2 = aggFunctionSQL(aggregateFunction, "value2") + + val query = s"SELECT key, $functionSQL1, $functionSQL2 " + "FROM parquetV1Table GROUP BY key" - benchmark.addCase(s"SQL Parquet - Spark ($aggregateFunction)") { _ => + benchmark.addCase(s"SQL Parquet - Spark (${aggregateFunction.toString})") { _ => spark.sql(query).noop() } - benchmark.addCase(s"SQL Parquet - Comet ($aggregateFunction)") { _ => + benchmark.addCase(s"SQL Parquet - Comet (${aggregateFunction.toString})") { _ => withSQLConf( CometConf.COMET_ENABLED.key -> "true", CometConf.COMET_EXEC_ENABLED.key -> "true") { @@ -194,9 +223,8 @@ object CometAggregateBenchmark extends CometBenchmarkBase { override def runCometBenchmark(mainArgs: Array[String]): Unit = { val total = 1024 * 1024 * 10 val combinations = List(100, 1024, 1024 * 1024) // number of distinct groups - val aggregateFunctions = List("SUM", "MIN", "MAX", "COUNT") - aggregateFunctions.foreach { aggFunc => + benchmarkAggFuncs.foreach { aggFunc => runBenchmarkWithTable( s"Grouped Aggregate (single group key + single aggregate $aggFunc)", total) { v =>