Skip to content

Commit bba1c1a

Browse files
committed
support_ansi_mode_remainder_function
1 parent 8b8a27b commit bba1c1a

File tree

2 files changed

+28
-42
lines changed

2 files changed

+28
-42
lines changed

native/spark-expr/src/math_funcs/round.rs

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,34 +30,25 @@ macro_rules! integer_round {
3030
let rem = $X % $DIV;
3131
if rem <= -$HALF {
3232
if $FAIL_ON_ERROR {
33-
match ($X - rem).sub_checked($DIV) {
34-
Ok(v) => Ok(v),
35-
Err(_e) => Err(ArrowError::ComputeError(
36-
arithmetic_overflow_error("integer").to_string(),
37-
)),
38-
}
33+
($X - rem).sub_checked($DIV).map_err(|_| {
34+
ArrowError::ComputeError(arithmetic_overflow_error("integer").to_string())
35+
})
3936
} else {
4037
Ok(($X - rem).sub_wrapping($DIV))
4138
}
4239
} else if rem >= $HALF {
4340
if $FAIL_ON_ERROR {
44-
match ($X - rem).add_checked($DIV) {
45-
Ok(v) => Ok(v),
46-
Err(_e) => Err(ArrowError::ComputeError(
47-
arithmetic_overflow_error("integer").to_string(),
48-
)),
49-
}
41+
($X - rem).add_checked($DIV).map_err(|_| {
42+
ArrowError::ComputeError(arithmetic_overflow_error("integer").to_string())
43+
})
5044
} else {
5145
Ok(($X - rem).add_wrapping($DIV))
5246
}
5347
} else {
5448
if $FAIL_ON_ERROR {
55-
match $X.sub_checked(rem) {
56-
Ok(v) => Ok(v),
57-
Err(_e) => Err(ArrowError::ComputeError(
58-
arithmetic_overflow_error("integer").to_string(),
59-
)),
60-
}
49+
$X.sub_checked(rem).map_err(|_| {
50+
ArrowError::ComputeError(arithmetic_overflow_error("integer").to_string())
51+
})
6152
} else {
6253
Ok($X.sub_wrapping(rem))
6354
}

spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3018,30 +3018,25 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
30183018
}
30193019

30203020
test("ANSI support for round function") {
3021-
val data = Seq((Integer.MAX_VALUE, Integer.MIN_VALUE, Long.MinValue, Long.MaxValue))
3022-
Seq("true", "false").foreach { p =>
3023-
withSQLConf(SQLConf.ANSI_ENABLED.key -> p) {
3024-
withParquetTable(data, "tbl") {
3025-
val res = spark.sql(s"""
3026-
|SELECT
3027-
| round(_1, -1) ,
3028-
| round(_1, -10) ,
3029-
| round(${Int.MaxValue}, -10)
3030-
| from tbl
3031-
| """.stripMargin)
3032-
3033-
checkSparkMaybeThrows(res) match {
3034-
case (Some(sparkException), Some(cometException)) =>
3035-
assert(sparkException.getMessage.contains("ARITHMETIC_OVERFLOW"))
3036-
assert(cometException.getMessage.contains("ARITHMETIC_OVERFLOW"))
3037-
case (None, None) => checkSparkAnswerAndOperator(res)
3038-
case (None, Some(ex)) =>
3039-
fail(
3040-
"Comet threw an exception but Spark did not. Comet exception: " + ex.getMessage)
3041-
case (Some(sparkException), None) =>
3042-
fail(
3043-
"Spark threw an exception but Comet did not. Spark exception: " +
3044-
sparkException.getMessage)
3021+
Seq((Integer.MAX_VALUE, Integer.MIN_VALUE, Long.MinValue, Long.MaxValue)).foreach { value =>
3022+
val data = Seq(value)
3023+
withParquetTable(data, "tbl") {
3024+
Seq(-1000, -100, -10, -1, 0, 1, 10, 100, 1000).foreach { scale =>
3025+
Seq(true, false).foreach { ansi =>
3026+
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansi.toString) {
3027+
val res = spark.sql(s"SELECT round(_1, $scale) from tbl")
3028+
checkSparkMaybeThrows(res) match {
3029+
case (Some(sparkException), Some(cometException)) =>
3030+
assert(sparkException.getMessage.contains("ARITHMETIC_OVERFLOW"))
3031+
assert(cometException.getMessage.contains("ARITHMETIC_OVERFLOW"))
3032+
case (None, None) => checkSparkAnswerAndOperator(res)
3033+
case (None, Some(ex)) =>
3034+
fail("Comet threw an exception but Spark did not. Comet exception: " + ex.getMessage)
3035+
case (Some(sparkException), None) =>
3036+
fail("Spark threw an exception but Comet did not. Spark exception: " +
3037+
sparkException.getMessage)
3038+
}
3039+
}
30453040
}
30463041
}
30473042
}

0 commit comments

Comments
 (0)