Skip to content

Commit bc21af0

Browse files
davidlghellinalamb
authored andcommitted
fix(spark): handle divide-by-zero in Spark mod/pmod with ANSI mode support (apache#20461)
## Which issue does this PR close? - NA. ## Rationale for this change Spark's `mod` and `pmod` functions return `NULL` on integer division by zero in legacy mode (ANSI off), but DataFusion's implementation always threw a `DivideByZero` error regardless of the ANSI mode setting. ## What changes are included in this PR? Add ANSI mode support to `spark_mod` and `spark_pmod` via enable_ansi_mode config option In legacy mode (ANSI off): division by zero returns `NULL` per-element In ANSI mode (ANSI on): division by zero throws an error (unchanged behavior) Add `try_rem` helper that handles per-element zero-divisor masking for integer arrays ## Are these changes tested? Yes: - 18 unit tests in modulus.rs (including new tests for both ANSI modes) - Updated pmod.slt and mod.slt sqllogictests with ANSI on/off coverage ## Are there any user-facing changes? Yes — mod(10, 0) and pmod(10, 0) now return NULL instead of erroring when enable_ansi_mode = false (the default), matching Spark behavior.
1 parent 7698fdc commit bc21af0

File tree

3 files changed

+164
-31
lines changed

3 files changed

+164
-31
lines changed

datafusion/spark/src/function/math/modulus.rs

Lines changed: 114 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -15,37 +15,75 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use arrow::array::{Scalar, new_null_array};
1819
use arrow::compute::kernels::numeric::add;
19-
use arrow::compute::kernels::{cmp::lt, numeric::rem, zip::zip};
20+
use arrow::compute::kernels::{
21+
cmp::{eq, lt},
22+
numeric::rem,
23+
zip::zip,
24+
};
2025
use arrow::datatypes::DataType;
2126
use datafusion_common::{Result, ScalarValue, assert_eq_or_internal_err};
2227
use datafusion_expr::{
2328
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
2429
};
2530
use std::any::Any;
2631

32+
/// Attempts `rem(left, right)` with per-element divide-by-zero handling.
33+
/// In ANSI mode, any zero divisor causes an error.
34+
/// In legacy mode (ANSI off), positions where the divisor is zero return NULL
35+
/// while other positions compute normally.
36+
fn try_rem(
37+
left: &arrow::array::ArrayRef,
38+
right: &arrow::array::ArrayRef,
39+
enable_ansi_mode: bool,
40+
) -> Result<arrow::array::ArrayRef> {
41+
match rem(left, right) {
42+
Ok(result) => Ok(result),
43+
Err(arrow::error::ArrowError::DivideByZero) if !enable_ansi_mode => {
44+
// Integer rem fails when ANY divisor element is zero.
45+
// Handle per-element: null out zero divisors
46+
let zero = ScalarValue::new_zero(right.data_type())?.to_array()?;
47+
let zero = Scalar::new(zero);
48+
let null = Scalar::new(new_null_array(right.data_type(), 1));
49+
let is_zero = eq(right, &zero)?;
50+
let safe_right = zip(&is_zero, &null, right)?;
51+
Ok(rem(left, &safe_right)?)
52+
}
53+
Err(e) => Err(e.into()),
54+
}
55+
}
56+
2757
/// Spark-compatible `mod` function
28-
/// This function directly uses Arrow's arithmetic_op function for modulo operations
29-
pub fn spark_mod(args: &[ColumnarValue]) -> Result<ColumnarValue> {
58+
/// In ANSI mode, division by zero throws an error.
59+
/// In legacy mode, division by zero returns NULL (Spark behavior).
60+
pub fn spark_mod(
61+
args: &[ColumnarValue],
62+
enable_ansi_mode: bool,
63+
) -> Result<ColumnarValue> {
3064
assert_eq_or_internal_err!(args.len(), 2, "mod expects exactly two arguments");
3165
let args = ColumnarValue::values_to_arrays(args)?;
32-
let result = rem(&args[0], &args[1])?;
66+
let result = try_rem(&args[0], &args[1], enable_ansi_mode)?;
3367
Ok(ColumnarValue::Array(result))
3468
}
3569

3670
/// Spark-compatible `pmod` function
37-
/// This function directly uses Arrow's arithmetic_op function for modulo operations
38-
pub fn spark_pmod(args: &[ColumnarValue]) -> Result<ColumnarValue> {
71+
/// In ANSI mode, division by zero throws an error.
72+
/// In legacy mode, division by zero returns NULL (Spark behavior).
73+
pub fn spark_pmod(
74+
args: &[ColumnarValue],
75+
enable_ansi_mode: bool,
76+
) -> Result<ColumnarValue> {
3977
assert_eq_or_internal_err!(args.len(), 2, "pmod expects exactly two arguments");
4078
let args = ColumnarValue::values_to_arrays(args)?;
4179
let left = &args[0];
4280
let right = &args[1];
4381
let zero = ScalarValue::new_zero(left.data_type())?.to_array_of_size(left.len())?;
44-
let result = rem(left, right)?;
82+
let result = try_rem(left, right, enable_ansi_mode)?;
4583
let neg = lt(&result, &zero)?;
4684
let plus = zip(&neg, right, &zero)?;
4785
let result = add(&plus, &result)?;
48-
let result = rem(&result, right)?;
86+
let result = try_rem(&result, right, enable_ansi_mode)?;
4987
Ok(ColumnarValue::Array(result))
5088
}
5189

@@ -95,7 +133,7 @@ impl ScalarUDFImpl for SparkMod {
95133
}
96134

97135
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
98-
spark_mod(&args.args)
136+
spark_mod(&args.args, args.config_options.execution.enable_ansi_mode)
99137
}
100138
}
101139

@@ -145,7 +183,7 @@ impl ScalarUDFImpl for SparkPmod {
145183
}
146184

147185
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
148-
spark_pmod(&args.args)
186+
spark_pmod(&args.args, args.config_options.execution.enable_ansi_mode)
149187
}
150188
}
151189

