Skip to content

Commit e0cdb86

Browse files
committed
lazy_coalesce_fallback_case_statement
1 parent f6369dd commit e0cdb86

File tree

3 files changed

+54
-10
lines changed

3 files changed

+54
-10
lines changed

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,8 @@ object QueryPlanSerde extends Logging with CometExprShim {
191191
classOf[Log2] -> CometLog2,
192192
classOf[Pow] -> CometScalarFunction[Pow]("pow"),
193193
classOf[If] -> CometIf,
194-
classOf[CaseWhen] -> CometCaseWhen)
194+
classOf[CaseWhen] -> CometCaseWhen,
195+
classOf[Coalesce] -> CometCoalesce)
195196

196197
/**
197198
* Mapping of Spark aggregate expression class to Comet expression handler.
@@ -1078,10 +1079,6 @@ object QueryPlanSerde extends Logging with CometExprShim {
10781079
None
10791080
}
10801081

1081-
case a @ Coalesce(_) =>
1082-
val exprChildren = a.children.map(exprToProtoInternal(_, inputs, binding))
1083-
scalarFunctionExprToProto("coalesce", exprChildren: _*)
1084-
10851082
// With Spark 3.4, CharVarcharCodegenUtils.readSidePadding gets called to pad spaces for
10861083
// char types.
10871084
// See https://github.com/apache/spark/pull/38151

spark/src/main/scala/org/apache/comet/serde/conditional.scala

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ package org.apache.comet.serde
2121

2222
import scala.collection.JavaConverters._
2323

24-
import org.apache.spark.sql.catalyst.expressions.{Attribute, CaseWhen, Expression, If}
24+
import org.apache.spark.sql.catalyst.expressions.{Attribute, CaseWhen, Coalesce, Expression, If, IsNotNull}
2525

2626
import org.apache.comet.CometSparkSessionExtensions.withInfo
2727
import org.apache.comet.serde.QueryPlanSerde.exprToProtoInternal
@@ -91,3 +91,48 @@ object CometCaseWhen extends CometExpressionSerde[CaseWhen] {
9191
}
9292
}
9393
}
94+
95+
object CometCoalesce extends CometExpressionSerde[Coalesce] {
96+
override def convert(
97+
expr: Coalesce,
98+
inputs: Seq[Attribute],
99+
binding: Boolean): Option[ExprOuterClass.Expr] = {
100+
val branches = expr.children.dropRight(1).map { child =>
101+
(IsNotNull(child), child)
102+
}
103+
val elseValue = Some(expr.children.last)
104+
var allBranches: Seq[Expression] = Seq()
105+
val whenSeq = branches.map(elements => {
106+
allBranches = allBranches :+ elements._1
107+
exprToProtoInternal(elements._1, inputs, binding)
108+
})
109+
val thenSeq = branches.map(elements => {
110+
allBranches = allBranches :+ elements._2
111+
exprToProtoInternal(elements._2, inputs, binding)
112+
})
113+
assert(whenSeq.length == thenSeq.length)
114+
if (whenSeq.forall(_.isDefined) && thenSeq.forall(_.isDefined)) {
115+
val builder = ExprOuterClass.CaseWhen.newBuilder()
116+
builder.addAllWhen(whenSeq.map(_.get).asJava)
117+
builder.addAllThen(thenSeq.map(_.get).asJava)
118+
if (elseValue.isDefined) {
119+
val elseValueExpr =
120+
exprToProtoInternal(elseValue.get, inputs, binding)
121+
if (elseValueExpr.isDefined) {
122+
builder.setElseExpr(elseValueExpr.get)
123+
} else {
124+
withInfo(expr, elseValue.get)
125+
return None
126+
}
127+
}
128+
Some(
129+
ExprOuterClass.Expr
130+
.newBuilder()
131+
.setCaseWhen(builder)
132+
.build())
133+
} else {
134+
withInfo(expr, allBranches: _*)
135+
None
136+
}
137+
}
138+
}

spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -395,13 +395,15 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
395395
}
396396

397397
test("test coalesce lazy eval") {
398-
withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
399-
val data = Seq((100, 0))
398+
withSQLConf(
399+
SQLConf.ANSI_ENABLED.key -> "true",
400+
CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key -> "true") {
401+
val data = Seq((9999999999999L, 0))
400402
withParquetTable(data, "t1") {
401403
val res = spark.sql("""
402-
|SELECT coalesce(_1 , 1/0) from t1;
404+
|SELECT coalesce(_1, CAST(_1 AS TINYINT)) from t1;
403405
| """.stripMargin)
404-
checkSparkAnswer(res)
406+
checkSparkAnswerAndOperator(res)
405407
}
406408
}
407409
}

0 commit comments

Comments
 (0)