Skip to content

Commit 44170d7

Browse files
andygroveviirya
andauthored
fix: Improve testing for array_remove and fallback to Spark for unsupported types (apache#1308)
* Fallback to Spark for unsupported types for ArrayRemove * save progress * improve tests * revert debug logging * prepare for review * fix * format * remove test failure * fix * more testing * refactor * update readme * fix * format * fix * fix * fix * fix inverted options * Update QueryRunner.scala Co-authored-by: Liang-Chi Hsieh <[email protected]> * Add descriptions to CLI options --------- Co-authored-by: Liang-Chi Hsieh <[email protected]>
1 parent f09f8af commit 44170d7

File tree

12 files changed

+421
-112
lines changed

12 files changed

+421
-112
lines changed

fuzz-testing/README.md

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@ Planned areas of improvement:
4444

4545
## Usage
4646

47-
Build the jar file first.
47+
From the root of the project, run `mvn install -DskipTests` to install Comet.
48+
49+
Then build the fuzz testing jar.
4850

4951
```shell
5052
mvn package
@@ -59,8 +61,8 @@ Set appropriate values for `SPARK_HOME`, `SPARK_MASTER`, and `COMET_JAR` environ
5961
$SPARK_HOME/bin/spark-submit \
6062
--master $SPARK_MASTER \
6163
--class org.apache.comet.fuzz.Main \
62-
target/comet-fuzz-spark3.4_2.12-0.5.0-SNAPSHOT-jar-with-dependencies.jar \
63-
data --num-files=2 --num-rows=200 --num-columns=100
64+
target/comet-fuzz-spark3.4_2.12-0.6.0-SNAPSHOT-jar-with-dependencies.jar \
65+
data --num-files=2 --num-rows=200 --exclude-negative-zero --generate-arrays --generate-structs --generate-maps
6466
```
6567

6668
There is an optional `--exclude-negative-zero` flag for excluding `-0.0` from the generated data, which is
@@ -75,7 +77,7 @@ Generate random queries that are based on the available test files.
7577
$SPARK_HOME/bin/spark-submit \
7678
--master $SPARK_MASTER \
7779
--class org.apache.comet.fuzz.Main \
78-
target/comet-fuzz-spark3.4_2.12-0.5.0-SNAPSHOT-jar-with-dependencies.jar \
80+
target/comet-fuzz-spark3.4_2.12-0.6.0-SNAPSHOT-jar-with-dependencies.jar \
7981
queries --num-files=2 --num-queries=500
8082
```
8183

@@ -97,7 +99,7 @@ $SPARK_HOME/bin/spark-submit \
9799
--conf spark.driver.extraClassPath=$COMET_JAR \
98100
--conf spark.executor.extraClassPath=$COMET_JAR \
99101
--class org.apache.comet.fuzz.Main \
100-
target/comet-fuzz-spark3.4_2.12-0.5.0-SNAPSHOT-jar-with-dependencies.jar \
102+
target/comet-fuzz-spark3.4_2.12-0.6.0-SNAPSHOT-jar-with-dependencies.jar \
101103
run --num-files=2 --filename=queries.sql
102104
```
103105

fuzz-testing/pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,11 @@ under the License.
5151
<artifactId>spark-sql_${scala.binary.version}</artifactId>
5252
<scope>provided</scope>
5353
</dependency>
54+
<dependency>
55+
<groupId>org.apache.datafusion</groupId>
56+
<artifactId>comet-spark-spark${spark.version.short}_${scala.binary.version}</artifactId>
57+
<version>0.6.0-SNAPSHOT</version>
58+
</dependency>
5459
<dependency>
5560
<groupId>org.rogach</groupId>
5661
<artifactId>scallop_${scala.binary.version}</artifactId>

fuzz-testing/src/main/scala/org/apache/comet/fuzz/Main.scala

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -26,23 +26,35 @@ import org.rogach.scallop.ScallopOption
2626

2727
import org.apache.spark.sql.SparkSession
2828

29+
import org.apache.comet.testing.{DataGenOptions, ParquetGenerator}
30+
2931
class Conf(arguments: Seq[String]) extends ScallopConf(arguments) {
3032
object generateData extends Subcommand("data") {
31-
val numFiles: ScallopOption[Int] = opt[Int](required = true)
32-
val numRows: ScallopOption[Int] = opt[Int](required = true)
33-
val numColumns: ScallopOption[Int] = opt[Int](required = true)
34-
val excludeNegativeZero: ScallopOption[Boolean] = opt[Boolean](required = false)
33+
val numFiles: ScallopOption[Int] =
34+
opt[Int](required = true, descr = "Number of files to generate")
35+
val numRows: ScallopOption[Int] = opt[Int](required = true, descr = "Number of rows per file")
36+
val generateArrays: ScallopOption[Boolean] =
37+
opt[Boolean](required = false, descr = "Whether to generate arrays")
38+
val generateStructs: ScallopOption[Boolean] =
39+
opt[Boolean](required = false, descr = "Whether to generate structs")
40+
val generateMaps: ScallopOption[Boolean] =
41+
opt[Boolean](required = false, descr = "Whether to generate maps")
42+
val excludeNegativeZero: ScallopOption[Boolean] =
43+
opt[Boolean](required = false, descr = "Whether to exclude negative zero")
3544
}
3645
addSubcommand(generateData)
3746
object generateQueries extends Subcommand("queries") {
38-
val numFiles: ScallopOption[Int] = opt[Int](required = false)
39-
val numQueries: ScallopOption[Int] = opt[Int](required = true)
47+
val numFiles: ScallopOption[Int] =
48+
opt[Int](required = false, descr = "Number of input files to use")
49+
val numQueries: ScallopOption[Int] =
50+
opt[Int](required = true, descr = "Number of queries to generate")
4051
}
4152
addSubcommand(generateQueries)
4253
object runQueries extends Subcommand("run") {
43-
val filename: ScallopOption[String] = opt[String](required = true)
44-
val numFiles: ScallopOption[Int] = opt[Int](required = false)
45-
val showMatchingResults: ScallopOption[Boolean] = opt[Boolean](required = false)
54+
val filename: ScallopOption[String] =
55+
opt[String](required = true, descr = "File to write queries to")
56+
val numFiles: ScallopOption[Int] =
57+
opt[Int](required = false, descr = "Number of input files to use")
4658
}
4759
addSubcommand(runQueries)
4860
verify()
@@ -60,25 +72,28 @@ object Main {
6072
val conf = new Conf(args.toIndexedSeq)
6173
conf.subcommand match {
6274
case Some(conf.generateData) =>
63-
DataGen.generateRandomFiles(
64-
r,
65-
spark,
66-
numFiles = conf.generateData.numFiles(),
67-
numRows = conf.generateData.numRows(),
68-
numColumns = conf.generateData.numColumns(),
75+
val options = DataGenOptions(
76+
allowNull = true,
77+
generateArray = conf.generateData.generateArrays(),
78+
generateStruct = conf.generateData.generateStructs(),
79+
generateMap = conf.generateData.generateMaps(),
6980
generateNegativeZero = !conf.generateData.excludeNegativeZero())
81+
for (i <- 0 until conf.generateData.numFiles()) {
82+
ParquetGenerator.makeParquetFile(
83+
r,
84+
spark,
85+
s"test$i.parquet",
86+
numRows = conf.generateData.numRows(),
87+
options)
88+
}
7089
case Some(conf.generateQueries) =>
7190
QueryGen.generateRandomQueries(
7291
r,
7392
spark,
7493
numFiles = conf.generateQueries.numFiles(),
7594
conf.generateQueries.numQueries())
7695
case Some(conf.runQueries) =>
77-
QueryRunner.runQueries(
78-
spark,
79-
conf.runQueries.numFiles(),
80-
conf.runQueries.filename(),
81-
conf.runQueries.showMatchingResults())
96+
QueryRunner.runQueries(spark, conf.runQueries.numFiles(), conf.runQueries.filename())
8297
case _ =>
8398
// scalastyle:off println
8499
println("Invalid subcommand")

fuzz-testing/src/main/scala/org/apache/comet/fuzz/Meta.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,16 @@ object Meta {
103103
val miscScalarFunc: Seq[Function] =
104104
Seq(Function("isnan", 1), Function("isnull", 1), Function("isnotnull", 1))
105105

106+
val arrayScalarFunc: Seq[Function] = Seq(
107+
Function("array", 2),
108+
Function("array_remove", 2),
109+
Function("array_insert", 2),
110+
Function("array_contains", 2),
111+
Function("array_intersect", 2),
112+
Function("array_append", 2))
113+
106114
val scalarFunc: Seq[Function] = stringScalarFunc ++ dateScalarFunc ++
107-
mathScalarFunc ++ miscScalarFunc
115+
mathScalarFunc ++ miscScalarFunc ++ arrayScalarFunc
108116

109117
val aggFunc: Seq[Function] = Seq(
110118
Function("min", 1),

fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryRunner.scala

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ package org.apache.comet.fuzz
2121

2222
import java.io.{BufferedWriter, FileWriter, PrintWriter, StringWriter}
2323

24+
import scala.collection.mutable.WrappedArray
2425
import scala.io.Source
2526

2627
import org.apache.spark.sql.{Row, SparkSession}
@@ -31,7 +32,6 @@ object QueryRunner {
3132
spark: SparkSession,
3233
numFiles: Int,
3334
filename: String,
34-
showMatchingResults: Boolean,
3535
showFailedSparkQueries: Boolean = false): Unit = {
3636

3737
val outputFilename = s"results-${System.currentTimeMillis()}.md"
@@ -64,8 +64,12 @@ object QueryRunner {
6464
val sparkRows = df.collect()
6565
val sparkPlan = df.queryExecution.executedPlan.toString
6666

67+
// execute with Comet
6768
try {
6869
spark.conf.set("spark.comet.enabled", "true")
70+
// complex type support until we support it natively
71+
spark.conf.set("spark.comet.sparkToColumnar.enabled", "true")
72+
spark.conf.set("spark.comet.convert.parquet.enabled", "true")
6973
val df = spark.sql(sql)
7074
val cometRows = df.collect()
7175
val cometPlan = df.queryExecution.executedPlan.toString
@@ -77,17 +81,7 @@ object QueryRunner {
7781
val r = cometRows(i)
7882
assert(l.length == r.length)
7983
for (j <- 0 until l.length) {
80-
val same = (l(j), r(j)) match {
81-
case (a: Float, b: Float) if a.isInfinity => b.isInfinity
82-
case (a: Float, b: Float) if a.isNaN => b.isNaN
83-
case (a: Float, b: Float) => (a - b).abs <= 0.000001f
84-
case (a: Double, b: Double) if a.isInfinity => b.isInfinity
85-
case (a: Double, b: Double) if a.isNaN => b.isNaN
86-
case (a: Double, b: Double) => (a - b).abs <= 0.000001
87-
case (a: Array[Byte], b: Array[Byte]) => a.sameElements(b)
88-
case (a, b) => a == b
89-
}
90-
if (!same) {
84+
if (!same(l(j), r(j))) {
9185
showSQL(w, sql)
9286
showPlans(w, sparkPlan, cometPlan)
9387
w.write(s"First difference at row $i:\n")
@@ -138,14 +132,33 @@ object QueryRunner {
138132
}
139133
}
140134

135+
private def same(l: Any, r: Any): Boolean = {
136+
(l, r) match {
137+
case (a: Float, b: Float) if a.isInfinity => b.isInfinity
138+
case (a: Float, b: Float) if a.isNaN => b.isNaN
139+
case (a: Float, b: Float) => (a - b).abs <= 0.000001f
140+
case (a: Double, b: Double) if a.isInfinity => b.isInfinity
141+
case (a: Double, b: Double) if a.isNaN => b.isNaN
142+
case (a: Double, b: Double) => (a - b).abs <= 0.000001
143+
case (a: Array[_], b: Array[_]) =>
144+
a.length == b.length && a.zip(b).forall(x => same(x._1, x._2))
145+
case (a: WrappedArray[_], b: WrappedArray[_]) =>
146+
a.length == b.length && a.zip(b).forall(x => same(x._1, x._2))
147+
case (a, b) => a == b
148+
}
149+
}
150+
151+
private def format(value: Any): String = {
152+
value match {
153+
case null => "NULL"
154+
case v: WrappedArray[_] => s"[${v.map(format).mkString(",")}]"
155+
case v: Array[Byte] => s"[${v.mkString(",")}]"
156+
case other => other.toString
157+
}
158+
}
159+
141160
private def formatRow(row: Row): String = {
142-
row.toSeq
143-
.map {
144-
case null => "NULL"
145-
case v: Array[Byte] => v.mkString
146-
case other => other.toString
147-
}
148-
.mkString(",")
161+
row.toSeq.map(format).mkString(",")
149162
}
150163

151164
private def showSQL(w: BufferedWriter, sql: String, maxLength: Int = 120): Unit = {

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2284,12 +2284,16 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
22842284
withInfo(expr, "unsupported arguments for GetArrayStructFields", child)
22852285
None
22862286
}
2287-
case expr if expr.prettyName == "array_remove" =>
2288-
createBinaryExpr(
2289-
expr.children(0),
2290-
expr.children(1),
2291-
inputs,
2292-
(builder, binaryExpr) => builder.setArrayRemove(binaryExpr))
2287+
case expr: ArrayRemove =>
2288+
if (CometArrayRemove.checkSupport(expr)) {
2289+
createBinaryExpr(
2290+
expr.children(0),
2291+
expr.children(1),
2292+
inputs,
2293+
(builder, binaryExpr) => builder.setArrayRemove(binaryExpr))
2294+
} else {
2295+
None
2296+
}
22932297
case expr if expr.prettyName == "array_contains" =>
22942298
createBinaryExpr(
22952299
expr.children(0),
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.comet.serde
21+
22+
import org.apache.spark.sql.catalyst.expressions.{ArrayRemove, Expression}
23+
import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, DecimalType, StructType}
24+
25+
import org.apache.comet.CometSparkSessionExtensions.withInfo
26+
import org.apache.comet.shims.CometExprShim
27+
28+
trait CometExpression {
29+
def checkSupport(expr: Expression): Boolean
30+
}
31+
32+
object CometArrayRemove extends CometExpression with CometExprShim {
33+
34+
def isTypeSupported(dt: DataType): Boolean = {
35+
import DataTypes._
36+
dt match {
37+
case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType |
38+
_: DecimalType | DateType | TimestampType | StringType | BinaryType =>
39+
true
40+
case t if isTimestampNTZType(t) => true
41+
case ArrayType(elementType, _) => isTypeSupported(elementType)
42+
case _: StructType =>
43+
// https://github.com/apache/datafusion-comet/issues/1307
44+
false
45+
case _ => false
46+
}
47+
}
48+
49+
override def checkSupport(expr: Expression): Boolean = {
50+
val ar = expr.asInstanceOf[ArrayRemove]
51+
val inputTypes: Set[DataType] = ar.children.map(_.dataType).toSet
52+
for (dt <- inputTypes) {
53+
if (!isTypeSupported(dt)) {
54+
withInfo(expr, s"data type not supported: $dt")
55+
return false
56+
}
57+
}
58+
true
59+
}
60+
}

0 commit comments

Comments
 (0)