@@ -165,7 +203,7 @@ mod test {
165203
let left_value = ColumnarValue::Array(Arc::new(left));
166204
let right_value = ColumnarValue::Array(Arc::new(right));
167205

168-
let result = spark_mod(&[left_value, right_value]).unwrap();
206+
let result = spark_mod(&[left_value, right_value], false).unwrap();
169207

170208
if let ColumnarValue::Array(result_array) = result {
171209
let result_int32 =
@@ -187,7 +225,7 @@ mod test {
187225
let left_value = ColumnarValue::Array(Arc::new(left));
188226
let right_value = ColumnarValue::Array(Arc::new(right));
189227

190-
let result = spark_mod(&[left_value, right_value]).unwrap();
228+
let result = spark_mod(&[left_value, right_value], false).unwrap();
191229

192230
if let ColumnarValue::Array(result_array) = result {
193231
let result_int64 =
@@ -228,7 +266,7 @@ mod test {
228266
let left_value = ColumnarValue::Array(Arc::new(left));
229267
let right_value = ColumnarValue::Array(Arc::new(right));
230268

231-
let result = spark_mod(&[left_value, right_value]).unwrap();
269+
let result = spark_mod(&[left_value, right_value], false).unwrap();
232270

233271
if let ColumnarValue::Array(result_array) = result {
234272
let result_float64 = result_array
@@ -284,7 +322,7 @@ mod test {
284322
let left_value = ColumnarValue::Array(Arc::new(left));
285323
let right_value = ColumnarValue::Array(Arc::new(right));
286324

287-
let result = spark_mod(&[left_value, right_value]).unwrap();
325+
let result = spark_mod(&[left_value, right_value], false).unwrap();
288326

289327
if let ColumnarValue::Array(result_array) = result {
290328
let result_float32 = result_array
@@ -319,7 +357,7 @@ mod test {
319357

320358
let left_value = ColumnarValue::Array(Arc::new(left));
321359

322-
let result = spark_mod(&[left_value, right_value]).unwrap();
360+
let result = spark_mod(&[left_value, right_value], false).unwrap();
323361

324362
if let ColumnarValue::Array(result_array) = result {
325363
let result_int32 =
@@ -337,20 +375,43 @@ mod test {
337375
let left = Int32Array::from(vec![Some(10)]);
338376
let left_value = ColumnarValue::Array(Arc::new(left));
339377

340-
let result = spark_mod(&[left_value]);
378+
let result = spark_mod(&[left_value], false);
341379
assert!(result.is_err());
342380
}
343381

344382
#[test]
345-
fn test_mod_zero_division() {
383+
fn test_mod_zero_division_legacy() {
384+
// In legacy mode (ANSI off), division by zero returns NULL per-element
385+
let left = Int32Array::from(vec![Some(10), Some(7), Some(15)]);
386+
let right = Int32Array::from(vec![Some(0), Some(2), Some(4)]);
387+
388+
let left_value = ColumnarValue::Array(Arc::new(left));
389+
let right_value = ColumnarValue::Array(Arc::new(right));
390+
391+
let result = spark_mod(&[left_value, right_value], false).unwrap();
392+
393+
if let ColumnarValue::Array(result_array) = result {
394+
let result_int32 =
395+
result_array.as_any().downcast_ref::<Int32Array>().unwrap();
396+
assert!(result_int32.is_null(0)); // 10 % 0 = NULL
397+
assert_eq!(result_int32.value(1), 1); // 7 % 2 = 1
398+
assert_eq!(result_int32.value(2), 3); // 15 % 4 = 3
399+
} else {
400+
panic!("Expected array result");
401+
}
402+
}
403+
404+
#[test]
405+
fn test_mod_zero_division_ansi() {
406+
// In ANSI mode, division by zero should error
346407
let left = Int32Array::from(vec![Some(10), Some(7), Some(15)]);
347408
let right = Int32Array::from(vec![Some(0), Some(2), Some(4)]);
348409

349410
let left_value = ColumnarValue::Array(Arc::new(left));
350411
let right_value = ColumnarValue::Array(Arc::new(right));
351412

352-
let result = spark_mod(&[left_value, right_value]);
353-
assert!(result.is_err()); // Division by zero should error
413+
let result = spark_mod(&[left_value, right_value], true);
414+
assert!(result.is_err());
354415
}
355416

356417
// PMOD tests
@@ -362,7 +423,7 @@ mod test {
362423
let left_value = ColumnarValue::Array(Arc::new(left));
363424
let right_value = ColumnarValue::Array(Arc::new(right));
364425

365-
let result = spark_pmod(&[left_value, right_value]).unwrap();
426+
let result = spark_pmod(&[left_value, right_value], false).unwrap();
366427

367428
if let ColumnarValue::Array(result_array) = result {
368429
let result_int32 =
@@ -385,7 +446,7 @@ mod test {
385446
let left_value = ColumnarValue::Array(Arc::new(left));
386447
let right_value = ColumnarValue::Array(Arc::new(right));
387448

388-
let result = spark_pmod(&[left_value, right_value]).unwrap();
449+
let result = spark_pmod(&[left_value, right_value], false).unwrap();
389450

390451
if let ColumnarValue::Array(result_array) = result {
391452
let result_int64 =
@@ -425,7 +486,7 @@ mod test {
425486
let left_value = ColumnarValue::Array(Arc::new(left));
426487
let right_value = ColumnarValue::Array(Arc::new(right));
427488

428-
let result = spark_pmod(&[left_value, right_value]).unwrap();
489+
let result = spark_pmod(&[left_value, right_value], false).unwrap();
429490

430491
if let ColumnarValue::Array(result_array) = result {
431492
let result_float64 = result_array
@@ -476,7 +537,7 @@ mod test {
476537
let left_value = ColumnarValue::Array(Arc::new(left));
477538
let right_value = ColumnarValue::Array(Arc::new(right));
478539

479-
let result = spark_pmod(&[left_value, right_value]).unwrap();
540+
let result = spark_pmod(&[left_value, right_value], false).unwrap();
480541

481542
if let ColumnarValue::Array(result_array) = result {
482543
let result_float32 = result_array
@@ -508,7 +569,7 @@ mod test {
508569

509570
let left_value = ColumnarValue::Array(Arc::new(left));
510571

511-
let result = spark_pmod(&[left_value, right_value]).unwrap();
572+
let result = spark_pmod(&[left_value, right_value], false).unwrap();
512573

513574
if let ColumnarValue::Array(result_array) = result {
514575
let result_int32 =
@@ -527,20 +588,43 @@ mod test {
527588
let left = Int32Array::from(vec![Some(10)]);
528589
let left_value = ColumnarValue::Array(Arc::new(left));
529590

530-
let result = spark_pmod(&[left_value]);
591+
let result = spark_pmod(&[left_value], false);
531592
assert!(result.is_err());
532593
}
533594

534595
#[test]
535-
fn test_pmod_zero_division() {
596+
fn test_pmod_zero_division_legacy() {
597+
// In legacy mode (ANSI off), division by zero returns NULL per-element
536598
let left = Int32Array::from(vec![Some(10), Some(-7), Some(15)]);
537599
let right = Int32Array::from(vec![Some(0), Some(0), Some(4)]);
538600

539601
let left_value = ColumnarValue::Array(Arc::new(left));
540602
let right_value = ColumnarValue::Array(Arc::new(right));
541603

542-
let result = spark_pmod(&[left_value, right_value]);
543-
assert!(result.is_err()); // Division by zero should error
604+
let result = spark_pmod(&[left_value, right_value], false).unwrap();
605+
606+
if let ColumnarValue::Array(result_array) = result {
607+
let result_int32 =
608+
result_array.as_any().downcast_ref::<Int32Array>().unwrap();
609+
assert!(result_int32.is_null(0)); // 10 pmod 0 = NULL
610+
assert!(result_int32.is_null(1)); // -7 pmod 0 = NULL
611+
assert_eq!(result_int32.value(2), 3); // 15 pmod 4 = 3
612+
} else {
613+
panic!("Expected array result");
614+
}
615+
}
616+
617+
#[test]
618+
fn test_pmod_zero_division_ansi() {
619+
// In ANSI mode, division by zero should error
620+
let left = Int32Array::from(vec![Some(10), Some(-7), Some(15)]);
621+
let right = Int32Array::from(vec![Some(0), Some(0), Some(4)]);
622+
623+
let left_value = ColumnarValue::Array(Arc::new(left));
624+
let right_value = ColumnarValue::Array(Arc::new(right));
625+
626+
let result = spark_pmod(&[left_value, right_value], true);
627+
assert!(result.is_err());
544628
}
545629

546630
#[test]
@@ -552,7 +636,7 @@ mod test {
552636
let left_value = ColumnarValue::Array(Arc::new(left));
553637
let right_value = ColumnarValue::Array(Arc::new(right));
554638

555-
let result = spark_pmod(&[left_value, right_value]).unwrap();
639+
let result = spark_pmod(&[left_value, right_value], false).unwrap();
556640

557641
if let ColumnarValue::Array(result_array) = result {
558642
let result_int32 =
@@ -590,7 +674,7 @@ mod test {
590674
let left_value = ColumnarValue::Array(Arc::new(left));
591675
let right_value = ColumnarValue::Array(Arc::new(right));
592676

593-
let result = spark_pmod(&[left_value, right_value]).unwrap();
677+
let result = spark_pmod(&[left_value, right_value], false).unwrap();
594678

595679
if let ColumnarValue::Array(result_array) = result {
596680
let result_int32 =

datafusion/sqllogictest/test_files/spark/math/mod.slt

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,35 @@ SELECT MOD(10.0::decimal(3,1), 3.0::decimal(2,1)) as mod_decimal_2;
144144
----
145145
1
146146

147+
# Division by zero returns NULL in legacy mode (ANSI off)
148+
query I
149+
SELECT MOD(10::int, 0::int) as mod_div_zero_1;
150+
----
151+
NULL
152+
153+
query I
154+
SELECT MOD(-7::int, 0::int) as mod_div_zero_2;
155+
----
156+
NULL
157+
158+
query R
159+
SELECT MOD(10.5::float8, 0.0::float8) as mod_div_zero_float;
160+
----
161+
NaN
162+
163+
# Division by zero errors in ANSI mode
164+
statement ok
165+
set datafusion.execution.enable_ansi_mode = true;
166+
167+
statement error DataFusion error: Arrow error: Divide by zero error
168+
SELECT MOD(10::int, 0::int);
169+
170+
statement error DataFusion error: Arrow error: Divide by zero error
171+
SELECT MOD(-7::int, 0::int);
172+
173+
statement ok
174+
set datafusion.execution.enable_ansi_mode = false;
175+
147176
# Edge cases
148177
query I
149178
SELECT MOD(0::int, 5::int) as mod_zero_1;

datafusion/sqllogictest/test_files/spark/math/pmod.slt

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,28 @@ SELECT pmod(0::int, 5::int) as pmod_zero_1;
6464
----
6565
0
6666

67-
statement error DataFusion error: Arrow error: Divide by zero error
67+
query I
6868
SELECT pmod(10::int, 0::int) as pmod_zero_2;
69+
----
70+
NULL
71+
72+
query I
73+
SELECT pmod(-7::int, 0::int) as pmod_zero_3;
74+
----
75+
NULL
76+
77+
# Division by zero errors in ANSI mode
78+
statement ok
79+
set datafusion.execution.enable_ansi_mode = true;
80+
81+
statement error DataFusion error: Arrow error: Divide by zero error
82+
SELECT pmod(10::int, 0::int);
83+
84+
statement error DataFusion error: Arrow error: Divide by zero error
85+
SELECT pmod(-7::int, 0::int);
86+
87+
statement ok
88+
set datafusion.execution.enable_ansi_mode = false;
6989

7090
# PMOD tests with NULL values
7191
query I

0 commit comments

Comments
 (0)