Skip to content
Merged
12 changes: 9 additions & 3 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -558,9 +558,15 @@ object QueryPlanSerde extends Logging with CometExprShim {
binding: Boolean,
conf: SQLConf): Option[AggExpr] = {

if (aggExpr.isDistinct) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we need to pass the aggExpr.isDistinct value into the protobuf plan?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a good point, I was thinking the same but IMO Spark doesn't call the count distinct on partial phase.

                +-----------------------------+
                |         Driver              |
                |   COUNT(DISTINCT name)      |
                +-------------+---------------+
                              |
                              v
     +-------------------+          +-------------------+
     |   Executor 1      |          |   Executor 2      |
     | Partitions P0,P1  |          | Partitions P2,P3  |
     | Local distinct:   |          | Local distinct:   |
     | {Alice,Bob,Eve}   |          | {Mallory,Eve,Bob,Trent}|
     +---------+---------+          +---------+---------+
               |                              |
               |          Shuffle             |
               v                              v
        +--------------+               +--------------+
        | Reducer R0   |               | Reducer R1   |
        | {Alice,Bob,Eve}              | {Mallory,Trent} |
        +--------------+               +--------------+
               \                              /
                \                            /
                 \                          /
                  +-----------+-------------+
                              v
                    Driver Final Merge
                      DISTINCT = 5

The local distinct is made by HashAggregate so when count distinct get called as aggExpr it might not be needing the flag as data already deduped on reducers. Checking the Final stage though

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explaining that. We should add tests for other distinct aggregates as well, such as sum and avg. I'm not sure if there are others?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spark has tests for the following aggregates with DISTINCT:

  • count
  • sum
  • avg
  • first
  • last
  • corr
  • var_pop
  • var_samp

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this PR, we could just remove the fallback for COUNT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, will do once

// https://github.com/apache/datafusion-comet/issues/1260
withInfo(aggExpr, s"distinct aggregate not supported: $aggExpr")
// Support Count(distinct single_value)
// COUNT(DISTINCT x) - supported
// COUNT(DISTINCT x, x) - supported through transition to COUNT(DISTINCT x)
// COUNT(DISTINCT x, y) - not supported
if (aggExpr.isDistinct
&&
!(aggExpr.aggregateFunction.prettyName == "count" &&
aggExpr.aggregateFunction.children.length == 1)) {
withInfo(aggExpr, s"Distinct aggregate not supported for: $aggExpr")
return None
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,73 @@

package org.apache.comet

import org.apache.comet.DataTypeSupport.isComplexType

class CometFuzzAggregateSuite extends CometFuzzTestBase {

test("count distinct") {
test("count distinct - simple columns") {
val df = spark.read.parquet(filename)
df.createOrReplaceTempView("t1")
for (col <- df.columns) {
for (col <- df.schema.fields.filterNot(f => isComplexType(f.dataType)).map(_.name)) {
val sql = s"SELECT count(distinct $col) FROM t1"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you also add tests for count distinct with multiple columns e.g. COUNT(DISTINCT col1, col2, col3)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @andygrove I'll add them separately as well

val (_, cometPlan) = checkSparkAnswer(sql)
if (usingDataSourceExec) {
assert(1 == collectNativeScans(cometPlan).length)
}

checkSparkAnswerAndOperator(sql)
}
}

// Aggregate by complex columns not yet supported
// https://github.com/apache/datafusion-comet/issues/2382
test("count distinct - complex columns") {
val df = spark.read.parquet(filename)
df.createOrReplaceTempView("t1")
for (col <- df.schema.fields.filter(f => isComplexType(f.dataType)).map(_.name)) {
val sql = s"SELECT count(distinct $col) FROM t1"
// Comet does not support count distinct yet
// https://github.com/apache/datafusion-comet/issues/2292
val (_, cometPlan) = checkSparkAnswer(sql)
if (usingDataSourceExec) {
assert(1 == collectNativeScans(cometPlan).length)
}
}
}

test("count distinct group by multiple column - simple columns ") {
val df = spark.read.parquet(filename)
df.createOrReplaceTempView("t1")
for (col <- df.schema.fields.filterNot(f => isComplexType(f.dataType)).map(_.name)) {
val sql = s"SELECT c1, c2, c3, count(distinct $col) FROM t1 group by c1, c2, c3"
val (_, cometPlan) = checkSparkAnswer(sql)
if (usingDataSourceExec) {
assert(1 == collectNativeScans(cometPlan).length)
}

checkSparkAnswerAndOperator(sql)
}
}

// Aggregate by complex columns not yet supported
// https://github.com/apache/datafusion-comet/issues/2382
test("count distinct group by multiple column - complex columns ") {
val df = spark.read.parquet(filename)
df.createOrReplaceTempView("t1")
for (col <- df.schema.fields.filter(f => isComplexType(f.dataType)).map(_.name)) {
val sql = s"SELECT c1, c2, c3, count(distinct $col) FROM t1 group by c1, c2, c3"
val (_, cometPlan) = checkSparkAnswer(sql)
if (usingDataSourceExec) {
assert(1 == collectNativeScans(cometPlan).length)
}
}
}

// COUNT(distinct x, y, z, ...) not yet supported
// https://github.com/apache/datafusion-comet/issues/2292
test("count distinct multiple values and group by multiple column") {
val df = spark.read.parquet(filename)
df.createOrReplaceTempView("t1")
for (col <- df.columns) {
val sql = s"SELECT c1, c2, c3, count(distinct $col, c4, c5) FROM t1 group by c1, c2, c3"
val (_, cometPlan) = checkSparkAnswer(sql)
if (usingDataSourceExec) {
assert(1 == collectNativeScans(cometPlan).length)
Expand All @@ -39,7 +97,6 @@ class CometFuzzAggregateSuite extends CometFuzzTestBase {
val df = spark.read.parquet(filename)
df.createOrReplaceTempView("t1")
for (col <- df.columns) {
// cannot run fully natively due to range partitioning and sort
val sql = s"SELECT $col, count(*) FROM t1 GROUP BY $col ORDER BY $col"
val (_, cometPlan) = checkSparkAnswer(sql)
if (usingDataSourceExec) {
Expand All @@ -53,7 +110,6 @@ class CometFuzzAggregateSuite extends CometFuzzTestBase {
df.createOrReplaceTempView("t1")
val groupCol = df.columns.head
for (col <- df.columns.drop(1)) {
// cannot run fully natively due to range partitioning and sort
val sql = s"SELECT $groupCol, count($col) FROM t1 GROUP BY $groupCol ORDER BY $groupCol"
val (_, cometPlan) = checkSparkAnswer(sql)
if (usingDataSourceExec) {
Expand All @@ -67,7 +123,6 @@ class CometFuzzAggregateSuite extends CometFuzzTestBase {
df.createOrReplaceTempView("t1")
val groupCol = df.columns.head
val otherCol = df.columns.drop(1)
// cannot run fully natively due to range partitioning and sort
val sql = s"SELECT $groupCol, count(${otherCol.mkString(", ")}) FROM t1 " +
s"GROUP BY $groupCol ORDER BY $groupCol"
val (_, cometPlan) = checkSparkAnswer(sql)
Expand All @@ -88,5 +143,4 @@ class CometFuzzAggregateSuite extends CometFuzzTestBase {
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -1031,9 +1031,8 @@ class CometExecSuite extends CometTestBase {
|GROUP BY key
""".stripMargin)

// The above query uses COUNT(DISTINCT) which Comet doesn't support yet, so the plan will
// have a mix of `HashAggregate` and `CometHashAggregate`. In the following we check all
// operators starting from `CometHashAggregate` are native.
// The above query uses SUM(DISTINCT) and count(distinct value1, value2)
// which is not yet supported
checkSparkAnswer(df)
val subPlan = stripAQEPlan(df.queryExecution.executedPlan).collectFirst {
case s: CometHashAggregateExec => s
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,31 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
session
}

// Wrapper on SQL aggregation function
case class BenchAggregateFunction(name: String, distinct: Boolean = false) {
override def toString: String = if (distinct) s"$name(DISTINCT)" else name
}

// Aggregation functions to test
private val benchmarkAggFuncs = Seq(
BenchAggregateFunction("SUM"),
BenchAggregateFunction("MIN"),
BenchAggregateFunction("MAX"),
BenchAggregateFunction("COUNT"),
BenchAggregateFunction("COUNT", distinct = true))

def aggFunctionSQL(aggregateFunction: BenchAggregateFunction, input: String): String = {
s"${aggregateFunction.name}(${if (aggregateFunction.distinct) s"DISTINCT $input" else input})"
}

def singleGroupAndAggregate(
values: Int,
groupingKeyCardinality: Int,
aggregateFunction: String): Unit = {
aggregateFunction: BenchAggregateFunction): Unit = {
val benchmark =
new Benchmark(
s"Grouped HashAgg Exec: single group key (cardinality $groupingKeyCardinality), " +
s"single aggregate $aggregateFunction",
s"single aggregate ${aggregateFunction.toString}",
values,
output = output)

Expand All @@ -58,13 +75,14 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
dir,
spark.sql(s"SELECT value, floor(rand() * $groupingKeyCardinality) as key FROM $tbl"))

val query = s"SELECT key, $aggregateFunction(value) FROM parquetV1Table GROUP BY key"
val functionSQL = aggFunctionSQL(aggregateFunction, "value")
val query = s"SELECT key, $functionSQL FROM parquetV1Table GROUP BY key"

benchmark.addCase(s"SQL Parquet - Spark ($aggregateFunction)") { _ =>
benchmark.addCase(s"SQL Parquet - Spark (${aggregateFunction.toString})") { _ =>
spark.sql(query).noop()
}

benchmark.addCase(s"SQL Parquet - Comet ($aggregateFunction)") { _ =>
benchmark.addCase(s"SQL Parquet - Comet (${aggregateFunction.toString})") { _ =>
withSQLConf(
CometConf.COMET_ENABLED.key -> "true",
CometConf.COMET_EXEC_ENABLED.key -> "true") {
Expand All @@ -81,11 +99,11 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
values: Int,
dataType: DecimalType,
groupingKeyCardinality: Int,
aggregateFunction: String): Unit = {
aggregateFunction: BenchAggregateFunction): Unit = {
val benchmark =
new Benchmark(
s"Grouped HashAgg Exec: single group key (cardinality $groupingKeyCardinality), " +
s"single aggregate $aggregateFunction on decimal",
s"single aggregate ${aggregateFunction.toString} on decimal",
values,
output = output)

Expand All @@ -99,13 +117,14 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
spark.sql(
s"SELECT dec as value, floor(rand() * $groupingKeyCardinality) as key FROM $tbl"))

val query = s"SELECT key, $aggregateFunction(value) FROM parquetV1Table GROUP BY key"
val functionSQL = aggFunctionSQL(aggregateFunction, "value")
val query = s"SELECT key, $functionSQL FROM parquetV1Table GROUP BY key"

benchmark.addCase(s"SQL Parquet - Spark ($aggregateFunction)") { _ =>
benchmark.addCase(s"SQL Parquet - Spark (${aggregateFunction.toString})") { _ =>
spark.sql(query).noop()
}

benchmark.addCase(s"SQL Parquet - Comet ($aggregateFunction)") { _ =>
benchmark.addCase(s"SQL Parquet - Comet (${aggregateFunction.toString})") { _ =>
withSQLConf(
CometConf.COMET_ENABLED.key -> "true",
CometConf.COMET_EXEC_ENABLED.key -> "true") {
Expand All @@ -118,11 +137,14 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
}
}

def multiGroupKeys(values: Int, groupingKeyCard: Int, aggregateFunction: String): Unit = {
def multiGroupKeys(
values: Int,
groupingKeyCard: Int,
aggregateFunction: BenchAggregateFunction): Unit = {
val benchmark =
new Benchmark(
s"Grouped HashAgg Exec: multiple group keys (cardinality $groupingKeyCard), " +
s"single aggregate $aggregateFunction",
s"single aggregate ${aggregateFunction.toString}",
values,
output = output)

Expand All @@ -134,14 +156,15 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
s"SELECT value, floor(rand() * $groupingKeyCard) as key1, " +
s"floor(rand() * $groupingKeyCard) as key2 FROM $tbl"))

val functionSQL = aggFunctionSQL(aggregateFunction, "value")
val query =
s"SELECT key1, key2, $aggregateFunction(value) FROM parquetV1Table GROUP BY key1, key2"
s"SELECT key1, key2, $functionSQL FROM parquetV1Table GROUP BY key1, key2"

benchmark.addCase(s"SQL Parquet - Spark ($aggregateFunction)") { _ =>
benchmark.addCase(s"SQL Parquet - Spark (${aggregateFunction.toString})") { _ =>
spark.sql(query).noop()
}

benchmark.addCase(s"SQL Parquet - Comet ($aggregateFunction)") { _ =>
benchmark.addCase(s"SQL Parquet - Comet (${aggregateFunction.toString})") { _ =>
withSQLConf(
CometConf.COMET_ENABLED.key -> "true",
CometConf.COMET_EXEC_ENABLED.key -> "true",
Expand All @@ -155,11 +178,14 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
}
}

def multiAggregates(values: Int, groupingKeyCard: Int, aggregateFunction: String): Unit = {
def multiAggregates(
values: Int,
groupingKeyCard: Int,
aggregateFunction: BenchAggregateFunction): Unit = {
val benchmark =
new Benchmark(
s"Grouped HashAgg Exec: single group key (cardinality $groupingKeyCard), " +
s"multiple aggregates $aggregateFunction",
s"multiple aggregates ${aggregateFunction.toString}",
values,
output = output)

Expand All @@ -171,14 +197,17 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
s"SELECT value as value1, value as value2, floor(rand() * $groupingKeyCard) as key " +
s"FROM $tbl"))

val query = s"SELECT key, $aggregateFunction(value1), $aggregateFunction(value2) " +
val functionSQL1 = aggFunctionSQL(aggregateFunction, "value1")
val functionSQL2 = aggFunctionSQL(aggregateFunction, "value2")

val query = s"SELECT key, $functionSQL1, $functionSQL2 " +
"FROM parquetV1Table GROUP BY key"

benchmark.addCase(s"SQL Parquet - Spark ($aggregateFunction)") { _ =>
benchmark.addCase(s"SQL Parquet - Spark (${aggregateFunction.toString})") { _ =>
spark.sql(query).noop()
}

benchmark.addCase(s"SQL Parquet - Comet ($aggregateFunction)") { _ =>
benchmark.addCase(s"SQL Parquet - Comet (${aggregateFunction.toString})") { _ =>
withSQLConf(
CometConf.COMET_ENABLED.key -> "true",
CometConf.COMET_EXEC_ENABLED.key -> "true") {
Expand All @@ -194,9 +223,8 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
override def runCometBenchmark(mainArgs: Array[String]): Unit = {
val total = 1024 * 1024 * 10
val combinations = List(100, 1024, 1024 * 1024) // number of distinct groups
val aggregateFunctions = List("SUM", "MIN", "MAX", "COUNT")

aggregateFunctions.foreach { aggFunc =>
benchmarkAggFuncs.foreach { aggFunc =>
runBenchmarkWithTable(
s"Grouped Aggregate (single group key + single aggregate $aggFunc)",
total) { v =>
Expand Down
Loading