Skip to content

Commit b1af9e3

Browse files
committed
impl_ansi_integral_divide_remainder
1 parent fa016e3 commit b1af9e3

File tree

5 files changed

+67
-26
lines changed

5 files changed

+67
-26
lines changed

native/core/src/execution/planner.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -283,8 +283,6 @@ impl PhysicalPlanner {
283283
)
284284
}
285285
ExprStruct::IntegralDivide(expr) => {
286-
// TODO respect eval mode
287-
// https://github.com/apache/datafusion-comet/issues/533
288286
let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
289287
self.create_binary_expr_with_options(
290288
expr.left.as_ref().unwrap(),

native/spark-expr/src/comet_scalar_funcs.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,10 +136,12 @@ pub fn create_comet_physical_fun_with_eval_mode(
136136
make_comet_scalar_udf!("decimal_div", spark_decimal_div, data_type)
137137
}
138138
"decimal_integral_div" => {
139+
let is_ansi_div = eval_mode == EvalMode::Ansi;
139140
make_comet_scalar_udf!(
140141
"decimal_integral_div",
141142
spark_decimal_integral_div,
142-
data_type
143+
data_type,
144+
is_ansi_div
143145
)
144146
}
145147
"checked_add" => {

native/spark-expr/src/math_funcs/div.rs

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,30 +15,33 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use crate::divide_by_zero_error;
1819
use crate::math_funcs::utils::get_precision_scale;
1920
use arrow::array::{Array, Decimal128Array};
2021
use arrow::datatypes::{DataType, DECIMAL128_MAX_PRECISION};
22+
use arrow::error::ArrowError;
2123
use arrow::{
2224
array::{ArrayRef, AsArray},
2325
datatypes::Decimal128Type,
2426
};
2527
use datafusion::common::DataFusionError;
2628
use datafusion::physical_plan::ColumnarValue;
27-
use num::{BigInt, Signed, ToPrimitive};
29+
use num::{BigInt, Signed, ToPrimitive, Zero};
2830
use std::sync::Arc;
2931

3032
pub fn spark_decimal_div(
3133
args: &[ColumnarValue],
3234
data_type: &DataType,
3335
) -> Result<ColumnarValue, DataFusionError> {
34-
spark_decimal_div_internal(args, data_type, false)
36+
spark_decimal_div_internal(args, data_type, false, false)
3537
}
3638

3739
pub fn spark_decimal_integral_div(
3840
args: &[ColumnarValue],
3941
data_type: &DataType,
42+
is_ansi_div: bool,
4043
) -> Result<ColumnarValue, DataFusionError> {
41-
spark_decimal_div_internal(args, data_type, true)
44+
spark_decimal_div_internal(args, data_type, true, is_ansi_div)
4245
}
4346

4447
// Let Decimal(p3, s3) as return type i.e. Decimal(p1, s1) / Decimal(p2, s2) = Decimal(p3, s3).
@@ -50,6 +53,7 @@ fn spark_decimal_div_internal(
5053
args: &[ColumnarValue],
5154
data_type: &DataType,
5255
is_integral_div: bool,
56+
is_ansi_div: bool,
5357
) -> Result<ColumnarValue, DataFusionError> {
5458
let left = &args[0];
5559
let right = &args[1];
@@ -80,9 +84,12 @@ fn spark_decimal_div_internal(
8084
let r_mul = ten.pow(r_exp);
8185
let five = BigInt::from(5);
8286
let zero = BigInt::from(0);
83-
arrow::compute::kernels::arity::binary(left, right, |l, r| {
87+
arrow::compute::kernels::arity::try_binary(left, right, |l, r| {
8488
let l = BigInt::from(l) * &l_mul;
8589
let r = BigInt::from(r) * &r_mul;
90+
if is_integral_div && is_ansi_div && r.is_zero() {
91+
return Err(ArrowError::ComputeError(divide_by_zero_error().to_string()));
92+
}
8693
let div = if r.eq(&zero) { zero.clone() } else { &l / &r };
8794
let res = if is_integral_div {
8895
div
@@ -91,14 +98,17 @@ fn spark_decimal_div_internal(
9198
} else {
9299
div + &five
93100
} / &ten;
94-
res.to_i128().unwrap_or(i128::MAX)
101+
Ok(res.to_i128().unwrap_or(i128::MAX))
95102
})?
96103
} else {
97104
let l_mul = 10_i128.pow(l_exp);
98105
let r_mul = 10_i128.pow(r_exp);
99-
arrow::compute::kernels::arity::binary(left, right, |l, r| {
106+
arrow::compute::kernels::arity::try_binary(left, right, |l, r| {
100107
let l = l * l_mul;
101108
let r = r * r_mul;
109+
if is_integral_div && is_ansi_div && r.is_zero() {
110+
return Err(ArrowError::ComputeError(divide_by_zero_error().to_string()));
111+
}
102112
let div = if r == 0 { 0 } else { l / r };
103113
let res = if is_integral_div {
104114
div
@@ -107,7 +117,7 @@ fn spark_decimal_div_internal(
107117
} else {
108118
div + 5
109119
} / 10;
110-
res.to_i128().unwrap_or(i128::MAX)
120+
Ok(res.to_i128().unwrap_or(i128::MAX))
111121
})?
112122
};
113123
let result = result.with_data_type(DataType::Decimal128(p3, s3));

spark/src/main/scala/org/apache/comet/serde/arithmetic.scala

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -180,18 +180,11 @@ object CometDivide extends CometExpressionSerde[Divide] with MathBase {
180180

181181
object CometIntegralDivide extends CometExpressionSerde[IntegralDivide] with MathBase {
182182

183-
override def getSupportLevel(expr: IntegralDivide): SupportLevel = {
184-
if (expr.evalMode == EvalMode.ANSI) {
185-
Incompatible(Some("ANSI mode is not supported"))
186-
} else {
187-
Compatible(None)
188-
}
189-
}
190-
191183
override def convert(
192184
expr: IntegralDivide,
193185
inputs: Seq[Attribute],
194186
binding: Boolean): Option[ExprOuterClass.Expr] = {
187+
195188
if (!supportedDataType(expr.left.dataType)) {
196189
withInfo(expr, s"Unsupported datatype ${expr.left.dataType}")
197190
return None
@@ -252,14 +245,6 @@ object CometIntegralDivide extends CometExpressionSerde[IntegralDivide] with Mat
252245

253246
object CometRemainder extends CometExpressionSerde[Remainder] with MathBase {
254247

255-
override def getSupportLevel(expr: Remainder): SupportLevel = {
256-
if (expr.evalMode == EvalMode.ANSI) {
257-
Incompatible(Some("ANSI mode is not supported"))
258-
} else {
259-
Compatible(None)
260-
}
261-
}
262-
263248
override def convert(
264249
expr: Remainder,
265250
inputs: Seq[Attribute],

spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,52 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
479479
}
480480
}
481481

482+
test("ANSI support for integral divide (division by zero)") {
483+
// TODO : Support ANSI mode in Integral divide
484+
val data = Seq((Integer.MIN_VALUE, 0))
485+
withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
486+
withParquetTable(data, "tbl") {
487+
val res = spark.sql("""
488+
|SELECT
489+
| _1 div _2
490+
| from tbl
491+
| """.stripMargin)
492+
493+
checkSparkMaybeThrows(res) match {
494+
case (Some(sparkExc), Some(cometExc)) =>
495+
val cometErrorPattern =
496+
"""[DIVIDE_BY_ZERO] Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead"""
497+
assert(cometExc.getMessage.contains(cometErrorPattern))
498+
assert(sparkExc.getMessage.contains("Division by zero"))
499+
case _ => fail("Exception should be thrown")
500+
}
501+
}
502+
}
503+
}
504+
505+
test("ANSI support for remainder (division by zero)") {
506+
// TODO : Support ANSI mode in Integral divide
507+
val data = Seq((Integer.MIN_VALUE, 0))
508+
withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
509+
withParquetTable(data, "tbl") {
510+
val res = spark.sql("""
511+
|SELECT
512+
| mod(_1,_2)
513+
| from tbl
514+
| """.stripMargin)
515+
516+
checkSparkMaybeThrows(res) match {
517+
case (Some(sparkExc), Some(cometExc)) =>
518+
val cometErrorPattern =
519+
"""[DIVIDE_BY_ZERO] Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead"""
520+
assert(cometExc.getMessage.contains(cometErrorPattern))
521+
assert(sparkExc.getMessage.contains("Division by zero"))
522+
case _ => fail("Exception should be thrown")
523+
}
524+
}
525+
}
526+
}
527+
482528
test("Verify coalesce performs lazy evaluation") {
483529
val data = Seq((Integer.MAX_VALUE, 9999999999999L))
484530
withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {

0 commit comments

Comments
 (0)