diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala b/spark/src/main/scala/org/apache/comet/serde/arrays.scala index 5d989b4a35..b552a071d6 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala @@ -546,7 +546,6 @@ object CometArrayFilter extends CometExpressionSerde[ArrayFilter] { object CometSize extends CometExpressionSerde[Size] { override def getSupportLevel(expr: Size): SupportLevel = { - // TODO respect spark.sql.legacy.sizeOfNull expr.child.dataType match { case _: ArrayType => Compatible() case _: MapType => Unsupported(Some("size does not support map inputs")) @@ -554,7 +553,6 @@ object CometSize extends CometExpressionSerde[Size] { // this should be unreachable because Spark only supports map and array inputs Unsupported(Some(s"Unsupported child data type: $other")) } - } override def convert( @@ -562,10 +560,41 @@ object CometSize extends CometExpressionSerde[Size] { inputs: Seq[Attribute], binding: Boolean): Option[ExprOuterClass.Expr] = { val arrayExprProto = exprToProto(expr.child, inputs, binding) + for { + isNotNullExprProto <- createIsNotNullExprProto(expr, inputs, binding) + sizeScalarExprProto <- scalarFunctionExprToProto("size", arrayExprProto) + emptyLiteralExprProto <- createLiteralExprProto(SQLConf.get.legacySizeOfNull) + } yield { + val caseWhenExpr = ExprOuterClass.CaseWhen + .newBuilder() + .addWhen(isNotNullExprProto) + .addThen(sizeScalarExprProto) + .setElseExpr(emptyLiteralExprProto) + .build() + ExprOuterClass.Expr + .newBuilder() + .setCaseWhen(caseWhenExpr) + .build() + } + } + + private def createIsNotNullExprProto( + expr: Size, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + createUnaryExpr( + expr, + expr.child, + inputs, + binding, + (builder, unaryExpr) => builder.setIsNotNull(unaryExpr)) + } - val sizeScalarExpr = scalarFunctionExprToProto("size", arrayExprProto) - optExprWithInfo(sizeScalarExpr, expr) + private def createLiteralExprProto(legacySizeOfNull: Boolean): Option[ExprOuterClass.Expr] = { + val value = if (legacySizeOfNull) -1 else null + exprToProto(Literal(value, IntegerType), Seq.empty) } + } trait ArraysBase { diff --git a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala index 9f908e741e..cf49117364 100644 --- a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.CometTestBase import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayDistinct, ArrayExcept, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayRepeat, ArraysOverlap, ArrayUnion} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.ArrayType import org.apache.comet.CometSparkSessionExtensions.{isSpark35Plus, isSpark40Plus} @@ -871,4 +872,22 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp } } } + + test("size - respect to legacySizeOfNull") { + val table = "t1" + withSQLConf(CometConf.COMET_NATIVE_SCAN_IMPL.key -> CometConf.SCAN_NATIVE_ICEBERG_COMPAT) { + withTable(table) { + sql(s"create table $table(col array) using parquet") + sql(s"insert into $table values(null)") + withSQLConf(SQLConf.LEGACY_SIZE_OF_NULL.key -> "false") { + checkSparkAnswerAndOperator(sql(s"select size(col) from $table")) + } + withSQLConf( + SQLConf.LEGACY_SIZE_OF_NULL.key -> "true", + SQLConf.ANSI_ENABLED.key -> "false") { + checkSparkAnswerAndOperator(sql(s"select size(col) from $table")) + } + } + } + } }