Skip to content

Commit ddab352

Browse files
Chore: Improve array contains test coverage (#2030)
1 parent e73fff0 commit ddab352

File tree

2 files changed

+115
-10
lines changed

2 files changed

+115
-10
lines changed

spark/src/main/scala/org/apache/comet/serde/arrays.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ object CometArrayAppend extends CometExpressionSerde with IncompatExpr {
136136
}
137137
}
138138

139-
object CometArrayContains extends CometExpressionSerde with IncompatExpr {
139+
object CometArrayContains extends CometExpressionSerde {
140140
override def convert(
141141
expr: Expression,
142142
inputs: Seq[Attribute],

spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala

Lines changed: 114 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import scala.util.Random
2525
import org.apache.hadoop.fs.Path
2626
import org.apache.spark.sql.CometTestBase
2727
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
28-
import org.apache.spark.sql.functions.{array, col, expr, lit, udf}
28+
import org.apache.spark.sql.functions._
2929

3030
import org.apache.comet.CometSparkSessionExtensions.{isSpark35Plus, isSpark40Plus}
3131
import org.apache.comet.serde.CometArrayExcept
@@ -218,16 +218,121 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp
218218
}
219219
}
220220

221-
test("array_contains") {
222-
withSQLConf(CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") {
223-
withTempDir { dir =>
224-
val path = new Path(dir.toURI.toString, "test.parquet")
225-
makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = false, n = 10000)
226-
spark.read.parquet(path.toString).createOrReplaceTempView("t1");
221+
test("array_contains - int values") {
222+
withTempDir { dir =>
223+
val path = new Path(dir.toURI.toString, "test.parquet")
224+
makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = false, n = 10000)
225+
spark.read.parquet(path.toString).createOrReplaceTempView("t1");
226+
checkSparkAnswerAndOperator(
227+
spark.sql("SELECT array_contains(array(_2, _3, _4), _2) FROM t1"))
228+
checkSparkAnswerAndOperator(
229+
spark.sql("SELECT array_contains((CASE WHEN _2 =_3 THEN array(_4) END), _4) FROM t1"));
230+
}
231+
}
232+
233+
test("array_contains - test all types (native Parquet reader)") {
234+
withTempDir { dir =>
235+
val path = new Path(dir.toURI.toString, "test.parquet")
236+
val filename = path.toString
237+
val random = new Random(42)
238+
withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
239+
ParquetGenerator.makeParquetFile(
240+
random,
241+
spark,
242+
filename,
243+
100,
244+
DataGenOptions(
245+
allowNull = true,
246+
generateNegativeZero = true,
247+
generateArray = true,
248+
generateStruct = true,
249+
generateMap = false))
250+
}
251+
val table = spark.read.parquet(filename)
252+
table.createOrReplaceTempView("t1")
253+
val complexTypeFields =
254+
table.schema.fields.filter(field => isComplexType(field.dataType))
255+
val primitiveTypeFields =
256+
table.schema.fields.filterNot(field => isComplexType(field.dataType))
257+
for (field <- primitiveTypeFields) {
258+
val fieldName = field.name
259+
val typeName = field.dataType.typeName
260+
sql(s"SELECT array($fieldName, $fieldName) as a, $fieldName as b FROM t1")
261+
.createOrReplaceTempView("t2")
262+
checkSparkAnswerAndOperator(sql("SELECT array_contains(a, b) FROM t2"))
227263
checkSparkAnswerAndOperator(
228-
spark.sql("SELECT array_contains(array(_2, _3, _4), _2) FROM t1"))
264+
sql(s"SELECT array_contains(a, cast(null as $typeName)) FROM t2"))
265+
}
266+
for (field <- complexTypeFields) {
267+
val fieldName = field.name
268+
sql(s"SELECT array($fieldName, $fieldName) as a, $fieldName as b FROM t1")
269+
.createOrReplaceTempView("t3")
270+
checkSparkAnswer(sql("SELECT array_contains(a, b) FROM t3"))
271+
}
272+
}
273+
}
274+
275+
// https://github.com/apache/datafusion-comet/issues/1929
276+
ignore("array_contains - array literals") {
277+
withTempDir { dir =>
278+
val path = new Path(dir.toURI.toString, "test.parquet")
279+
val filename = path.toString
280+
val random = new Random(42)
281+
withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
282+
ParquetGenerator.makeParquetFile(
283+
random,
284+
spark,
285+
filename,
286+
100,
287+
DataGenOptions(
288+
allowNull = true,
289+
generateNegativeZero = true,
290+
generateArray = false,
291+
generateStruct = false,
292+
generateMap = false))
293+
}
294+
val table = spark.read.parquet(filename)
295+
for (field <- table.schema.fields) {
296+
val typeName = field.dataType.typeName
229297
checkSparkAnswerAndOperator(
230-
spark.sql("SELECT array_contains((CASE WHEN _2 =_3 THEN array(_4) END), _4) FROM t1"));
298+
sql(s"SELECT array_contains(cast(null as array<$typeName>), b) FROM t2"))
299+
checkSparkAnswerAndOperator(sql(
300+
s"SELECT array_contains(cast(array() as array<$typeName>), cast(null as $typeName)) FROM t2"))
301+
checkSparkAnswerAndOperator(sql("SELECT array_contains(array(), 1) FROM t2"))
302+
}
303+
}
304+
}
305+
306+
test("array_contains - test all types (convert from Parquet)") {
307+
withTempDir { dir =>
308+
val path = new Path(dir.toURI.toString, "test.parquet")
309+
val filename = path.toString
310+
val random = new Random(42)
311+
withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
312+
ParquetGenerator.makeParquetFile(
313+
random,
314+
spark,
315+
filename,
316+
100,
317+
DataGenOptions(
318+
allowNull = true,
319+
generateNegativeZero = true,
320+
generateArray = true,
321+
generateStruct = true,
322+
generateMap = false))
323+
}
324+
withSQLConf(
325+
CometConf.COMET_NATIVE_SCAN_ENABLED.key -> "false",
326+
CometConf.COMET_SPARK_TO_ARROW_ENABLED.key -> "true",
327+
CometConf.COMET_CONVERT_FROM_PARQUET_ENABLED.key -> "true") {
328+
val table = spark.read.parquet(filename)
329+
table.createOrReplaceTempView("t1")
330+
for (field <- table.schema.fields) {
331+
val fieldName = field.name
332+
sql(s"SELECT array($fieldName, $fieldName) as a, $fieldName as b FROM t1")
333+
.createOrReplaceTempView("t2")
334+
checkSparkAnswer(sql("SELECT array_contains(a, b) FROM t2"))
335+
}
231336
}
232337
}
233338
}

0 commit comments

Comments
 (0)