Skip to content

Commit 3bb1b40

Browse files
authored
chore: Improve test coverage for count aggregates (#2406)
* refactor fuzz test * link to issue * add new test to CI
1 parent 341db1d commit 3bb1b40

File tree

5 files changed

+206
-119
lines changed

5 files changed

+206
-119
lines changed

.github/workflows/pr_build_linux.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ jobs:
102102
- name: "fuzz"
103103
value: |
104104
org.apache.comet.CometFuzzTestSuite
105+
org.apache.comet.CometFuzzAggregateSuite
105106
org.apache.comet.DataGeneratorSuite
106107
- name: "shuffle"
107108
value: |

.github/workflows/pr_build_macos.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ jobs:
6767
- name: "fuzz"
6868
value: |
6969
org.apache.comet.CometFuzzTestSuite
70+
org.apache.comet.CometFuzzAggregateSuite
7071
org.apache.comet.DataGeneratorSuite
7172
- name: "shuffle"
7273
value: |
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.comet
21+
22+
class CometFuzzAggregateSuite extends CometFuzzTestBase {
23+
24+
test("count distinct") {
25+
val df = spark.read.parquet(filename)
26+
df.createOrReplaceTempView("t1")
27+
for (col <- df.columns) {
28+
val sql = s"SELECT count(distinct $col) FROM t1"
29+
// Comet does not support count distinct yet
30+
// https://github.com/apache/datafusion-comet/issues/2292
31+
val (_, cometPlan) = checkSparkAnswer(sql)
32+
if (usingDataSourceExec) {
33+
assert(1 == collectNativeScans(cometPlan).length)
34+
}
35+
}
36+
}
37+
38+
test("count(*) group by single column") {
39+
val df = spark.read.parquet(filename)
40+
df.createOrReplaceTempView("t1")
41+
for (col <- df.columns) {
42+
// cannot run fully natively due to range partitioning and sort
43+
val sql = s"SELECT $col, count(*) FROM t1 GROUP BY $col ORDER BY $col"
44+
val (_, cometPlan) = checkSparkAnswer(sql)
45+
if (usingDataSourceExec) {
46+
assert(1 == collectNativeScans(cometPlan).length)
47+
}
48+
}
49+
}
50+
51+
test("count(col) group by single column") {
52+
val df = spark.read.parquet(filename)
53+
df.createOrReplaceTempView("t1")
54+
val groupCol = df.columns.head
55+
for (col <- df.columns.drop(1)) {
56+
// cannot run fully natively due to range partitioning and sort
57+
val sql = s"SELECT $groupCol, count($col) FROM t1 GROUP BY $groupCol ORDER BY $groupCol"
58+
val (_, cometPlan) = checkSparkAnswer(sql)
59+
if (usingDataSourceExec) {
60+
assert(1 == collectNativeScans(cometPlan).length)
61+
}
62+
}
63+
}
64+
65+
test("count(col1, col2, ..) group by single column") {
66+
val df = spark.read.parquet(filename)
67+
df.createOrReplaceTempView("t1")
68+
val groupCol = df.columns.head
69+
val otherCol = df.columns.drop(1)
70+
// cannot run fully natively due to range partitioning and sort
71+
val sql = s"SELECT $groupCol, count(${otherCol.mkString(", ")}) FROM t1 " +
72+
s"GROUP BY $groupCol ORDER BY $groupCol"
73+
val (_, cometPlan) = checkSparkAnswer(sql)
74+
if (usingDataSourceExec) {
75+
assert(1 == collectNativeScans(cometPlan).length)
76+
}
77+
}
78+
79+
test("min/max aggregate") {
80+
val df = spark.read.parquet(filename)
81+
df.createOrReplaceTempView("t1")
82+
for (col <- df.columns) {
83+
// cannot run fully native due to HashAggregate
84+
val sql = s"SELECT min($col), max($col) FROM t1"
85+
val (_, cometPlan) = checkSparkAnswer(sql)
86+
if (usingDataSourceExec) {
87+
assert(1 == collectNativeScans(cometPlan).length)
88+
}
89+
}
90+
}
91+
92+
}
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.comet
21+
22+
import java.io.File
23+
import java.text.SimpleDateFormat
24+
25+
import scala.util.Random
26+
27+
import org.scalactic.source.Position
28+
import org.scalatest.Tag
29+
30+
import org.apache.commons.io.FileUtils
31+
import org.apache.spark.sql.CometTestBase
32+
import org.apache.spark.sql.comet.{CometNativeScanExec, CometScanExec}
33+
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
34+
import org.apache.spark.sql.execution.SparkPlan
35+
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
36+
import org.apache.spark.sql.internal.SQLConf
37+
38+
import org.apache.comet.testing.{DataGenOptions, ParquetGenerator}
39+
40+
class CometFuzzTestBase extends CometTestBase with AdaptiveSparkPlanHelper {
41+
42+
var filename: String = null
43+
44+
/**
45+
* We use Asia/Kathmandu because it has a non-zero number of minutes as the offset, so is an
46+
* interesting edge case. Also, this timezone tends to be different from the default system
47+
* timezone.
48+
*
49+
* Represents UTC+5:45
50+
*/
51+
val defaultTimezone = "Asia/Kathmandu"
52+
53+
override def beforeAll(): Unit = {
54+
super.beforeAll()
55+
val tempDir = System.getProperty("java.io.tmpdir")
56+
filename = s"$tempDir/CometFuzzTestSuite_${System.currentTimeMillis()}.parquet"
57+
val random = new Random(42)
58+
withSQLConf(
59+
CometConf.COMET_ENABLED.key -> "false",
60+
SQLConf.SESSION_LOCAL_TIMEZONE.key -> defaultTimezone) {
61+
val options =
62+
DataGenOptions(
63+
generateArray = true,
64+
generateStruct = true,
65+
generateNegativeZero = false,
66+
// override base date due to known issues with experimental scans
67+
baseDate =
68+
new SimpleDateFormat("YYYY-MM-DD hh:mm:ss").parse("2024-05-25 12:34:56").getTime)
69+
ParquetGenerator.makeParquetFile(random, spark, filename, 1000, options)
70+
}
71+
}
72+
73+
protected override def afterAll(): Unit = {
74+
super.afterAll()
75+
FileUtils.deleteDirectory(new File(filename))
76+
}
77+
78+
override protected def test(testName: String, testTags: Tag*)(testFun: => Any)(implicit
79+
pos: Position): Unit = {
80+
Seq("native", "jvm").foreach { shuffleMode =>
81+
Seq(
82+
CometConf.SCAN_NATIVE_COMET,
83+
CometConf.SCAN_NATIVE_DATAFUSION,
84+
CometConf.SCAN_NATIVE_ICEBERG_COMPAT).foreach { scanImpl =>
85+
super.test(testName + s" ($scanImpl, $shuffleMode shuffle)", testTags: _*) {
86+
withSQLConf(
87+
CometConf.COMET_NATIVE_SCAN_IMPL.key -> scanImpl,
88+
CometConf.COMET_SCAN_ALLOW_INCOMPATIBLE.key -> "true",
89+
CometConf.COMET_SHUFFLE_MODE.key -> shuffleMode) {
90+
testFun
91+
}
92+
}
93+
}
94+
}
95+
}
96+
97+
def collectNativeScans(plan: SparkPlan): Seq[SparkPlan] = {
98+
collect(plan) {
99+
case scan: CometScanExec => scan
100+
case scan: CometNativeScanExec => scan
101+
}
102+
}
103+
104+
def collectCometShuffleExchanges(plan: SparkPlan): Seq[SparkPlan] = {
105+
collect(plan) { case exchange: CometShuffleExchangeExec =>
106+
exchange
107+
}
108+
}
109+
110+
}

spark/src/test/scala/org/apache/comet/CometFuzzTestSuite.scala

Lines changed: 2 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -19,65 +19,18 @@
1919

2020
package org.apache.comet
2121

22-
import java.io.File
23-
import java.text.SimpleDateFormat
24-
2522
import scala.util.Random
2623

27-
import org.scalactic.source.Position
28-
import org.scalatest.Tag
29-
3024
import org.apache.commons.codec.binary.Hex
31-
import org.apache.commons.io.FileUtils
32-
import org.apache.spark.sql.CometTestBase
33-
import org.apache.spark.sql.comet.{CometNativeScanExec, CometScanExec}
34-
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
35-
import org.apache.spark.sql.execution.SparkPlan
36-
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper}
25+
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
3726
import org.apache.spark.sql.internal.SQLConf
3827
import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType
3928
import org.apache.spark.sql.types._
4029

