@@ -232,24 +232,53 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp
232232 }
233233 }
234234
235+ test(" array_distinct" ) {
236+ withSQLConf(CometConf .COMET_EXPR_ALLOW_INCOMPATIBLE .key -> " true" ) {
237+ Seq (true , false ).foreach { dictionaryEnabled =>
238+ withTempDir { dir =>
239+ val path = new Path (dir.toURI.toString, " test.parquet" )
240+ makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, n = 10000 )
241+ spark.read.parquet(path.toString).createOrReplaceTempView(" t1" )
242+ // The result needs to be in ascending order for checkSparkAnswerAndOperator to pass
243+ // because datafusion array_distinct sorts the elements and then removes the duplicates
244+ checkSparkAnswerAndOperator(
245+ spark.sql(" SELECT array_distinct(array(_2, _2, _3, _4, _4)) FROM t1" ))
246+ checkSparkAnswerAndOperator(
247+ spark.sql(" SELECT array_distinct((CASE WHEN _2 =_3 THEN array(_4) END)) FROM t1" ))
248+ checkSparkAnswerAndOperator(spark.sql(
249+ " SELECT array_distinct((CASE WHEN _2 =_3 THEN array(_2, _2, _4, _4, _5) END)) FROM t1" ))
250+ // NULL needs to be the first element for checkSparkAnswerAndOperator to pass because
251+ // datafusion array_distinct sorts the elements and then removes the duplicates
252+ checkSparkAnswerAndOperator(
253+ spark.sql(
254+ " SELECT array_distinct(array(CAST(NULL AS INT), _2, _2, _3, _4, _4)) FROM t1" ))
255+ checkSparkAnswerAndOperator(spark.sql(
256+ " SELECT array_distinct(array(CAST(NULL AS INT), CAST(NULL AS INT), _2, _2, _3, _4, _4)) FROM t1" ))
257+ }
258+ }
259+ }
260+ }
261+
235262 test(" array_max" ) {
236- withTempDir { dir =>
237- val path = new Path (dir.toURI.toString, " test.parquet" )
238- makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = false , n = 10000 )
239- spark.read.parquet(path.toString).createOrReplaceTempView(" t1" );
240- checkSparkAnswerAndOperator(spark.sql(" SELECT array_max(array(_2, _3, _4)) FROM t1" ))
241- checkSparkAnswerAndOperator(
242- spark.sql(" SELECT array_max((CASE WHEN _2 =_3 THEN array(_4) END)) FROM t1" ));
243- checkSparkAnswerAndOperator(
244- spark.sql(" SELECT array_max((CASE WHEN _2 =_3 THEN array(_2, _4) END)) FROM t1" ));
245- checkSparkAnswerAndOperator(
246- spark.sql(" SELECT array_max(array(CAST(NULL AS INT), CAST(NULL AS INT))) FROM t1" ))
247- checkSparkAnswerAndOperator(
248- spark.sql(" SELECT array_max(array(_2, CAST(NULL AS INT))) FROM t1" ))
249- checkSparkAnswerAndOperator(spark.sql(" SELECT array_max(array()) FROM t1" ))
250- checkSparkAnswerAndOperator(
251- spark.sql(
252- " SELECT array_max(array(double('-Infinity'), 0.0, double('Infinity'))) FROM t1" ))
263+ Seq (true , false ).foreach { dictionaryEnabled =>
264+ withTempDir { dir =>
265+ val path = new Path (dir.toURI.toString, " test.parquet" )
266+ makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, n = 10000 )
267+ spark.read.parquet(path.toString).createOrReplaceTempView(" t1" );
268+ checkSparkAnswerAndOperator(spark.sql(" SELECT array_max(array(_2, _3, _4)) FROM t1" ))
269+ checkSparkAnswerAndOperator(
270+ spark.sql(" SELECT array_max((CASE WHEN _2 =_3 THEN array(_4) END)) FROM t1" ))
271+ checkSparkAnswerAndOperator(
272+ spark.sql(" SELECT array_max((CASE WHEN _2 =_3 THEN array(_2, _4) END)) FROM t1" ))
273+ checkSparkAnswerAndOperator(
274+ spark.sql(" SELECT array_max(array(CAST(NULL AS INT), CAST(NULL AS INT))) FROM t1" ))
275+ checkSparkAnswerAndOperator(
276+ spark.sql(" SELECT array_max(array(_2, CAST(NULL AS INT))) FROM t1" ))
277+ checkSparkAnswerAndOperator(spark.sql(" SELECT array_max(array()) FROM t1" ))
278+ checkSparkAnswerAndOperator(
279+ spark.sql(
280+ " SELECT array_max(array(double('-Infinity'), 0.0, double('Infinity'))) FROM t1" ))
281+ }
253282 }
254283 }
255284
0 commit comments