Skip to content

Commit 4b577f8

Browse files
authored
chore: Create simple fuzz test as part of test suite (#1610)
1 parent 23dfb03 commit 4b577f8

File tree

6 files changed

+227
-10
lines changed

6 files changed

+227
-10
lines changed

common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@ import scala.collection.JavaConverters._
2727

2828
import org.apache.arrow.c.CDataDictionaryProvider
2929
import org.apache.arrow.vector.{BigIntVector, BitVector, DateDayVector, DecimalVector, FieldVector, FixedSizeBinaryVector, Float4Vector, Float8Vector, IntVector, SmallIntVector, TimeStampMicroTZVector, TimeStampMicroVector, TinyIntVector, ValueVector, VarBinaryVector, VarCharVector, VectorSchemaRoot}
30-
import org.apache.arrow.vector.complex.MapVector
31-
import org.apache.arrow.vector.complex.StructVector
30+
import org.apache.arrow.vector.complex.{ListVector, MapVector, StructVector}
3231
import org.apache.arrow.vector.dictionary.DictionaryProvider
3332
import org.apache.arrow.vector.ipc.ArrowStreamWriter
3433
import org.apache.arrow.vector.types._
@@ -278,7 +277,7 @@ object Utils {
278277
case v @ (_: BitVector | _: TinyIntVector | _: SmallIntVector | _: IntVector |
279278
_: BigIntVector | _: Float4Vector | _: Float8Vector | _: VarCharVector |
280279
_: DecimalVector | _: DateDayVector | _: TimeStampMicroTZVector | _: VarBinaryVector |
281-
_: FixedSizeBinaryVector | _: TimeStampMicroVector | _: StructVector) =>
280+
_: FixedSizeBinaryVector | _: TimeStampMicroVector | _: StructVector | _: ListVector) =>
282281
v.asInstanceOf[FieldVector]
283282
case _ =>
284283
throw new SparkException(s"Unsupported Arrow Vector for $reason: ${valueVector.getClass}")

docs/source/user-guide/datatypes.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,5 @@ The following Spark data types are currently available:
3939
- Timestamp
4040
- TimestampNTZ
4141
- Null
42+
- Struct
43+
- Array

docs/source/user-guide/installation.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ See the [Comet Kubernetes Guide](kubernetes.md) guide.
8484

8585
Make sure `SPARK_HOME` points to the same Spark version as Comet was built for.
8686

87-
```console
87+
```shell
8888
export COMET_JAR=spark/target/comet-spark-spark3.4_2.12-0.8.0-SNAPSHOT.jar
8989

9090
$SPARK_HOME/bin/spark-shell \
@@ -95,7 +95,7 @@ $SPARK_HOME/bin/spark-shell \
9595
--conf spark.shuffle.manager=org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager \
9696
--conf spark.comet.explainFallback.enabled=true \
9797
--conf spark.memory.offHeap.enabled=true \
98-
--conf spark.memory.offHeap.size=16g \
98+
--conf spark.memory.offHeap.size=16g
9999
```
100100

101101
### Verify Comet enabled for Spark SQL query

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2459,6 +2459,10 @@ object QueryPlanSerde extends Logging with CometExprShim {
24592459
}
24602460

24612461
val groupingExprs = groupingExpressions.map(exprToProto(_, child.output))
2462+
if (groupingExprs.exists(_.isEmpty)) {
2463+
withInfo(op, "Not all grouping expressions are supported")
2464+
return None
2465+
}
24622466

24632467
// In some of the cases, the aggregateExpressions could be empty.
24642468
// For example, if the aggregate functions only have group by or if the aggregate

spark/src/main/scala/org/apache/comet/testing/ParquetGenerator.scala

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ package org.apache.comet.testing
2222
import java.math.{BigDecimal, RoundingMode}
2323
import java.nio.charset.Charset
2424
import java.sql.Timestamp
25+
import java.text.SimpleDateFormat
26+
import java.time.{Instant, LocalDateTime, ZoneId}
2527

2628
import scala.collection.mutable.ListBuffer
2729
import scala.util.Random
@@ -31,6 +33,13 @@ import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, DecimalType,
3133

3234
object ParquetGenerator {
3335

36+
/**
37+
* Arbitrary date to use as base for generating temporal columns. Random integers will be added
38+
* to or subtracted from this value.
39+
*/
40+
private val baseDate =
41+
new SimpleDateFormat("YYYY-MM-DD hh:mm:ss").parse("2024-05-25 12:34:56").getTime
42+
3443
private val primitiveTypes = Seq(
3544
DataTypes.BooleanType,
3645
DataTypes.ByteType,
@@ -43,8 +52,7 @@ object ParquetGenerator {
4352
DataTypes.createDecimalType(36, 18),
4453
DataTypes.DateType,
4554
DataTypes.TimestampType,
46-
// TimestampNTZType only in Spark 3.4+
47-
// DataTypes.TimestampNTZType,
55+
DataTypes.TimestampNTZType,
4856
DataTypes.StringType,
4957
DataTypes.BinaryType)
5058

@@ -58,17 +66,24 @@ object ParquetGenerator {
5866
val dataTypes = ListBuffer[DataType]()
5967
dataTypes.appendAll(primitiveTypes)
6068

69+
val arraysOfPrimitives = primitiveTypes.map(DataTypes.createArrayType)
70+
6171
if (options.generateStruct) {
6272
dataTypes += StructType(
6373
primitiveTypes.zipWithIndex.map(x => StructField(s"c${x._2}", x._1, true)))
74+
75+
if (options.generateArray) {
76+
dataTypes += StructType(
77+
arraysOfPrimitives.zipWithIndex.map(x => StructField(s"c${x._2}", x._1, true)))
78+
}
6479
}
6580

6681
if (options.generateMap) {
6782
dataTypes += MapType(DataTypes.IntegerType, DataTypes.StringType)
6883
}
6984

7085
if (options.generateArray) {
71-
dataTypes.appendAll(primitiveTypes.map(DataTypes.createArrayType))
86+
dataTypes.appendAll(arraysOfPrimitives)
7287

7388
if (options.generateStruct) {
7489
dataTypes += DataTypes.createArrayType(
@@ -202,9 +217,14 @@ object ParquetGenerator {
202217
null
203218
}
204219
case DataTypes.DateType =>
205-
Range(0, numRows).map(_ => new java.sql.Date(1716645600011L + r.nextInt()))
220+
Range(0, numRows).map(_ => new java.sql.Date(baseDate + r.nextInt()))
206221
case DataTypes.TimestampType =>
207-
Range(0, numRows).map(_ => new Timestamp(1716645600011L + r.nextInt()))
222+
Range(0, numRows).map(_ => new Timestamp(baseDate + r.nextInt()))
223+
case DataTypes.TimestampNTZType =>
224+
Range(0, numRows).map(_ =>
225+
LocalDateTime.ofInstant(
226+
Instant.ofEpochMilli(baseDate + r.nextInt()),
227+
ZoneId.systemDefault()))
208228
case _ => throw new IllegalStateException(s"Cannot generate data for $dataType yet")
209229
}
210230
}
Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.comet
21+
22+
import java.io.File
23+
24+
import scala.util.Random
25+
26+
import org.scalactic.source.Position
27+
import org.scalatest.Tag
28+
29+
import org.apache.commons.io.FileUtils
30+
import org.apache.spark.sql.CometTestBase
31+
import org.apache.spark.sql.comet.{CometNativeScanExec, CometScanExec}
32+
import org.apache.spark.sql.execution.SparkPlan
33+
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
34+
import org.apache.spark.sql.internal.SQLConf
35+
36+
import org.apache.comet.testing.{DataGenOptions, ParquetGenerator}
37+
38+
class CometFuzzTestSuite extends CometTestBase with AdaptiveSparkPlanHelper {
39+
40+
private var filename: String = null
41+
42+
/**
43+
* We use Asia/Kathmandu because it has a non-zero number of minutes as the offset, so is an
44+
* interesting edge case. Also, this timezone tends to be different from the default system
45+
* timezone.
46+
*
47+
* Represents UTC+5:45
48+
*/
49+
private val defaultTimezone = "Asia/Kathmandu"
50+
51+
override def beforeAll(): Unit = {
52+
super.beforeAll()
53+
val tempDir = System.getProperty("java.io.tmpdir")
54+
filename = s"$tempDir/CometFuzzTestSuite_${System.currentTimeMillis()}.parquet"
55+
val random = new Random(42)
56+
withSQLConf(
57+
CometConf.COMET_ENABLED.key -> "false",
58+
SQLConf.SESSION_LOCAL_TIMEZONE.key -> defaultTimezone) {
59+
val options =
60+
DataGenOptions(generateArray = true, generateStruct = true, generateNegativeZero = false)
61+
ParquetGenerator.makeParquetFile(random, spark, filename, 1000, options)
62+
}
63+
}
64+
65+
protected override def afterAll(): Unit = {
66+
super.afterAll()
67+
FileUtils.deleteDirectory(new File(filename))
68+
}
69+
70+
test("select *") {
71+
val df = spark.read.parquet(filename)
72+
df.createOrReplaceTempView("t1")
73+
val sql = "SELECT * FROM t1"
74+
if (CometConf.isExperimentalNativeScan) {
75+
checkSparkAnswerAndOperator(sql)
76+
} else {
77+
checkSparkAnswer(sql)
78+
}
79+
}
80+
81+
test("select * with limit") {
82+
val df = spark.read.parquet(filename)
83+
df.createOrReplaceTempView("t1")
84+
val sql = "SELECT * FROM t1 LIMIT 500"
85+
if (CometConf.isExperimentalNativeScan) {
86+
checkSparkAnswerAndOperator(sql)
87+
} else {
88+
checkSparkAnswer(sql)
89+
}
90+
}
91+
92+
test("order by single column") {
93+
val df = spark.read.parquet(filename)
94+
df.createOrReplaceTempView("t1")
95+
for (col <- df.columns) {
96+
val sql = s"SELECT $col FROM t1 ORDER BY $col"
97+
// cannot run fully natively due to range partitioning and sort
98+
val (_, cometPlan) = checkSparkAnswer(sql)
99+
if (CometConf.isExperimentalNativeScan) {
100+
assert(1 == collectNativeScans(cometPlan).length)
101+
}
102+
}
103+
}
104+
105+
test("count distinct") {
106+
val df = spark.read.parquet(filename)
107+
df.createOrReplaceTempView("t1")
108+
for (col <- df.columns) {
109+
val sql = s"SELECT count(distinct $col) FROM t1"
110+
val (_, cometPlan) = checkSparkAnswer(sql)
111+
if (CometConf.isExperimentalNativeScan) {
112+
assert(1 == collectNativeScans(cometPlan).length)
113+
}
114+
}
115+
}
116+
117+
test("order by multiple columns") {
118+
val df = spark.read.parquet(filename)
119+
df.createOrReplaceTempView("t1")
120+
val allCols = df.columns.mkString(",")
121+
val sql = s"SELECT $allCols FROM t1 ORDER BY $allCols"
122+
// cannot run fully natively due to range partitioning and sort
123+
val (_, cometPlan) = checkSparkAnswer(sql)
124+
if (CometConf.isExperimentalNativeScan) {
125+
assert(1 == collectNativeScans(cometPlan).length)
126+
}
127+
}
128+
129+
test("aggregate group by single column") {
130+
val df = spark.read.parquet(filename)
131+
df.createOrReplaceTempView("t1")
132+
for (col <- df.columns) {
133+
// cannot run fully natively due to range partitioning and sort
134+
val sql = s"SELECT $col, count(*) FROM t1 GROUP BY $col ORDER BY $col"
135+
val (_, cometPlan) = checkSparkAnswer(sql)
136+
if (CometConf.isExperimentalNativeScan) {
137+
assert(1 == collectNativeScans(cometPlan).length)
138+
}
139+
}
140+
}
141+
142+
test("min/max aggregate") {
143+
val df = spark.read.parquet(filename)
144+
df.createOrReplaceTempView("t1")
145+
for (col <- df.columns) {
146+
// cannot run fully native due to HashAggregate
147+
val sql = s"SELECT min($col), max($col) FROM t1"
148+
val (_, cometPlan) = checkSparkAnswer(sql)
149+
if (CometConf.isExperimentalNativeScan) {
150+
assert(1 == collectNativeScans(cometPlan).length)
151+
}
152+
}
153+
}
154+
155+
test("join") {
156+
val df = spark.read.parquet(filename)
157+
df.createOrReplaceTempView("t1")
158+
df.createOrReplaceTempView("t2")
159+
for (col <- df.columns) {
160+
// cannot run fully native due to HashAggregate
161+
val sql = s"SELECT count(*) FROM t1 JOIN t2 ON t1.$col = t2.$col"
162+
val (_, cometPlan) = checkSparkAnswer(sql)
163+
if (CometConf.isExperimentalNativeScan) {
164+
assert(2 == collectNativeScans(cometPlan).length)
165+
}
166+
}
167+
}
168+
169+
override protected def test(testName: String, testTags: Tag*)(testFun: => Any)(implicit
170+
pos: Position): Unit = {
171+
Seq("native", "jvm").foreach { shuffleMode =>
172+
Seq("native_comet", "native_datafusion", "native_iceberg_compat").foreach { scanImpl =>
173+
super.test(testName + s" ($scanImpl, $shuffleMode shuffle)", testTags: _*) {
174+
withSQLConf(
175+
CometConf.COMET_NATIVE_SCAN_IMPL.key -> scanImpl,
176+
CometConf.COMET_SCAN_ALLOW_INCOMPATIBLE.key -> "true",
177+
CometConf.COMET_SHUFFLE_MODE.key -> shuffleMode) {
178+
testFun
179+
}
180+
}
181+
}
182+
}
183+
}
184+
185+
private def collectNativeScans(plan: SparkPlan): Seq[SparkPlan] = {
186+
collect(plan) {
187+
case scan: CometScanExec => scan
188+
case scan: CometNativeScanExec => scan
189+
}
190+
}
191+
192+
}

0 commit comments

Comments
 (0)