Skip to content

Commit a5c19d9

Browse files
committed
[VL] Support ANSI mode decimal Add/Subtract with checked overflow
Routes decimal `Add` and `Subtract` to Velox's `checked_add` and `checked_subtract` functions when ANSI mode is enabled (`nullOnOverflow = false`). These checked variants throw on overflow instead of returning null, matching Spark's ANSI behavior. Depends on facebookincubator/velox#16302 which adds `checked_add` and `checked_subtract` support for decimal types.
1 parent fbcd3e7 commit a5c19d9

File tree

3 files changed

+65
-2
lines changed

3 files changed

+65
-2
lines changed

backends-velox/src/test/scala/org/apache/gluten/execution/VeloxLiteralSuite.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,4 +139,5 @@ class VeloxLiteralSuite extends VeloxWholeStageTransformerSuite {
139139
validateFallbackResult("SELECT struct(cast(null as struct<a: string>))")
140140
validateFallbackResult("SELECT array(struct(1, 'a'), null)")
141141
}
142+
142143
}

backends-velox/src/test/scala/org/apache/gluten/functions/ArithmeticAnsiValidateSuite.scala

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,4 +100,46 @@ class ArithmeticAnsiValidateSuite extends FunctionsValidateSuite {
100100
}
101101
}
102102

103+
test("decimal add overflow") {
104+
// Normal decimal add should succeed and match Spark results
105+
runQueryAndCompare(
106+
"SELECT CAST(1.0 AS DECIMAL(10,2)) + CAST(2.0 AS DECIMAL(10,2))") {
107+
checkGlutenPlan[ProjectExecTransformer]
108+
}
109+
110+
// Overflow: max DECIMAL(38,0) + 1 should throw in ANSI mode
111+
if (isSparkVersionGE("4.0")) {
112+
intercept[SparkException] {
113+
sql("SELECT CAST(99999999999999999999999999999999999999 AS DECIMAL(38,0)) + " +
114+
"CAST(1 AS DECIMAL(38,0))").collect()
115+
}
116+
} else {
117+
intercept[ArithmeticException] {
118+
sql("SELECT CAST(99999999999999999999999999999999999999 AS DECIMAL(38,0)) + " +
119+
"CAST(1 AS DECIMAL(38,0))").collect()
120+
}
121+
}
122+
}
123+
124+
test("decimal subtract overflow") {
125+
// Normal decimal subtract should succeed and match Spark results
126+
runQueryAndCompare(
127+
"SELECT CAST(5.0 AS DECIMAL(10,2)) - CAST(2.0 AS DECIMAL(10,2))") {
128+
checkGlutenPlan[ProjectExecTransformer]
129+
}
130+
131+
// Overflow: -max DECIMAL(38,0) - 1 should throw in ANSI mode
132+
if (isSparkVersionGE("4.0")) {
133+
intercept[SparkException] {
134+
sql("SELECT CAST(-99999999999999999999999999999999999999 AS DECIMAL(38,0)) - " +
135+
"CAST(1 AS DECIMAL(38,0))").collect()
136+
}
137+
} else {
138+
intercept[ArithmeticException] {
139+
sql("SELECT CAST(-99999999999999999999999999999999999999 AS DECIMAL(38,0)) - " +
140+
"CAST(1 AS DECIMAL(38,0))").collect()
141+
}
142+
}
143+
}
144+
103145
}

gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -621,17 +621,37 @@ object ExpressionConverter extends SQLConfHelper with Logging {
621621
substraitExprName,
622622
expr.children.map(replaceWithExpressionTransformer0(_, attributeSeq, expressionsMap)),
623623
expr)
624-
case CheckOverflow(b: BinaryArithmetic, decimalType, _)
624+
case CheckOverflow(b: BinaryArithmetic, decimalType, nullOnOverflow)
625625
if !BackendsApiManager.getSettings.transformCheckOverflow &&
626626
DecimalArithmeticUtil.isDecimalArithmetic(b) =>
627-
val arithmeticExprName =
627+
val baseExprName =
628628
BackendsApiManager.getSparkPlanExecApiInstance.getDecimalArithmeticExprName(
629629
getAndCheckSubstraitName(b, expressionsMap))
630+
// When nullOnOverflow is false, it's ANSI mode - use checked_ prefix for overflow errors
631+
val arithmeticExprName = if (!nullOnOverflow) {
632+
"checked_" + baseExprName
633+
} else {
634+
baseExprName
635+
}
630636
val left =
631637
replaceWithExpressionTransformer0(b.left, attributeSeq, expressionsMap)
632638
val right =
633639
replaceWithExpressionTransformer0(b.right, attributeSeq, expressionsMap)
634640
DecimalArithmeticExpressionTransformer(arithmeticExprName, left, right, decimalType, b)
641+
// Velox path: ANSI mode decimal Add/Subtract uses checked_ variants
642+
// that throw on overflow instead of returning null.
643+
case c @ CheckOverflow(b: BinaryArithmetic, _, nullOnOverflow)
644+
if BackendsApiManager.getSettings.transformCheckOverflow &&
645+
DecimalArithmeticUtil.isDecimalArithmetic(b) &&
646+
!nullOnOverflow &&
647+
(b.isInstanceOf[Add] || b.isInstanceOf[Subtract]) =>
648+
val baseExprName =
649+
BackendsApiManager.getSparkPlanExecApiInstance.getDecimalArithmeticExprName(
650+
getAndCheckSubstraitName(b, expressionsMap))
651+
val checkedExprName = "checked_" + baseExprName
652+
val childTransformer =
653+
genRescaleDecimalTransformer(checkedExprName, b, attributeSeq, expressionsMap)
654+
CheckOverflowTransformer(substraitExprName, childTransformer, c)
635655
case c: CheckOverflow =>
636656
CheckOverflowTransformer(
637657
substraitExprName,

0 commit comments

Comments
 (0)