@@ -25,7 +25,7 @@ import scala.util.Random
2525import org .apache .hadoop .fs .Path
2626import org .apache .spark .sql .CometTestBase
2727import org .apache .spark .sql .execution .adaptive .AdaptiveSparkPlanHelper
28- import org .apache .spark .sql .functions .{ array , col , expr , lit , udf }
28+ import org .apache .spark .sql .functions ._
2929
3030import org .apache .comet .CometSparkSessionExtensions .{isSpark35Plus , isSpark40Plus }
3131import org .apache .comet .serde .CometArrayExcept
@@ -218,16 +218,121 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp
218218 }
219219 }
220220
221- test(" array_contains" ) {
222- withSQLConf(CometConf .COMET_EXPR_ALLOW_INCOMPATIBLE .key -> " true" ) {
223- withTempDir { dir =>
224- val path = new Path (dir.toURI.toString, " test.parquet" )
225- makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = false , n = 10000 )
226- spark.read.parquet(path.toString).createOrReplaceTempView(" t1" );
221+ test(" array_contains - int values" ) {
222+ withTempDir { dir =>
223+ val path = new Path (dir.toURI.toString, " test.parquet" )
224+ makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = false , n = 10000 )
225+ spark.read.parquet(path.toString).createOrReplaceTempView(" t1" );
226+ checkSparkAnswerAndOperator(
227+ spark.sql(" SELECT array_contains(array(_2, _3, _4), _2) FROM t1" ))
228+ checkSparkAnswerAndOperator(
229+ spark.sql(" SELECT array_contains((CASE WHEN _2 =_3 THEN array(_4) END), _4) FROM t1" ));
230+ }
231+ }
232+
233+ test(" array_contains - test all types (native Parquet reader)" ) {
234+ withTempDir { dir =>
235+ val path = new Path (dir.toURI.toString, " test.parquet" )
236+ val filename = path.toString
237+ val random = new Random (42 )
238+ withSQLConf(CometConf .COMET_ENABLED .key -> " false" ) {
239+ ParquetGenerator .makeParquetFile(
240+ random,
241+ spark,
242+ filename,
243+ 100 ,
244+ DataGenOptions (
245+ allowNull = true ,
246+ generateNegativeZero = true ,
247+ generateArray = true ,
248+ generateStruct = true ,
249+ generateMap = false ))
250+ }
251+ val table = spark.read.parquet(filename)
252+ table.createOrReplaceTempView(" t1" )
253+ val complexTypeFields =
254+ table.schema.fields.filter(field => isComplexType(field.dataType))
255+ val primitiveTypeFields =
256+ table.schema.fields.filterNot(field => isComplexType(field.dataType))
257+ for (field <- primitiveTypeFields) {
258+ val fieldName = field.name
259+ val typeName = field.dataType.typeName
260+ sql(s " SELECT array( $fieldName, $fieldName) as a, $fieldName as b FROM t1 " )
261+ .createOrReplaceTempView(" t2" )
262+ checkSparkAnswerAndOperator(sql(" SELECT array_contains(a, b) FROM t2" ))
227263 checkSparkAnswerAndOperator(
228- spark.sql(" SELECT array_contains(array(_2, _3, _4), _2) FROM t1" ))
264+ sql(s " SELECT array_contains(a, cast(null as $typeName)) FROM t2 " ))
265+ }
266+ for (field <- complexTypeFields) {
267+ val fieldName = field.name
268+ sql(s " SELECT array( $fieldName, $fieldName) as a, $fieldName as b FROM t1 " )
269+ .createOrReplaceTempView(" t3" )
270+ checkSparkAnswer(sql(" SELECT array_contains(a, b) FROM t3" ))
271+ }
272+ }
273+ }
274+
275+ // https://github.com/apache/datafusion-comet/issues/1929
276+ ignore(" array_contains - array literals" ) {
277+ withTempDir { dir =>
278+ val path = new Path (dir.toURI.toString, " test.parquet" )
279+ val filename = path.toString
280+ val random = new Random (42 )
281+ withSQLConf(CometConf .COMET_ENABLED .key -> " false" ) {
282+ ParquetGenerator .makeParquetFile(
283+ random,
284+ spark,
285+ filename,
286+ 100 ,
287+ DataGenOptions (
288+ allowNull = true ,
289+ generateNegativeZero = true ,
290+ generateArray = false ,
291+ generateStruct = false ,
292+ generateMap = false ))
293+ }
294+ val table = spark.read.parquet(filename)
295+ for (field <- table.schema.fields) {
296+ val typeName = field.dataType.typeName
229297 checkSparkAnswerAndOperator(
230- spark.sql(" SELECT array_contains((CASE WHEN _2 =_3 THEN array(_4) END), _4) FROM t1" ));
298+ sql(s " SELECT array_contains(cast(null as array< $typeName>), b) FROM t2 " ))
299+ checkSparkAnswerAndOperator(sql(
300+ s " SELECT array_contains(cast(array() as array< $typeName>), cast(null as $typeName)) FROM t2 " ))
301+ checkSparkAnswerAndOperator(sql(" SELECT array_contains(array(), 1) FROM t2" ))
302+ }
303+ }
304+ }
305+
306+ test(" array_contains - test all types (convert from Parquet)" ) {
307+ withTempDir { dir =>
308+ val path = new Path (dir.toURI.toString, " test.parquet" )
309+ val filename = path.toString
310+ val random = new Random (42 )
311+ withSQLConf(CometConf .COMET_ENABLED .key -> " false" ) {
312+ ParquetGenerator .makeParquetFile(
313+ random,
314+ spark,
315+ filename,
316+ 100 ,
317+ DataGenOptions (
318+ allowNull = true ,
319+ generateNegativeZero = true ,
320+ generateArray = true ,
321+ generateStruct = true ,
322+ generateMap = false ))
323+ }
324+ withSQLConf(
325+ CometConf .COMET_NATIVE_SCAN_ENABLED .key -> " false" ,
326+ CometConf .COMET_SPARK_TO_ARROW_ENABLED .key -> " true" ,
327+ CometConf .COMET_CONVERT_FROM_PARQUET_ENABLED .key -> " true" ) {
328+ val table = spark.read.parquet(filename)
329+ table.createOrReplaceTempView(" t1" )
330+ for (field <- table.schema.fields) {
331+ val fieldName = field.name
332+ sql(s " SELECT array( $fieldName, $fieldName) as a, $fieldName as b FROM t1 " )
333+ .createOrReplaceTempView(" t2" )
334+ checkSparkAnswer(sql(" SELECT array_contains(a, b) FROM t2" ))
335+ }
231336 }
232337 }
233338 }
0 commit comments