diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 892d8bca63..4cc92c2beb 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -915,6 +915,8 @@ object QueryPlanSerde extends Logging with CometExprShim { case l @ Length(child) if child.dataType == BinaryType => withInfo(l, "Length on BinaryType is not supported") None + case r @ Reverse(child) if child.dataType.isInstanceOf[ArrayType] => + convert(r, CometArrayReverse) case expr => QueryPlanSerde.exprSerdeMap.get(expr.getClass) match { case Some(handler) => diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala b/spark/src/main/scala/org/apache/comet/serde/arrays.scala index 5b1603aafa..09ea547cc3 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala @@ -21,7 +21,7 @@ package org.apache.comet.serde import scala.annotation.tailrec -import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayContains, ArrayDistinct, ArrayExcept, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayRemove, ArrayRepeat, ArraysOverlap, ArrayUnion, Attribute, CreateArray, ElementAt, Expression, Flatten, GetArrayItem, Literal} +import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayContains, ArrayDistinct, ArrayExcept, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayRemove, ArrayRepeat, ArraysOverlap, ArrayUnion, Attribute, CreateArray, ElementAt, Expression, Flatten, GetArrayItem, Literal, Reverse} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -432,6 +432,22 @@ object CometGetArrayItem extends CometExpressionSerde[GetArrayItem] { } } +object CometArrayReverse extends CometExpressionSerde[Reverse] with ArraysBase { + override def convert( + expr: Reverse, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + if (!isTypeSupported(expr.child.dataType)) { + withInfo(expr, s"child data type not supported: ${expr.child.dataType}") + return None + } + val reverseExprProto = exprToProto(expr.child, inputs, binding) + val reverseScalarExpr = scalarFunctionExprToProto("array_reverse", reverseExprProto) + optExprWithInfo(reverseScalarExpr, expr, expr.children: _*) + } + +} + object CometElementAt extends CometExpressionSerde[ElementAt] { override def convert( diff --git a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala index 56d9b3b429..2adb7a9ed6 100644 --- a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala @@ -29,88 +29,94 @@ import org.apache.spark.sql.functions._ import org.apache.comet.CometSparkSessionExtensions.{isSpark35Plus, isSpark40Plus} import org.apache.comet.DataTypeSupport.isComplexType -import org.apache.comet.serde.{CometArrayExcept, CometArrayRemove, CometFlatten} +import org.apache.comet.serde.{CometArrayExcept, CometArrayRemove, CometArrayReverse, CometFlatten} import org.apache.comet.testing.{DataGenOptions, ParquetGenerator} class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { test("array_remove - integer") { Seq(true, false).foreach { dictionaryEnabled => - withTempDir { dir => - val path = new Path(dir.toURI.toString, "test.parquet") - makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, 10000) - spark.read.parquet(path.toString).createOrReplaceTempView("t1") - checkSparkAnswerAndOperator( - sql("SELECT array_remove(array(_2, _3,_4), _2) from t1 where _2 is null")) - checkSparkAnswerAndOperator( - sql("SELECT array_remove(array(_2, _3,_4), _3) from t1 where _3 is not null")) - checkSparkAnswerAndOperator(sql( - "SELECT array_remove(case when _2 = _3 THEN array(_2, _3,_4) ELSE null END, _3) from t1")) + withTempView("t1") { + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, 10000) + spark.read.parquet(path.toString).createOrReplaceTempView("t1") + checkSparkAnswerAndOperator( + sql("SELECT array_remove(array(_2, _3,_4), _2) from t1 where _2 is null")) + checkSparkAnswerAndOperator( + sql("SELECT array_remove(array(_2, _3,_4), _3) from t1 where _3 is not null")) + checkSparkAnswerAndOperator(sql( + "SELECT array_remove(case when _2 = _3 THEN array(_2, _3,_4) ELSE null END, _3) from t1")) + } } } } test("array_remove - test all types (native Parquet reader)") { withTempDir { dir => - val path = new Path(dir.toURI.toString, "test.parquet") - val filename = path.toString - val random = new Random(42) - withSQLConf(CometConf.COMET_ENABLED.key -> "false") { - ParquetGenerator.makeParquetFile( - random, - spark, - filename, - 100, - DataGenOptions( - allowNull = true, - generateNegativeZero = true, - generateArray = false, - generateStruct = false, - generateMap = false)) - } - val table = spark.read.parquet(filename) - table.createOrReplaceTempView("t1") - // test with array of each column - val fieldNames = - table.schema.fields - .filter(field => CometArrayRemove.isTypeSupported(field.dataType)) - .map(_.name) - for (fieldName <- fieldNames) { - sql(s"SELECT array($fieldName, $fieldName) as a, $fieldName as b FROM t1") - .createOrReplaceTempView("t2") - val df = sql("SELECT array_remove(a, b) FROM t2") - checkSparkAnswerAndOperator(df) - } - } - } - - test("array_remove - test all types (convert from Parquet)") { - withTempDir { dir => - val path = new Path(dir.toURI.toString, "test.parquet") - val filename = path.toString - val random = new Random(42) - withSQLConf(CometConf.COMET_ENABLED.key -> "false") { - val options = DataGenOptions( - allowNull = true, - generateNegativeZero = true, - generateArray = true, - generateStruct = true, - generateMap = false) - ParquetGenerator.makeParquetFile(random, spark, filename, 100, options) - } - withSQLConf( - CometConf.COMET_NATIVE_SCAN_ENABLED.key -> "false", - CometConf.COMET_SPARK_TO_ARROW_ENABLED.key -> "true", - CometConf.COMET_CONVERT_FROM_PARQUET_ENABLED.key -> "true") { + withTempView("t1") { + val path = new Path(dir.toURI.toString, "test.parquet") + val filename = path.toString + val random = new Random(42) + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + ParquetGenerator.makeParquetFile( + random, + spark, + filename, + 100, + DataGenOptions( + allowNull = true, + generateNegativeZero = true, + generateArray = false, + generateStruct = false, + generateMap = false)) + } val table = spark.read.parquet(filename) table.createOrReplaceTempView("t1") // test with array of each column - for (field <- table.schema.fields) { - val fieldName = field.name + val fieldNames = + table.schema.fields + .filter(field => CometArrayRemove.isTypeSupported(field.dataType)) + .map(_.name) + for (fieldName <- fieldNames) { sql(s"SELECT array($fieldName, $fieldName) as a, $fieldName as b FROM t1") .createOrReplaceTempView("t2") val df = sql("SELECT array_remove(a, b) FROM t2") - checkSparkAnswer(df) + checkSparkAnswerAndOperator(df) + } + } + } + } + + test("array_remove - test all types (convert from Parquet)") { + withTempDir { dir => + withTempView("t1") { + val path = new Path(dir.toURI.toString, "test.parquet") + val filename = path.toString + val random = new Random(42) + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + val options = DataGenOptions( + allowNull = true, + generateNegativeZero = true, + generateArray = true, + generateStruct = true, + generateMap = false) + ParquetGenerator.makeParquetFile(random, spark, filename, 100, options) + } + withSQLConf( + CometConf.COMET_NATIVE_SCAN_ENABLED.key -> "false", + CometConf.COMET_SPARK_TO_ARROW_ENABLED.key -> "true", + CometConf.COMET_CONVERT_FROM_PARQUET_ENABLED.key -> "true") { + val table = spark.read.parquet(filename) + table.createOrReplaceTempView("t1") + // test with array of each column + for (field <- table.schema.fields) { + val fieldName = field.name + sql(s"SELECT array($fieldName, $fieldName) as a, $fieldName as b FROM t1") + .createOrReplaceTempView("t2") + val df = sql("SELECT array_remove(a, b) FROM t2") + checkSparkAnswer(df) + } } } } @@ -118,19 +124,21 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp test("array_remove - fallback for unsupported type struct") { withTempDir { dir => - val path = new Path(dir.toURI.toString, "test.parquet") - makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = true, 100) - spark.read.parquet(path.toString).createOrReplaceTempView("t1") - sql("SELECT array(struct(_1, _2)) as a, struct(_1, _2) as b FROM t1") - .createOrReplaceTempView("t2") - val expectedFallbackReasons = HashSet( - "data type not supported: ArrayType(StructType(StructField(_1,BooleanType,true),StructField(_2,ByteType,true)),false)") - // note that checkExtended is disabled here due to an unrelated issue - // https://github.com/apache/datafusion-comet/issues/1313 - checkSparkAnswerAndCompareExplainPlan( - sql("SELECT array_remove(a, b) FROM t2"), - expectedFallbackReasons, - checkExplainString = false) + withTempView("t1", "t2") { + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = true, 100) + spark.read.parquet(path.toString).createOrReplaceTempView("t1") + sql("SELECT array(struct(_1, _2)) as a, struct(_1, _2) as b FROM t1") + .createOrReplaceTempView("t2") + val expectedFallbackReasons = HashSet( + "data type not supported: ArrayType(StructType(StructField(_1,BooleanType,true),StructField(_2,ByteType,true)),false)") + // note that checkExtended is disabled here due to an unrelated issue + // https://github.com/apache/datafusion-comet/issues/1313 + checkSparkAnswerAndCompareExplainPlan( + sql("SELECT array_remove(a, b) FROM t2"), + expectedFallbackReasons, + checkExplainString = false) + } } } @@ -138,21 +146,25 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp withSQLConf(CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") { Seq(true, false).foreach { dictionaryEnabled => withTempDir { dir => - val path = new Path(dir.toURI.toString, "test.parquet") - makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = dictionaryEnabled, 10000) - spark.read.parquet(path.toString).createOrReplaceTempView("t1"); - checkSparkAnswerAndOperator(spark.sql("Select array_append(array(_1),false) from t1")) - checkSparkAnswerAndOperator( - spark.sql("SELECT array_append(array(_2, _3, _4), 4) FROM t1")) - checkSparkAnswerAndOperator( - spark.sql("SELECT array_append(array(_2, _3, _4), null) FROM t1")); - checkSparkAnswerAndOperator( - spark.sql("SELECT array_append(array(_6, _7), CAST(6.5 AS DOUBLE)) FROM t1")); - checkSparkAnswerAndOperator( - spark.sql("SELECT array_append(array(_8), 'test') FROM t1")); - checkSparkAnswerAndOperator(spark.sql("SELECT array_append(array(_19), _19) FROM t1")); - checkSparkAnswerAndOperator( - spark.sql("SELECT array_append((CASE WHEN _2 =_3 THEN array(_4) END), _4) FROM t1")); + withTempView("t1") { + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = dictionaryEnabled, 10000) + spark.read.parquet(path.toString).createOrReplaceTempView("t1"); + checkSparkAnswerAndOperator(spark.sql("Select array_append(array(_1),false) from t1")) + checkSparkAnswerAndOperator( + spark.sql("SELECT array_append(array(_2, _3, _4), 4) FROM t1")) + checkSparkAnswerAndOperator( + spark.sql("SELECT array_append(array(_2, _3, _4), null) FROM t1")); + checkSparkAnswerAndOperator( + spark.sql("SELECT array_append(array(_6, _7), CAST(6.5 AS DOUBLE)) FROM t1")); + checkSparkAnswerAndOperator( + spark.sql("SELECT array_append(array(_8), 'test') FROM t1")); + checkSparkAnswerAndOperator( + spark.sql("SELECT array_append(array(_19), _19) FROM t1")); + checkSparkAnswerAndOperator( + spark.sql( + "SELECT array_append((CASE WHEN _2 =_3 THEN array(_4) END), _4) FROM t1")); + } } } } @@ -163,21 +175,26 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp withSQLConf(CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") { Seq(true, false).foreach { dictionaryEnabled => withTempDir { dir => - val path = new Path(dir.toURI.toString, "test.parquet") - makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = dictionaryEnabled, 10000) - spark.read.parquet(path.toString).createOrReplaceTempView("t1"); - checkSparkAnswerAndOperator(spark.sql("Select array_prepend(array(_1),false) from t1")) - checkSparkAnswerAndOperator( - spark.sql("SELECT array_prepend(array(_2, _3, _4), 4) FROM t1")) - checkSparkAnswerAndOperator( - spark.sql("SELECT array_prepend(array(_2, _3, _4), null) FROM t1")); - checkSparkAnswerAndOperator( - spark.sql("SELECT array_prepend(array(_6, _7), CAST(6.5 AS DOUBLE)) FROM t1")); - checkSparkAnswerAndOperator( - spark.sql("SELECT array_prepend(array(_8), 'test') FROM t1")); - checkSparkAnswerAndOperator(spark.sql("SELECT array_prepend(array(_19), _19) FROM t1")); - checkSparkAnswerAndOperator( - spark.sql("SELECT array_prepend((CASE WHEN _2 =_3 THEN array(_4) END), _4) FROM t1")); + withTempView("t1") { + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = dictionaryEnabled, 10000) + spark.read.parquet(path.toString).createOrReplaceTempView("t1"); + checkSparkAnswerAndOperator( + spark.sql("Select array_prepend(array(_1),false) from t1")) + checkSparkAnswerAndOperator( + spark.sql("SELECT array_prepend(array(_2, _3, _4), 4) FROM t1")) + checkSparkAnswerAndOperator( + spark.sql("SELECT array_prepend(array(_2, _3, _4), null) FROM t1")); + checkSparkAnswerAndOperator( + spark.sql("SELECT array_prepend(array(_6, _7), CAST(6.5 AS DOUBLE)) FROM t1")); + checkSparkAnswerAndOperator( + spark.sql("SELECT array_prepend(array(_8), 'test') FROM t1")); + checkSparkAnswerAndOperator( + spark.sql("SELECT array_prepend(array(_19), _19) FROM t1")); + checkSparkAnswerAndOperator( + spark.sql( + "SELECT array_prepend((CASE WHEN _2 =_3 THEN array(_4) END), _4) FROM t1")); + } } } } @@ -225,84 +242,90 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp test("array_contains - int values") { withTempDir { dir => - val path = new Path(dir.toURI.toString, "test.parquet") - makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = false, n = 10000) - spark.read.parquet(path.toString).createOrReplaceTempView("t1"); - checkSparkAnswerAndOperator( - spark.sql("SELECT array_contains(array(_2, _3, _4), _2) FROM t1")) - checkSparkAnswerAndOperator( - spark.sql("SELECT array_contains((CASE WHEN _2 =_3 THEN array(_4) END), _4) FROM t1")); + withTempView("t1") { + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = false, n = 10000) + spark.read.parquet(path.toString).createOrReplaceTempView("t1"); + checkSparkAnswerAndOperator( + spark.sql("SELECT array_contains(array(_2, _3, _4), _2) FROM t1")) + checkSparkAnswerAndOperator( + spark.sql("SELECT array_contains((CASE WHEN _2 =_3 THEN array(_4) END), _4) FROM t1")); + } } } test("array_contains - test all types (native Parquet reader)") { withTempDir { dir => - val path = new Path(dir.toURI.toString, "test.parquet") - val filename = path.toString - val random = new Random(42) - withSQLConf(CometConf.COMET_ENABLED.key -> "false") { - ParquetGenerator.makeParquetFile( - random, - spark, - filename, - 100, - DataGenOptions( - allowNull = true, - generateNegativeZero = true, - generateArray = true, - generateStruct = true, - generateMap = false)) - } - val table = spark.read.parquet(filename) - table.createOrReplaceTempView("t1") - val complexTypeFields = - table.schema.fields.filter(field => isComplexType(field.dataType)) - val primitiveTypeFields = - table.schema.fields.filterNot(field => isComplexType(field.dataType)) - for (field <- primitiveTypeFields) { - val fieldName = field.name - val typeName = field.dataType.typeName - sql(s"SELECT array($fieldName, $fieldName) as a, $fieldName as b FROM t1") - .createOrReplaceTempView("t2") - checkSparkAnswerAndOperator(sql("SELECT array_contains(a, b) FROM t2")) - checkSparkAnswerAndOperator( - sql(s"SELECT array_contains(a, cast(null as $typeName)) FROM t2")) - } - for (field <- complexTypeFields) { - val fieldName = field.name - sql(s"SELECT array($fieldName, $fieldName) as a, $fieldName as b FROM t1") - .createOrReplaceTempView("t3") - checkSparkAnswer(sql("SELECT array_contains(a, b) FROM t3")) + withTempView("t1", "t2", "t3") { + val path = new Path(dir.toURI.toString, "test.parquet") + val filename = path.toString + val random = new Random(42) + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + ParquetGenerator.makeParquetFile( + random, + spark, + filename, + 100, + DataGenOptions( + allowNull = true, + generateNegativeZero = true, + generateArray = true, + generateStruct = true, + generateMap = false)) + } + val table = spark.read.parquet(filename) + table.createOrReplaceTempView("t1") + val complexTypeFields = + table.schema.fields.filter(field => isComplexType(field.dataType)) + val primitiveTypeFields = + table.schema.fields.filterNot(field => isComplexType(field.dataType)) + for (field <- primitiveTypeFields) { + val fieldName = field.name + val typeName = field.dataType.typeName + sql(s"SELECT array($fieldName, $fieldName) as a, $fieldName as b FROM t1") + .createOrReplaceTempView("t2") + checkSparkAnswerAndOperator(sql("SELECT array_contains(a, b) FROM t2")) + checkSparkAnswerAndOperator( + sql(s"SELECT array_contains(a, cast(null as $typeName)) FROM t2")) + } + for (field <- complexTypeFields) { + val fieldName = field.name + sql(s"SELECT array($fieldName, $fieldName) as a, $fieldName as b FROM t1") + .createOrReplaceTempView("t3") + checkSparkAnswer(sql("SELECT array_contains(a, b) FROM t3")) + } } } } test("array_contains - array literals") { withTempDir { dir => - val path = new Path(dir.toURI.toString, "test.parquet") - val filename = path.toString - val random = new Random(42) - withSQLConf(CometConf.COMET_ENABLED.key -> "false") { - ParquetGenerator.makeParquetFile( - random, - spark, - filename, - 100, - DataGenOptions( - allowNull = true, - generateNegativeZero = true, - generateArray = false, - generateStruct = false, - generateMap = false)) - } - val table = spark.read.parquet(filename) - table.createOrReplaceTempView("t2") - for (field <- table.schema.fields) { - val typeName = field.dataType.typeName - checkSparkAnswerAndOperator(sql( - s"SELECT array_contains(cast(null as array<$typeName>), cast(null as $typeName)) FROM t2")) + withTempView("t2") { + val path = new Path(dir.toURI.toString, "test.parquet") + val filename = path.toString + val random = new Random(42) + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + ParquetGenerator.makeParquetFile( + random, + spark, + filename, + 100, + DataGenOptions( + allowNull = true, + generateNegativeZero = true, + generateArray = false, + generateStruct = false, + generateMap = false)) + } + val table = spark.read.parquet(filename) + table.createOrReplaceTempView("t2") + for (field <- table.schema.fields) { + val typeName = field.dataType.typeName + checkSparkAnswerAndOperator(sql( + s"SELECT array_contains(cast(null as array<$typeName>), cast(null as $typeName)) FROM t2")) + } + checkSparkAnswerAndOperator(sql("SELECT array_contains(array(), 1) FROM t2")) } - checkSparkAnswerAndOperator(sql("SELECT array_contains(array(), 1) FROM t2")) } } @@ -328,13 +351,15 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp CometConf.COMET_NATIVE_SCAN_ENABLED.key -> "false", CometConf.COMET_SPARK_TO_ARROW_ENABLED.key -> "true", CometConf.COMET_CONVERT_FROM_PARQUET_ENABLED.key -> "true") { - val table = spark.read.parquet(filename) - table.createOrReplaceTempView("t1") - for (field <- table.schema.fields) { - val fieldName = field.name - sql(s"SELECT array($fieldName, $fieldName) as a, $fieldName as b FROM t1") - .createOrReplaceTempView("t2") - checkSparkAnswer(sql("SELECT array_contains(a, b) FROM t2")) + withTempView("t1", "t2") { + val table = spark.read.parquet(filename) + table.createOrReplaceTempView("t1") + for (field <- table.schema.fields) { + val fieldName = field.name + sql(s"SELECT array($fieldName, $fieldName) as a, $fieldName as b FROM t1") + .createOrReplaceTempView("t2") + checkSparkAnswer(sql("SELECT array_contains(a, b) FROM t2")) + } } } } @@ -344,24 +369,26 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp withSQLConf(CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") { Seq(true, false).foreach { dictionaryEnabled => withTempDir { dir => - val path = new Path(dir.toURI.toString, "test.parquet") - makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, n = 10000) - spark.read.parquet(path.toString).createOrReplaceTempView("t1") - // The result needs to be in ascending order for checkSparkAnswerAndOperator to pass - // because datafusion array_distinct sorts the elements and then removes the duplicates - checkSparkAnswerAndOperator( - spark.sql("SELECT array_distinct(array(_2, _2, _3, _4, _4)) FROM t1")) - checkSparkAnswerAndOperator( - spark.sql("SELECT array_distinct((CASE WHEN _2 =_3 THEN array(_4) END)) FROM t1")) - checkSparkAnswerAndOperator(spark.sql( - "SELECT array_distinct((CASE WHEN _2 =_3 THEN array(_2, _2, _4, _4, _5) END)) FROM t1")) - // NULL needs to be the first element for checkSparkAnswerAndOperator to pass because - // datafusion array_distinct sorts the elements and then removes the duplicates - checkSparkAnswerAndOperator( - spark.sql( - "SELECT array_distinct(array(CAST(NULL AS INT), _2, _2, _3, _4, _4)) FROM t1")) - checkSparkAnswerAndOperator(spark.sql( - "SELECT array_distinct(array(CAST(NULL AS INT), CAST(NULL AS INT), _2, _2, _3, _4, _4)) FROM t1")) + withTempView("t1") { + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, n = 10000) + spark.read.parquet(path.toString).createOrReplaceTempView("t1") + // The result needs to be in ascending order for checkSparkAnswerAndOperator to pass + // because datafusion array_distinct sorts the elements and then removes the duplicates + checkSparkAnswerAndOperator( + spark.sql("SELECT array_distinct(array(_2, _2, _3, _4, _4)) FROM t1")) + checkSparkAnswerAndOperator( + spark.sql("SELECT array_distinct((CASE WHEN _2 =_3 THEN array(_4) END)) FROM t1")) + checkSparkAnswerAndOperator(spark.sql( + "SELECT array_distinct((CASE WHEN _2 =_3 THEN array(_2, _2, _4, _4, _5) END)) FROM t1")) + // NULL needs to be the first element for checkSparkAnswerAndOperator to pass because + // datafusion array_distinct sorts the elements and then removes the duplicates + checkSparkAnswerAndOperator( + spark.sql( + "SELECT array_distinct(array(CAST(NULL AS INT), _2, _2, _3, _4, _4)) FROM t1")) + checkSparkAnswerAndOperator(spark.sql( + "SELECT array_distinct(array(CAST(NULL AS INT), CAST(NULL AS INT), _2, _2, _3, _4, _4)) FROM t1")) + } } } } @@ -371,16 +398,18 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp withSQLConf(CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") { Seq(true, false).foreach { dictionaryEnabled => withTempDir { dir => - val path = new Path(dir.toURI.toString, "test.parquet") - makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, n = 10000) - spark.read.parquet(path.toString).createOrReplaceTempView("t1") - checkSparkAnswerAndOperator( - spark.sql("SELECT array_union(array(_2, _3, _4), array(_3, _4)) FROM t1")) - checkSparkAnswerAndOperator(sql("SELECT array_union(array(_18), array(_19)) from t1")) - checkSparkAnswerAndOperator(spark.sql( - "SELECT array_union(array(CAST(NULL AS INT), _2, _3, _4), array(CAST(NULL AS INT), _2, _3)) FROM t1")) - checkSparkAnswerAndOperator(spark.sql( - "SELECT array_union(array(CAST(NULL AS INT), CAST(NULL AS INT), _2, _3, _4), array(CAST(NULL AS INT), CAST(NULL AS INT), _2, _3)) FROM t1")) + withTempView("t1") { + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, n = 10000) + spark.read.parquet(path.toString).createOrReplaceTempView("t1") + checkSparkAnswerAndOperator( + spark.sql("SELECT array_union(array(_2, _3, _4), array(_3, _4)) FROM t1")) + checkSparkAnswerAndOperator(sql("SELECT array_union(array(_18), array(_19)) from t1")) + checkSparkAnswerAndOperator(spark.sql( + "SELECT array_union(array(CAST(NULL AS INT), _2, _3, _4), array(CAST(NULL AS INT), _2, _3)) FROM t1")) + checkSparkAnswerAndOperator(spark.sql( + "SELECT array_union(array(CAST(NULL AS INT), CAST(NULL AS INT), _2, _3, _4), array(CAST(NULL AS INT), CAST(NULL AS INT), _2, _3)) FROM t1")) + } } } } @@ -389,22 +418,24 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp test("array_max") { Seq(true, false).foreach { dictionaryEnabled => withTempDir { dir => - val path = new Path(dir.toURI.toString, "test.parquet") - makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, n = 10000) - spark.read.parquet(path.toString).createOrReplaceTempView("t1"); - checkSparkAnswerAndOperator(spark.sql("SELECT array_max(array(_2, _3, _4)) FROM t1")) - checkSparkAnswerAndOperator( - spark.sql("SELECT array_max((CASE WHEN _2 =_3 THEN array(_4) END)) FROM t1")) - checkSparkAnswerAndOperator( - spark.sql("SELECT array_max((CASE WHEN _2 =_3 THEN array(_2, _4) END)) FROM t1")) - checkSparkAnswerAndOperator( - spark.sql("SELECT array_max(array(CAST(NULL AS INT), CAST(NULL AS INT))) FROM t1")) - checkSparkAnswerAndOperator( - spark.sql("SELECT array_max(array(_2, CAST(NULL AS INT))) FROM t1")) - checkSparkAnswerAndOperator(spark.sql("SELECT array_max(array()) FROM t1")) - checkSparkAnswerAndOperator( - spark.sql( - "SELECT array_max(array(double('-Infinity'), 0.0, double('Infinity'))) FROM t1")) + withTempView("t1") { + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, n = 10000) + spark.read.parquet(path.toString).createOrReplaceTempView("t1"); + checkSparkAnswerAndOperator(spark.sql("SELECT array_max(array(_2, _3, _4)) FROM t1")) + checkSparkAnswerAndOperator( + spark.sql("SELECT array_max((CASE WHEN _2 =_3 THEN array(_4) END)) FROM t1")) + checkSparkAnswerAndOperator( + spark.sql("SELECT array_max((CASE WHEN _2 =_3 THEN array(_2, _4) END)) FROM t1")) + checkSparkAnswerAndOperator( + spark.sql("SELECT array_max(array(CAST(NULL AS INT), CAST(NULL AS INT))) FROM t1")) + checkSparkAnswerAndOperator( + spark.sql("SELECT array_max(array(_2, CAST(NULL AS INT))) FROM t1")) + checkSparkAnswerAndOperator(spark.sql("SELECT array_max(array()) FROM t1")) + checkSparkAnswerAndOperator( + spark.sql( + "SELECT array_max(array(double('-Infinity'), 0.0, double('Infinity'))) FROM t1")) + } } } } @@ -412,40 +443,43 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp test("array_min") { Seq(true, false).foreach { dictionaryEnabled => withTempDir { dir => - val path = new Path(dir.toURI.toString, "test.parquet") - makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, n = 10000) - spark.read.parquet(path.toString).createOrReplaceTempView("t1"); - checkSparkAnswerAndOperator(spark.sql("SELECT array_min(array(_2, _3, _4)) FROM t1")) - checkSparkAnswerAndOperator( - spark.sql("SELECT array_min((CASE WHEN _2 =_3 THEN array(_4) END)) FROM t1")) - checkSparkAnswerAndOperator( - spark.sql("SELECT array_min((CASE WHEN _2 =_3 THEN array(_2, _4) END)) FROM t1")) - checkSparkAnswerAndOperator( - spark.sql("SELECT array_min(array(CAST(NULL AS INT), CAST(NULL AS INT))) FROM t1")) - checkSparkAnswerAndOperator( - spark.sql("SELECT array_min(array(_2, CAST(NULL AS INT))) FROM t1")) - checkSparkAnswerAndOperator(spark.sql("SELECT array_min(array()) FROM t1")) - checkSparkAnswerAndOperator( - spark.sql( - "SELECT array_min(array(double('-Infinity'), 0.0, double('Infinity'))) FROM t1")) + withTempView("t1") { + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, n = 10000) + spark.read.parquet(path.toString).createOrReplaceTempView("t1"); + checkSparkAnswerAndOperator(spark.sql("SELECT array_min(array(_2, _3, _4)) FROM t1")) + checkSparkAnswerAndOperator( + spark.sql("SELECT array_min((CASE WHEN _2 =_3 THEN array(_4) END)) FROM t1")) + checkSparkAnswerAndOperator( + spark.sql("SELECT array_min((CASE WHEN _2 =_3 THEN array(_2, _4) END)) FROM t1")) + checkSparkAnswerAndOperator( + spark.sql("SELECT array_min(array(CAST(NULL AS INT), CAST(NULL AS INT))) FROM t1")) + checkSparkAnswerAndOperator( + spark.sql("SELECT array_min(array(_2, CAST(NULL AS INT))) FROM t1")) + checkSparkAnswerAndOperator(spark.sql("SELECT array_min(array()) FROM t1")) + checkSparkAnswerAndOperator( + spark.sql( + "SELECT array_min(array(double('-Infinity'), 0.0, double('Infinity'))) FROM t1")) + } } } } test("array_intersect") { withSQLConf(CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") { - Seq(true, false).foreach { dictionaryEnabled => withTempDir { dir => - val path = new Path(dir.toURI.toString, "test.parquet") - makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, 10000) - spark.read.parquet(path.toString).createOrReplaceTempView("t1") - checkSparkAnswerAndOperator( - sql("SELECT array_intersect(array(_2, _3, _4), array(_3, _4)) from t1")) - checkSparkAnswerAndOperator( - sql("SELECT array_intersect(array(_4 * -1), array(_5)) from t1")) - checkSparkAnswerAndOperator( - sql("SELECT array_intersect(array(_18), array(_19)) from t1")) + withTempView("t1") { + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, 10000) + spark.read.parquet(path.toString).createOrReplaceTempView("t1") + checkSparkAnswerAndOperator( + sql("SELECT array_intersect(array(_2, _3, _4), array(_3, _4)) from t1")) + checkSparkAnswerAndOperator( + sql("SELECT array_intersect(array(_4 * -1), array(_5)) from t1")) + checkSparkAnswerAndOperator( + sql("SELECT array_intersect(array(_18), array(_19)) from t1")) + } } } } @@ -455,18 +489,19 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp withSQLConf(CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") { Seq(true, false).foreach { dictionaryEnabled => withTempDir { dir => - val path = new Path(dir.toURI.toString, "test.parquet") - makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, 10000) - spark.read.parquet(path.toString).createOrReplaceTempView("t1") - checkSparkAnswerAndOperator(sql( - "SELECT array_join(array(cast(_1 as string), cast(_2 as string), cast(_6 as string)), ' @ ') from t1")) - checkSparkAnswerAndOperator(sql( - "SELECT array_join(array(cast(_1 as string), cast(_2 as string), cast(_6 as string)), ' @ ', ' +++ ') from t1")) - checkSparkAnswerAndOperator(sql( - "SELECT array_join(array('hello', 'world', cast(_2 as string)), ' ') from t1 where _2 is not null")) - checkSparkAnswerAndOperator( - sql( + withTempView("t1") { + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, 10000) + spark.read.parquet(path.toString).createOrReplaceTempView("t1") + checkSparkAnswerAndOperator(sql( + "SELECT array_join(array(cast(_1 as string), cast(_2 as string), cast(_6 as string)), ' @ ') from t1")) + checkSparkAnswerAndOperator(sql( + "SELECT array_join(array(cast(_1 as string), cast(_2 as string), cast(_6 as string)), ' @ ', ' +++ ') from t1")) + checkSparkAnswerAndOperator(sql( + "SELECT array_join(array('hello', 'world', cast(_2 as string)), ' ') from t1 where _2 is not null")) + checkSparkAnswerAndOperator(sql( "SELECT array_join(array('hello', '-', 'world', cast(_2 as string)), ' ') from t1")) + } } } } @@ -476,17 +511,19 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp withSQLConf(CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") { Seq(true, false).foreach { dictionaryEnabled => withTempDir { dir => - val path = new Path(dir.toURI.toString, "test.parquet") - makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, 10000) - spark.read.parquet(path.toString).createOrReplaceTempView("t1") - checkSparkAnswerAndOperator(sql( - "SELECT arrays_overlap(array(_2, _3, _4), array(_3, _4)) from t1 where _2 is not null")) - checkSparkAnswerAndOperator(sql( - "SELECT arrays_overlap(array('a', null, cast(_1 as string)), array('b', cast(_1 as string), cast(_2 as string))) from t1 where _1 is not null")) - checkSparkAnswerAndOperator(sql( - "SELECT arrays_overlap(array('a', null), array('b', null)) from t1 where _1 is not null")) - checkSparkAnswerAndOperator(spark.sql( - "SELECT arrays_overlap((CASE WHEN _2 =_3 THEN array(_6, _7) END), array(_6, _7)) FROM t1")); + withTempView("t1") { + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, 10000) + spark.read.parquet(path.toString).createOrReplaceTempView("t1") + checkSparkAnswerAndOperator(sql( + "SELECT arrays_overlap(array(_2, _3, _4), array(_3, _4)) from t1 where _2 is not null")) + checkSparkAnswerAndOperator(sql( + "SELECT arrays_overlap(array('a', null, cast(_1 as string)), array('b', cast(_1 as string), cast(_2 as string))) from t1 where _1 is not null")) + checkSparkAnswerAndOperator(sql( + "SELECT arrays_overlap(array('a', null), array('b', null)) from t1 where _1 is not null")) + checkSparkAnswerAndOperator(spark.sql( + "SELECT arrays_overlap((CASE WHEN _2 =_3 THEN array(_6, _7) END), array(_6, _7)) FROM t1")); + } } } } @@ -498,16 +535,21 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp withSQLConf(CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") { Seq(true, false).foreach { dictionaryEnabled => withTempDir { dir => - val path = new Path(dir.toURI.toString, "test.parquet") - makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = dictionaryEnabled, n = 10000) - spark.read.parquet(path.toString).createOrReplaceTempView("t1") - - checkSparkAnswerAndOperator( - sql("SELECT array_compact(array(_2)) FROM t1 WHERE _2 IS NULL")) - checkSparkAnswerAndOperator( - sql("SELECT array_compact(array(_2)) FROM t1 WHERE _2 IS NOT NULL")) - checkSparkAnswerAndOperator( - sql("SELECT array_compact(array(_2, _3, null)) FROM t1 WHERE _2 IS NOT NULL")) + withTempView("t1") { + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllPrimitiveTypes( + path, + dictionaryEnabled = dictionaryEnabled, + n = 10000) + spark.read.parquet(path.toString).createOrReplaceTempView("t1") + + checkSparkAnswerAndOperator( + sql("SELECT array_compact(array(_2)) FROM t1 WHERE _2 IS NULL")) + checkSparkAnswerAndOperator( + sql("SELECT array_compact(array(_2)) FROM t1 WHERE _2 IS NOT NULL")) + checkSparkAnswerAndOperator( + sql("SELECT array_compact(array(_2, _3, null)) FROM t1 WHERE _2 IS NOT NULL")) + } } } } @@ -517,16 +559,19 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp withSQLConf(CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") { Seq(true, false).foreach { dictionaryEnabled => withTempDir { dir => - val path = new Path(dir.toURI.toString, "test.parquet") - makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, 10000) - spark.read.parquet(path.toString).createOrReplaceTempView("t1") - - checkSparkAnswerAndOperator( - sql("SELECT array_except(array(_2, _3, _4), array(_3, _4)) from t1")) - checkSparkAnswerAndOperator(sql("SELECT array_except(array(_18), array(_19)) from t1")) - checkSparkAnswerAndOperator( - spark.sql( - "SELECT array_except(array(_2, _2, _4), array(_4)) FROM t1 WHERE _2 IS NOT NULL")) + withTempView("t1") { + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, 10000) + spark.read.parquet(path.toString).createOrReplaceTempView("t1") + + checkSparkAnswerAndOperator( + sql("SELECT array_except(array(_2, _3, _4), array(_3, _4)) from t1")) + checkSparkAnswerAndOperator( + sql("SELECT array_except(array(_18), array(_19)) from t1")) + checkSparkAnswerAndOperator( + spark.sql( + "SELECT array_except(array(_2, _2, _4), array(_4)) FROM t1 WHERE _2 IS NOT NULL")) + } } } } @@ -551,19 +596,21 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp generateMap = false)) } withSQLConf(CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") { - val table = spark.read.parquet(filename) - table.createOrReplaceTempView("t1") - // test with array of each column - val fields = - table.schema.fields.filter(field => CometArrayExcept.isTypeSupported(field.dataType)) - for (field <- fields) { - val fieldName = field.name - val typeName = field.dataType.typeName - sql( - s"SELECT cast(array($fieldName, $fieldName) as array<$typeName>) as a, cast(array($fieldName) as array<$typeName>) as b FROM t1") - .createOrReplaceTempView("t2") - val df = sql("SELECT array_except(a, b) FROM t2") - checkSparkAnswerAndOperator(df) + withTempView("t1", "t2") { + val table = spark.read.parquet(filename) + table.createOrReplaceTempView("t1") + // test with array of each column + val fields = + table.schema.fields.filter(field => CometArrayExcept.isTypeSupported(field.dataType)) + for (field <- fields) { + val fieldName = field.name + val typeName = field.dataType.typeName + sql( + s"SELECT cast(array($fieldName, $fieldName) as array<$typeName>) as a, cast(array($fieldName) as array<$typeName>) as b FROM t1") + .createOrReplaceTempView("t2") + val df = sql("SELECT array_except(a, b) FROM t2") + checkSparkAnswerAndOperator(df) + } } } } @@ -588,17 +635,19 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp CometConf.COMET_SPARK_TO_ARROW_ENABLED.key -> "true", CometConf.COMET_CONVERT_FROM_PARQUET_ENABLED.key -> "true", CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") { - val table = spark.read.parquet(filename) - table.createOrReplaceTempView("t1") - // test with array of each column - val fields = - table.schema.fields.filter(field => CometArrayExcept.isTypeSupported(field.dataType)) - for (field <- fields) { - val fieldName = field.name - sql(s"SELECT array($fieldName, $fieldName) as a, array($fieldName) as b FROM t1") - .createOrReplaceTempView("t2") - val df = sql("SELECT array_except(a, b) FROM t2") - checkSparkAnswer(df) + withTempView("t1", "t2") { + val table = spark.read.parquet(filename) + table.createOrReplaceTempView("t1") + // test with array of each column + val fields = + table.schema.fields.filter(field => CometArrayExcept.isTypeSupported(field.dataType)) + for (field <- fields) { + val fieldName = field.name + sql(s"SELECT array($fieldName, $fieldName) as a, array($fieldName) as b FROM t1") + .createOrReplaceTempView("t2") + val df = sql("SELECT array_except(a, b) FROM t2") + checkSparkAnswer(df) + } } } } @@ -610,19 +659,22 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp CometConf.COMET_EXPLAIN_FALLBACK_ENABLED.key -> "true") { Seq(true, false).foreach { dictionaryEnabled => withTempDir { dir => - val path = new Path(dir.toURI.toString, "test.parquet") - makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, 100) - spark.read.parquet(path.toString).createOrReplaceTempView("t1") - - checkSparkAnswerAndOperator(sql("SELECT array_repeat(_4, null) from t1")) - checkSparkAnswerAndOperator(sql("SELECT array_repeat(_4, 0) from t1")) - checkSparkAnswerAndOperator( - sql("SELECT array_repeat(_2, 5) from t1 where _2 is not null")) - checkSparkAnswerAndOperator(sql("SELECT array_repeat(_2, 5) from t1 where _2 is null")) - checkSparkAnswerAndOperator( - sql("SELECT array_repeat(_3, _4) from t1 where _3 is not null")) - checkSparkAnswerAndOperator(sql("SELECT array_repeat(cast(_3 as string), 2) from t1")) - checkSparkAnswerAndOperator(sql("SELECT array_repeat(array(_2, _3, _4), 2) from t1")) + withTempView("t1") { + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, 100) + spark.read.parquet(path.toString).createOrReplaceTempView("t1") + + checkSparkAnswerAndOperator(sql("SELECT array_repeat(_4, null) from t1")) + checkSparkAnswerAndOperator(sql("SELECT array_repeat(_4, 0) from t1")) + checkSparkAnswerAndOperator( + sql("SELECT array_repeat(_2, 5) from t1 where _2 is not null")) + checkSparkAnswerAndOperator( + sql("SELECT array_repeat(_2, 5) from t1 where _2 is null")) + checkSparkAnswerAndOperator( + sql("SELECT array_repeat(_3, _4) from t1 where _3 is not null")) + checkSparkAnswerAndOperator(sql("SELECT array_repeat(cast(_3 as string), 2) from t1")) + checkSparkAnswerAndOperator(sql("SELECT array_repeat(array(_2, _3, _4), 2) from t1")) + } } } } @@ -630,32 +682,34 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp test("flatten - test all types (native Parquet reader)") { withTempDir { dir => - val path = new Path(dir.toURI.toString, "test.parquet") - val filename = path.toString - val random = new Random(42) - withSQLConf(CometConf.COMET_ENABLED.key -> "false") { - ParquetGenerator.makeParquetFile( - random, - spark, - filename, - 100, - DataGenOptions( - allowNull = true, - generateNegativeZero = true, - generateArray = false, - generateStruct = false, - generateMap = false)) - } - val table = spark.read.parquet(filename) - table.createOrReplaceTempView("t1") - val fieldNames = - table.schema.fields - .filter(field => CometFlatten.isTypeSupported(field.dataType)) - .map(_.name) - for (fieldName <- fieldNames) { - sql(s"SELECT array(array($fieldName, $fieldName), array($fieldName)) as a FROM t1") - .createOrReplaceTempView("t2") - checkSparkAnswerAndOperator(sql("SELECT flatten(a) FROM t2")) + withTempView("t1", "t2") { + val path = new Path(dir.toURI.toString, "test.parquet") + val filename = path.toString + val random = new Random(42) + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + ParquetGenerator.makeParquetFile( + random, + spark, + filename, + 100, + DataGenOptions( + allowNull = true, + generateNegativeZero = true, + generateArray = false, + generateStruct = false, + generateMap = false)) + } + val table = spark.read.parquet(filename) + table.createOrReplaceTempView("t1") + val fieldNames = + table.schema.fields + .filter(field => CometFlatten.isTypeSupported(field.dataType)) + .map(_.name) + for (fieldName <- fieldNames) { + sql(s"SELECT array(array($fieldName, $fieldName), array($fieldName)) as a FROM t1") + .createOrReplaceTempView("t2") + checkSparkAnswerAndOperator(sql("SELECT flatten(a) FROM t2")) + } } } } @@ -678,16 +732,18 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp CometConf.COMET_NATIVE_SCAN_ENABLED.key -> "false", CometConf.COMET_SPARK_TO_ARROW_ENABLED.key -> "true", CometConf.COMET_CONVERT_FROM_PARQUET_ENABLED.key -> "true") { - val table = spark.read.parquet(filename) - table.createOrReplaceTempView("t1") - val fieldNames = - table.schema.fields - .filter(field => CometFlatten.isTypeSupported(field.dataType)) - .map(_.name) - for (fieldName <- fieldNames) { - sql(s"SELECT array(array($fieldName, $fieldName), array($fieldName)) as a FROM t1") - .createOrReplaceTempView("t2") - checkSparkAnswer(sql("SELECT flatten(a) FROM t2")) + withTempView("t1", "t2") { + val table = spark.read.parquet(filename) + table.createOrReplaceTempView("t1") + val fieldNames = + table.schema.fields + .filter(field => CometFlatten.isTypeSupported(field.dataType)) + .map(_.name) + for (fieldName <- fieldNames) { + sql(s"SELECT array(array($fieldName, $fieldName), array($fieldName)) as a FROM t1") + .createOrReplaceTempView("t2") + checkSparkAnswer(sql("SELECT flatten(a) FROM t2")) + } } } } @@ -699,11 +755,48 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp CometConf.COMET_EXPLAIN_FALLBACK_ENABLED.key -> "true") { Seq(true, false).foreach { dictionaryEnabled => withTempDir { dir => - val path = new Path(dir.toURI.toString, "test.parquet") - makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, 100) - spark.read.parquet(path.toString).createOrReplaceTempView("t1") - checkSparkAnswerAndOperator( - sql("SELECT array(array(1, 2, 3), null, array(), array(null), array(1)) from t1")) + withTempView("t1") { + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, 100) + spark.read.parquet(path.toString).createOrReplaceTempView("t1") + checkSparkAnswerAndOperator( + sql("SELECT array(array(1, 2, 3), null, array(), array(null), array(1)) from t1")) + } + } + } + } + } + + test("array_reverse") { + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + val filename = path.toString + val random = new Random(42) + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + val options = DataGenOptions( + allowNull = true, + generateNegativeZero = true, + generateArray = true, + generateStruct = true, + generateMap = false) + ParquetGenerator.makeParquetFile(random, spark, filename, 100, options) + } + withSQLConf( + CometConf.COMET_NATIVE_SCAN_ENABLED.key -> "false", + CometConf.COMET_SPARK_TO_ARROW_ENABLED.key -> "true", + CometConf.COMET_CONVERT_FROM_PARQUET_ENABLED.key -> "true") { + withTempView("t1", "t2") { + val table = spark.read.parquet(filename) + table.createOrReplaceTempView("t1") + val fieldNames = + table.schema.fields + .filter(field => CometArrayReverse.isTypeSupported(field.dataType)) + .map(_.name) + for (fieldName <- fieldNames) { + sql(s"SELECT $fieldName as a FROM t1") + .createOrReplaceTempView("t2") + checkSparkAnswer(sql("SELECT reverse(a) FROM t2")) + } } } }