Skip to content

Commit 9de380d

Browse files
committed
check_upstream_json_enrichments
1 parent 0050ed8 commit 9de380d

File tree

2 files changed

+61
-26
lines changed

2 files changed

+61
-26
lines changed

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1248,13 +1248,36 @@ object QueryPlanSerde extends Logging with CometExprShim {
12481248
None
12491249
}
12501250

1251-
case CaseWhen(branches, elseValue) =>
1251+
// case a @ Coalesce(_) =>
1252+
// val branches = a.children.dropRight(1).map { child =>
1253+
// (IsNotNull(child), child)
1254+
// }
1255+
// val elseValue = Some(a.children.last)
1256+
//
1257+
// exprToProtoInternal(CaseWhen(branches, elseValue), inputs, binding)
1258+
//
1259+
// val exprChildren = a.children.map(exprToProtoInternal(_, inputs, binding))
1260+
// scalarFunctionExprToProto("coalesce", exprChildren: _*)
1261+
1262+
case c @ (CaseWhen(_, _) | Coalesce(_)) =>
1263+
val (finalBranches, finalElse) = c match {
1264+
case CaseWhen(branches, elseValue) =>
1265+
(branches, elseValue)
1266+
1267+
case Coalesce(children) =>
1268+
val branches = children.dropRight(1).map { child =>
1269+
(IsNotNull(child), child)
1270+
}
1271+
val elseValue = Some(children.last)
1272+
(branches, elseValue)
1273+
}
1274+
12521275
var allBranches: Seq[Expression] = Seq()
1253-
val whenSeq = branches.map(elements => {
1276+
val whenSeq = finalBranches.map(elements => {
12541277
allBranches = allBranches :+ elements._1
12551278
exprToProtoInternal(elements._1, inputs, binding)
12561279
})
1257-
val thenSeq = branches.map(elements => {
1280+
val thenSeq = finalBranches.map(elements => {
12581281
allBranches = allBranches :+ elements._2
12591282
exprToProtoInternal(elements._2, inputs, binding)
12601283
})
@@ -1263,13 +1286,13 @@ object QueryPlanSerde extends Logging with CometExprShim {
12631286
val builder = ExprOuterClass.CaseWhen.newBuilder()
12641287
builder.addAllWhen(whenSeq.map(_.get).asJava)
12651288
builder.addAllThen(thenSeq.map(_.get).asJava)
1266-
if (elseValue.isDefined) {
1289+
if (finalElse.isDefined) {
12671290
val elseValueExpr =
1268-
exprToProtoInternal(elseValue.get, inputs, binding)
1291+
exprToProtoInternal(finalElse.get, inputs, binding)
12691292
if (elseValueExpr.isDefined) {
12701293
builder.setElseExpr(elseValueExpr.get)
12711294
} else {
1272-
withInfo(expr, elseValue.get)
1295+
withInfo(expr, finalElse.get)
12731296
return None
12741297
}
12751298
}
@@ -1319,10 +1342,6 @@ object QueryPlanSerde extends Logging with CometExprShim {
13191342
None
13201343
}
13211344

1322-
case a @ Coalesce(_) =>
1323-
val exprChildren = a.children.map(exprToProtoInternal(_, inputs, binding))
1324-
scalarFunctionExprToProto("coalesce", exprChildren: _*)
1325-
13261345
// With Spark 3.4, CharVarcharCodegenUtils.readSidePadding gets called to pad spaces for
13271346
// char types.
13281347
// See https://github.com/apache/spark/pull/38151

spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ import org.apache.spark.sql.types._
4444
import org.apache.comet.CometSparkSessionExtensions.isSpark40Plus
4545

4646
class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
47+
4748
import testImplicits._
4849

4950
override protected def test(testName: String, testTags: Tag*)(testFun: => Any)(implicit
@@ -367,15 +368,15 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
367368
withParquetTable(data, "tbl") {
368369
checkSparkAnswerAndOperator("SELECT try_divide(_1, _2) FROM tbl")
369370
checkSparkAnswerAndOperator("""
370-
|SELECT
371-
| try_divide(10, 0),
372-
| try_divide(NULL, 5),
373-
| try_divide(5, NULL),
374-
| try_divide(-2147483648, -1),
375-
| try_divide(-9223372036854775808, -1),
376-
| try_divide(DECIMAL('9999999999999999999999999999'), 0.1)
377-
| from tbl
378-
|""".stripMargin)
371+
|SELECT
372+
| try_divide(10, 0),
373+
| try_divide(NULL, 5),
374+
| try_divide(5, NULL),
375+
| try_divide(-2147483648, -1),
376+
| try_divide(-9223372036854775808, -1),
377+
| try_divide(DECIMAL('9999999999999999999999999999'), 0.1)
378+
| from tbl
379+
|""".stripMargin)
379380
}
380381
}
381382

@@ -384,13 +385,28 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
384385
withParquetTable(data, "tbl") {
385386
checkSparkAnswerAndOperator("SELECT try_divide(_1, _2) FROM tbl")
386387
checkSparkAnswerAndOperator("""
387-
|SELECT try_divide(-128, -1),
388-
|try_divide(-32768, -1),
389-
|try_divide(-2147483648, -1),
390-
|try_divide(-9223372036854775808, -1),
391-
|try_divide(CAST(99999 AS DECIMAL(5,0)), CAST(0.0001 AS DECIMAL(5,4)))
392-
|from tbl
393-
|""".stripMargin)
388+
|SELECT try_divide(-128, -1),
389+
|try_divide(-32768, -1),
390+
|try_divide(-2147483648, -1),
391+
|try_divide(-9223372036854775808, -1),
392+
|try_divide(CAST(99999 AS DECIMAL(5,0)), CAST(0.0001 AS DECIMAL(5,4)))
393+
|from tbl
394+
|""".stripMargin)
395+
}
396+
}
397+
398+
test("test coalesce lazy eval") {
399+
withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
400+
val data = Seq((100, 0))
401+
withParquetTable(data, "t1") {
402+
val res = spark.sql("""
403+
|SELECT coalesce(_1 , 1/0) from t1;
404+
| """.stripMargin)
405+
406+
res.explain(true)
407+
408+
checkSparkAnswer(res)
409+
}
394410
}
395411
}
396412

0 commit comments

Comments
 (0)