diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index d4c3be1877..84ffd11daa 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -192,7 +192,8 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[Log2] -> CometLog2, classOf[Pow] -> CometScalarFunction[Pow]("pow"), classOf[If] -> CometIf, - classOf[CaseWhen] -> CometCaseWhen) + classOf[CaseWhen] -> CometCaseWhen, + classOf[Coalesce] -> CometCoalesce) /** * Mapping of Spark aggregate expression class to Comet expression handler. @@ -999,10 +1000,6 @@ object QueryPlanSerde extends Logging with CometExprShim { None } - case a @ Coalesce(_) => - val exprChildren = a.children.map(exprToProtoInternal(_, inputs, binding)) - scalarFunctionExprToProto("coalesce", exprChildren: _*) - // With Spark 3.4, CharVarcharCodegenUtils.readSidePadding gets called to pad spaces for // char types. // See https://github.com/apache/spark/pull/38151 diff --git a/spark/src/main/scala/org/apache/comet/serde/conditional.scala b/spark/src/main/scala/org/apache/comet/serde/conditional.scala index db86afc4cb..e4f76c101e 100644 --- a/spark/src/main/scala/org/apache/comet/serde/conditional.scala +++ b/spark/src/main/scala/org/apache/comet/serde/conditional.scala @@ -21,7 +21,7 @@ package org.apache.comet.serde import scala.collection.JavaConverters._ -import org.apache.spark.sql.catalyst.expressions.{Attribute, CaseWhen, Expression, If} +import org.apache.spark.sql.catalyst.expressions.{Attribute, CaseWhen, Coalesce, Expression, If, IsNotNull} import org.apache.comet.CometSparkSessionExtensions.withInfo import org.apache.comet.serde.QueryPlanSerde.exprToProtoInternal @@ -91,3 +91,42 @@ object CometCaseWhen extends CometExpressionSerde[CaseWhen] { } } } + +object CometCoalesce extends CometExpressionSerde[Coalesce] { + override def convert( + expr: Coalesce, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val branches = expr.children.dropRight(1).map { child => + (IsNotNull(child), child) + } + val elseValue = expr.children.last + val whenSeq = branches.map(elements => { + exprToProtoInternal(elements._1, inputs, binding) + }) + val thenSeq = branches.map(elements => { + exprToProtoInternal(elements._2, inputs, binding) + }) + assert(whenSeq.length == thenSeq.length) + if (whenSeq.forall(_.isDefined) && thenSeq.forall(_.isDefined)) { + val builder = ExprOuterClass.CaseWhen.newBuilder() + builder.addAllWhen(whenSeq.map(_.get).asJava) + builder.addAllThen(thenSeq.map(_.get).asJava) + val elseValueExpr = exprToProtoInternal(elseValue, inputs, binding) + if (elseValueExpr.isDefined) { + builder.setElseExpr(elseValueExpr.get) + } else { + withInfo(expr, elseValue) + return None + } + Some( + ExprOuterClass.Expr + .newBuilder() + .setCaseWhen(builder) + .build()) + } else { + withInfo(expr, branches.map(_._2): _*) + None + } + } +} diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index bf1733b3ff..51bc709078 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -394,6 +394,20 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("test coalesce lazy eval") { + withSQLConf( + SQLConf.ANSI_ENABLED.key -> "true", + CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") { + val data = Seq((9999999999999L, 0)) + withParquetTable(data, "t1") { + val res = spark.sql(""" + |SELECT coalesce(_1, CAST(_1 AS TINYINT)) from t1; + | """.stripMargin) + checkSparkAnswerAndOperator(res) + } + } + } + test("dictionary arithmetic") { // TODO: test ANSI mode withSQLConf(SQLConf.ANSI_ENABLED.key -> "false", "parquet.enable.dictionary" -> "true") {