Skip to content

Commit 30093e3

Browse files
authored
fix: check overflow for decimal integral division (#1512)
1 parent 2f8bb14 commit 30093e3

File tree

2 files changed

+40
-1
lines changed

2 files changed

+40
-1
lines changed

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -655,8 +655,23 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
655655
(builder, mathExpr) => builder.setIntegralDivide(mathExpr))
656656

657657
if (divideExpr.isDefined) {
658+
val childExpr = if (dataType.isInstanceOf[DecimalType]) {
659+
// check overflow for decimal type
660+
val builder = ExprOuterClass.CheckOverflow.newBuilder()
661+
builder.setChild(divideExpr.get)
662+
builder.setFailOnError(getFailOnError(div))
663+
builder.setDatatype(serializeDataType(dataType).get)
664+
Some(
665+
ExprOuterClass.Expr
666+
.newBuilder()
667+
.setCheckOverflow(builder)
668+
.build())
669+
} else {
670+
divideExpr
671+
}
672+
658673
// cast result to long
659-
castToProto(expr, None, LongType, divideExpr.get, CometEvalMode.LEGACY)
674+
castToProto(expr, None, LongType, childExpr.get, CometEvalMode.LEGACY)
660675
} else {
661676
None
662677
}

spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2734,4 +2734,28 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
27342734
}
27352735
}
27362736

2737+
test("test integral divide overflow for decimal") {
2738+
// decimal support requires Spark 3.4 or later
2739+
assume(isSpark34Plus)
2740+
if (isSpark40Plus) {
2741+
Seq(true, false)
2742+
} else
2743+
{
2744+
// ansi mode only supported in Spark 4.0+
2745+
Seq(false)
2746+
}.foreach { ansiMode =>
2747+
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiMode.toString) {
2748+
withTable("t1") {
2749+
sql("create table t1(a decimal(38,0), b decimal(2,2)) using parquet")
2750+
sql(
2751+
"insert into t1 values(-62672277069777110394022909049981876593,-0.40)," +
2752+
" (-68299431870253176399167726913574455270,-0.22), (-77532633078952291817347741106477071062,0.36)," +
2753+
" (-79918484954351746825313746420585672848,0.44), (54400354300704342908577384819323710194,0.18)," +
2754+
" (78585488402645143056239590008272527352,-0.51)")
2755+
checkSparkAnswerAndOperator("select a div b from t1")
2756+
}
2757+
}
2758+
}
2759+
}
2760+
27372761
}

0 commit comments

Comments
 (0)