4130
import org.apache.comet.DataTypeSupport.isComplexType
4231
import org.apache.comet.testing.{DataGenOptions, ParquetGenerator}
4332

44-
class CometFuzzTestSuite extends CometTestBase with AdaptiveSparkPlanHelper {
45-
46-
private var filename: String = null
47-
48-
/**
49-
* We use Asia/Kathmandu because it has a non-zero number of minutes as the offset, so is an
50-
* interesting edge case. Also, this timezone tends to be different from the default system
51-
* timezone.
52-
*
53-
* Represents UTC+5:45
54-
*/
55-
private val defaultTimezone = "Asia/Kathmandu"
56-
57-
override def beforeAll(): Unit = {
58-
super.beforeAll()
59-
val tempDir = System.getProperty("java.io.tmpdir")
60-
filename = s"$tempDir/CometFuzzTestSuite_${System.currentTimeMillis()}.parquet"
61-
val random = new Random(42)
62-
withSQLConf(
63-
CometConf.COMET_ENABLED.key -> "false",
64-
SQLConf.SESSION_LOCAL_TIMEZONE.key -> defaultTimezone) {
65-
val options =
66-
DataGenOptions(
67-
generateArray = true,
68-
generateStruct = true,
69-
generateNegativeZero = false,
70-
// override base date due to known issues with experimental scans
71-
baseDate =
72-
new SimpleDateFormat("YYYY-MM-DD hh:mm:ss").parse("2024-05-25 12:34:56").getTime)
73-
ParquetGenerator.makeParquetFile(random, spark, filename, 1000, options)
74-
}
75-
}
76-
77-
protected override def afterAll(): Unit = {
78-
super.afterAll()
79-
FileUtils.deleteDirectory(new File(filename))
80-
}
33+
class CometFuzzTestSuite extends CometFuzzTestBase {
8134

8235
test("select *") {
8336
val df = spark.read.parquet(filename)
@@ -168,18 +121,6 @@ class CometFuzzTestSuite extends CometTestBase with AdaptiveSparkPlanHelper {
168121
}
169122
}
170123

171-
test("count distinct") {
172-
val df = spark.read.parquet(filename)
173-
df.createOrReplaceTempView("t1")
174-
for (col <- df.columns) {
175-
val sql = s"SELECT count(distinct $col) FROM t1"
176-
val (_, cometPlan) = checkSparkAnswer(sql)
177-
if (usingDataSourceExec) {
178-
assert(1 == collectNativeScans(cometPlan).length)
179-
}
180-
}
181-
}
182-
183124
test("order by multiple columns") {
184125
val df = spark.read.parquet(filename)
185126
df.createOrReplaceTempView("t1")
@@ -192,32 +133,6 @@ class CometFuzzTestSuite extends CometTestBase with AdaptiveSparkPlanHelper {
192133
}
193134
}
194135

195-
test("aggregate group by single column") {
196-
val df = spark.read.parquet(filename)
197-
df.createOrReplaceTempView("t1")
198-
for (col <- df.columns) {
199-
// cannot run fully natively due to range partitioning and sort
200-
val sql = s"SELECT $col, count(*) FROM t1 GROUP BY $col ORDER BY $col"
201-
val (_, cometPlan) = checkSparkAnswer(sql)
202-
if (usingDataSourceExec) {
203-
assert(1 == collectNativeScans(cometPlan).length)
204-
}
205-
}
206-
}
207-
208-
test("min/max aggregate") {
209-
val df = spark.read.parquet(filename)
210-
df.createOrReplaceTempView("t1")
211-
for (col <- df.columns) {
212-
// cannot run fully native due to HashAggregate
213-
val sql = s"SELECT min($col), max($col) FROM t1"
214-
val (_, cometPlan) = checkSparkAnswer(sql)
215-
if (usingDataSourceExec) {
216-
assert(1 == collectNativeScans(cometPlan).length)
217-
}
218-
}
219-
}
220-
221136
test("distribute by single column (complex types)") {
222137
val df = spark.read.parquet(filename)
223138
df.createOrReplaceTempView("t1")
@@ -371,36 +286,4 @@ class CometFuzzTestSuite extends CometTestBase with AdaptiveSparkPlanHelper {
371286
}
372287
}
373288

374-
override protected def test(testName: String, testTags: Tag*)(testFun: => Any)(implicit
375-
pos: Position): Unit = {
376-
Seq("native", "jvm").foreach { shuffleMode =>
377-
Seq(
378-
CometConf.SCAN_NATIVE_COMET,
379-
CometConf.SCAN_NATIVE_DATAFUSION,
380-
CometConf.SCAN_NATIVE_ICEBERG_COMPAT).foreach { scanImpl =>
381-
super.test(testName + s" ($scanImpl, $shuffleMode shuffle)", testTags: _*) {
382-
withSQLConf(
383-
CometConf.COMET_NATIVE_SCAN_IMPL.key -> scanImpl,
384-
CometConf.COMET_SCAN_ALLOW_INCOMPATIBLE.key -> "true",
385-
CometConf.COMET_SHUFFLE_MODE.key -> shuffleMode) {
386-
testFun
387-
}
388-
}
389-
}
390-
}
391-
}
392-
393-
private def collectNativeScans(plan: SparkPlan): Seq[SparkPlan] = {
394-
collect(plan) {
395-
case scan: CometScanExec => scan
396-
case scan: CometNativeScanExec => scan
397-
}
398-
}
399-
400-
private def collectCometShuffleExchanges(plan: SparkPlan): Seq[SparkPlan] = {
401-
collect(plan) { case exchange: CometShuffleExchangeExec =>
402-
exchange
403-
}
404-
}
405-
406289
}

0 commit comments

Comments
 (0)