1515// specific language governing permissions and limitations
1616// under the License.
1717
18+ use arrow:: array:: { Scalar , new_null_array} ;
1819use arrow:: compute:: kernels:: numeric:: add;
19- use arrow:: compute:: kernels:: { cmp:: lt, numeric:: rem, zip:: zip} ;
20+ use arrow:: compute:: kernels:: {
21+ cmp:: { eq, lt} ,
22+ numeric:: rem,
23+ zip:: zip,
24+ } ;
2025use arrow:: datatypes:: DataType ;
2126use datafusion_common:: { Result , ScalarValue , assert_eq_or_internal_err} ;
2227use datafusion_expr:: {
2328 ColumnarValue , ScalarFunctionArgs , ScalarUDFImpl , Signature , Volatility ,
2429} ;
2530use std:: any:: Any ;
2631
32+ /// Attempts `rem(left, right)` with per-element divide-by-zero handling.
33+ /// In ANSI mode, any zero divisor causes an error.
34+ /// In legacy mode (ANSI off), positions where the divisor is zero return NULL
35+ /// while other positions compute normally.
36+ fn try_rem (
37+ left : & arrow:: array:: ArrayRef ,
38+ right : & arrow:: array:: ArrayRef ,
39+ enable_ansi_mode : bool ,
40+ ) -> Result < arrow:: array:: ArrayRef > {
41+ match rem ( left, right) {
42+ Ok ( result) => Ok ( result) ,
43+ Err ( arrow:: error:: ArrowError :: DivideByZero ) if !enable_ansi_mode => {
44+ // Integer rem fails when ANY divisor element is zero.
45+ // Handle per-element: null out zero divisors
46+ let zero = ScalarValue :: new_zero ( right. data_type ( ) ) ?. to_array ( ) ?;
47+ let zero = Scalar :: new ( zero) ;
48+ let null = Scalar :: new ( new_null_array ( right. data_type ( ) , 1 ) ) ;
49+ let is_zero = eq ( right, & zero) ?;
50+ let safe_right = zip ( & is_zero, & null, right) ?;
51+ Ok ( rem ( left, & safe_right) ?)
52+ }
53+ Err ( e) => Err ( e. into ( ) ) ,
54+ }
55+ }
56+
2757/// Spark-compatible `mod` function
28- /// This function directly uses Arrow's arithmetic_op function for modulo operations
29- pub fn spark_mod ( args : & [ ColumnarValue ] ) -> Result < ColumnarValue > {
58+ /// In ANSI mode, division by zero throws an error.
59+ /// In legacy mode, division by zero returns NULL (Spark behavior).
60+ pub fn spark_mod (
61+ args : & [ ColumnarValue ] ,
62+ enable_ansi_mode : bool ,
63+ ) -> Result < ColumnarValue > {
3064 assert_eq_or_internal_err ! ( args. len( ) , 2 , "mod expects exactly two arguments" ) ;
3165 let args = ColumnarValue :: values_to_arrays ( args) ?;
32- let result = rem ( & args[ 0 ] , & args[ 1 ] ) ?;
66+ let result = try_rem ( & args[ 0 ] , & args[ 1 ] , enable_ansi_mode ) ?;
3367 Ok ( ColumnarValue :: Array ( result) )
3468}
3569
3670/// Spark-compatible `pmod` function
37- /// This function directly uses Arrow's arithmetic_op function for modulo operations
38- pub fn spark_pmod ( args : & [ ColumnarValue ] ) -> Result < ColumnarValue > {
71+ /// In ANSI mode, division by zero throws an error.
72+ /// In legacy mode, division by zero returns NULL (Spark behavior).
73+ pub fn spark_pmod (
74+ args : & [ ColumnarValue ] ,
75+ enable_ansi_mode : bool ,
76+ ) -> Result < ColumnarValue > {
3977 assert_eq_or_internal_err ! ( args. len( ) , 2 , "pmod expects exactly two arguments" ) ;
4078 let args = ColumnarValue :: values_to_arrays ( args) ?;
4179 let left = & args[ 0 ] ;
4280 let right = & args[ 1 ] ;
4381 let zero = ScalarValue :: new_zero ( left. data_type ( ) ) ?. to_array_of_size ( left. len ( ) ) ?;
44- let result = rem ( left, right) ?;
82+ let result = try_rem ( left, right, enable_ansi_mode ) ?;
4583 let neg = lt ( & result, & zero) ?;
4684 let plus = zip ( & neg, right, & zero) ?;
4785 let result = add ( & plus, & result) ?;
48- let result = rem ( & result, right) ?;
86+ let result = try_rem ( & result, right, enable_ansi_mode ) ?;
4987 Ok ( ColumnarValue :: Array ( result) )
5088}
5189
@@ -95,7 +133,7 @@ impl ScalarUDFImpl for SparkMod {
95133 }
96134
97135 fn invoke_with_args ( & self , args : ScalarFunctionArgs ) -> Result < ColumnarValue > {
98- spark_mod ( & args. args )
136+ spark_mod ( & args. args , args . config_options . execution . enable_ansi_mode )
99137 }
100138}
101139
@@ -145,7 +183,7 @@ impl ScalarUDFImpl for SparkPmod {
145183 }
146184
147185 fn invoke_with_args ( & self , args : ScalarFunctionArgs ) -> Result < ColumnarValue > {
148- spark_pmod ( & args. args )
186+ spark_pmod ( & args. args , args . config_options . execution . enable_ansi_mode )
149187 }
150188}
151189
@@ -165,7 +203,7 @@ mod test {
165203 let left_value = ColumnarValue :: Array ( Arc :: new ( left) ) ;
166204 let right_value = ColumnarValue :: Array ( Arc :: new ( right) ) ;
167205
168- let result = spark_mod ( & [ left_value, right_value] ) . unwrap ( ) ;
206+ let result = spark_mod ( & [ left_value, right_value] , false ) . unwrap ( ) ;
169207
170208 if let ColumnarValue :: Array ( result_array) = result {
171209 let result_int32 =
@@ -187,7 +225,7 @@ mod test {
187225 let left_value = ColumnarValue :: Array ( Arc :: new ( left) ) ;
188226 let right_value = ColumnarValue :: Array ( Arc :: new ( right) ) ;
189227
190- let result = spark_mod ( & [ left_value, right_value] ) . unwrap ( ) ;
228+ let result = spark_mod ( & [ left_value, right_value] , false ) . unwrap ( ) ;
191229
192230 if let ColumnarValue :: Array ( result_array) = result {
193231 let result_int64 =
@@ -228,7 +266,7 @@ mod test {
228266 let left_value = ColumnarValue :: Array ( Arc :: new ( left) ) ;
229267 let right_value = ColumnarValue :: Array ( Arc :: new ( right) ) ;
230268
231- let result = spark_mod ( & [ left_value, right_value] ) . unwrap ( ) ;
269+ let result = spark_mod ( & [ left_value, right_value] , false ) . unwrap ( ) ;
232270
233271 if let ColumnarValue :: Array ( result_array) = result {
234272 let result_float64 = result_array
@@ -284,7 +322,7 @@ mod test {
284322 let left_value = ColumnarValue :: Array ( Arc :: new ( left) ) ;
285323 let right_value = ColumnarValue :: Array ( Arc :: new ( right) ) ;
286324
287- let result = spark_mod ( & [ left_value, right_value] ) . unwrap ( ) ;
325+ let result = spark_mod ( & [ left_value, right_value] , false ) . unwrap ( ) ;
288326
289327 if let ColumnarValue :: Array ( result_array) = result {
290328 let result_float32 = result_array
@@ -319,7 +357,7 @@ mod test {
319357
320358 let left_value = ColumnarValue :: Array ( Arc :: new ( left) ) ;
321359
322- let result = spark_mod ( & [ left_value, right_value] ) . unwrap ( ) ;
360+ let result = spark_mod ( & [ left_value, right_value] , false ) . unwrap ( ) ;
323361
324362 if let ColumnarValue :: Array ( result_array) = result {
325363 let result_int32 =
@@ -337,20 +375,43 @@ mod test {
337375 let left = Int32Array :: from ( vec ! [ Some ( 10 ) ] ) ;
338376 let left_value = ColumnarValue :: Array ( Arc :: new ( left) ) ;
339377
340- let result = spark_mod ( & [ left_value] ) ;
378+ let result = spark_mod ( & [ left_value] , false ) ;
341379 assert ! ( result. is_err( ) ) ;
342380 }
343381
344382 #[ test]
345- fn test_mod_zero_division ( ) {
383+ fn test_mod_zero_division_legacy ( ) {
384+ // In legacy mode (ANSI off), division by zero returns NULL per-element
385+ let left = Int32Array :: from ( vec ! [ Some ( 10 ) , Some ( 7 ) , Some ( 15 ) ] ) ;
386+ let right = Int32Array :: from ( vec ! [ Some ( 0 ) , Some ( 2 ) , Some ( 4 ) ] ) ;
387+
388+ let left_value = ColumnarValue :: Array ( Arc :: new ( left) ) ;
389+ let right_value = ColumnarValue :: Array ( Arc :: new ( right) ) ;
390+
391+ let result = spark_mod ( & [ left_value, right_value] , false ) . unwrap ( ) ;
392+
393+ if let ColumnarValue :: Array ( result_array) = result {
394+ let result_int32 =
395+ result_array. as_any ( ) . downcast_ref :: < Int32Array > ( ) . unwrap ( ) ;
396+ assert ! ( result_int32. is_null( 0 ) ) ; // 10 % 0 = NULL
397+ assert_eq ! ( result_int32. value( 1 ) , 1 ) ; // 7 % 2 = 1
398+ assert_eq ! ( result_int32. value( 2 ) , 3 ) ; // 15 % 4 = 3
399+ } else {
400+ panic ! ( "Expected array result" ) ;
401+ }
402+ }
403+
404+ #[ test]
405+ fn test_mod_zero_division_ansi ( ) {
406+ // In ANSI mode, division by zero should error
346407 let left = Int32Array :: from ( vec ! [ Some ( 10 ) , Some ( 7 ) , Some ( 15 ) ] ) ;
347408 let right = Int32Array :: from ( vec ! [ Some ( 0 ) , Some ( 2 ) , Some ( 4 ) ] ) ;
348409
349410 let left_value = ColumnarValue :: Array ( Arc :: new ( left) ) ;
350411 let right_value = ColumnarValue :: Array ( Arc :: new ( right) ) ;
351412
352- let result = spark_mod ( & [ left_value, right_value] ) ;
353- assert ! ( result. is_err( ) ) ; // Division by zero should error
413+ let result = spark_mod ( & [ left_value, right_value] , true ) ;
414+ assert ! ( result. is_err( ) ) ;
354415 }
355416
356417 // PMOD tests
@@ -362,7 +423,7 @@ mod test {
362423 let left_value = ColumnarValue :: Array ( Arc :: new ( left) ) ;
363424 let right_value = ColumnarValue :: Array ( Arc :: new ( right) ) ;
364425
365- let result = spark_pmod ( & [ left_value, right_value] ) . unwrap ( ) ;
426+ let result = spark_pmod ( & [ left_value, right_value] , false ) . unwrap ( ) ;
366427
367428 if let ColumnarValue :: Array ( result_array) = result {
368429 let result_int32 =
@@ -385,7 +446,7 @@ mod test {
385446 let left_value = ColumnarValue :: Array ( Arc :: new ( left) ) ;
386447 let right_value = ColumnarValue :: Array ( Arc :: new ( right) ) ;
387448
388- let result = spark_pmod ( & [ left_value, right_value] ) . unwrap ( ) ;
449+ let result = spark_pmod ( & [ left_value, right_value] , false ) . unwrap ( ) ;
389450
390451 if let ColumnarValue :: Array ( result_array) = result {
391452 let result_int64 =
@@ -425,7 +486,7 @@ mod test {
425486 let left_value = ColumnarValue :: Array ( Arc :: new ( left) ) ;
426487 let right_value = ColumnarValue :: Array ( Arc :: new ( right) ) ;
427488
428- let result = spark_pmod ( & [ left_value, right_value] ) . unwrap ( ) ;
489+ let result = spark_pmod ( & [ left_value, right_value] , false ) . unwrap ( ) ;
429490
430491 if let ColumnarValue :: Array ( result_array) = result {
431492 let result_float64 = result_array
@@ -476,7 +537,7 @@ mod test {
476537 let left_value = ColumnarValue :: Array ( Arc :: new ( left) ) ;
477538 let right_value = ColumnarValue :: Array ( Arc :: new ( right) ) ;
478539
479- let result = spark_pmod ( & [ left_value, right_value] ) . unwrap ( ) ;
540+ let result = spark_pmod ( & [ left_value, right_value] , false ) . unwrap ( ) ;
480541
481542 if let ColumnarValue :: Array ( result_array) = result {
482543 let result_float32 = result_array
@@ -508,7 +569,7 @@ mod test {
508569
509570 let left_value = ColumnarValue :: Array ( Arc :: new ( left) ) ;
510571
511- let result = spark_pmod ( & [ left_value, right_value] ) . unwrap ( ) ;
572+ let result = spark_pmod ( & [ left_value, right_value] , false ) . unwrap ( ) ;
512573
513574 if let ColumnarValue :: Array ( result_array) = result {
514575 let result_int32 =
@@ -527,20 +588,43 @@ mod test {
527588 let left = Int32Array :: from ( vec ! [ Some ( 10 ) ] ) ;
528589 let left_value = ColumnarValue :: Array ( Arc :: new ( left) ) ;
529590
530- let result = spark_pmod ( & [ left_value] ) ;
591+ let result = spark_pmod ( & [ left_value] , false ) ;
531592 assert ! ( result. is_err( ) ) ;
532593 }
533594
534595 #[ test]
535- fn test_pmod_zero_division ( ) {
596+ fn test_pmod_zero_division_legacy ( ) {
597+ // In legacy mode (ANSI off), division by zero returns NULL per-element
536598 let left = Int32Array :: from ( vec ! [ Some ( 10 ) , Some ( -7 ) , Some ( 15 ) ] ) ;
537599 let right = Int32Array :: from ( vec ! [ Some ( 0 ) , Some ( 0 ) , Some ( 4 ) ] ) ;
538600
539601 let left_value = ColumnarValue :: Array ( Arc :: new ( left) ) ;
540602 let right_value = ColumnarValue :: Array ( Arc :: new ( right) ) ;
541603
542- let result = spark_pmod ( & [ left_value, right_value] ) ;
543- assert ! ( result. is_err( ) ) ; // Division by zero should error
604+ let result = spark_pmod ( & [ left_value, right_value] , false ) . unwrap ( ) ;
605+
606+ if let ColumnarValue :: Array ( result_array) = result {
607+ let result_int32 =
608+ result_array. as_any ( ) . downcast_ref :: < Int32Array > ( ) . unwrap ( ) ;
609+ assert ! ( result_int32. is_null( 0 ) ) ; // 10 pmod 0 = NULL
610+ assert ! ( result_int32. is_null( 1 ) ) ; // -7 pmod 0 = NULL
611+ assert_eq ! ( result_int32. value( 2 ) , 3 ) ; // 15 pmod 4 = 3
612+ } else {
613+ panic ! ( "Expected array result" ) ;
614+ }
615+ }
616+
617+ #[ test]
618+ fn test_pmod_zero_division_ansi ( ) {
619+ // In ANSI mode, division by zero should error
620+ let left = Int32Array :: from ( vec ! [ Some ( 10 ) , Some ( -7 ) , Some ( 15 ) ] ) ;
621+ let right = Int32Array :: from ( vec ! [ Some ( 0 ) , Some ( 0 ) , Some ( 4 ) ] ) ;
622+
623+ let left_value = ColumnarValue :: Array ( Arc :: new ( left) ) ;
624+ let right_value = ColumnarValue :: Array ( Arc :: new ( right) ) ;
625+
626+ let result = spark_pmod ( & [ left_value, right_value] , true ) ;
627+ assert ! ( result. is_err( ) ) ;
544628 }
545629
546630 #[ test]
@@ -552,7 +636,7 @@ mod test {
552636 let left_value = ColumnarValue :: Array ( Arc :: new ( left) ) ;
553637 let right_value = ColumnarValue :: Array ( Arc :: new ( right) ) ;
554638
555- let result = spark_pmod ( & [ left_value, right_value] ) . unwrap ( ) ;
639+ let result = spark_pmod ( & [ left_value, right_value] , false ) . unwrap ( ) ;
556640
557641 if let ColumnarValue :: Array ( result_array) = result {
558642 let result_int32 =
@@ -590,7 +674,7 @@ mod test {
590674 let left_value = ColumnarValue :: Array ( Arc :: new ( left) ) ;
591675 let right_value = ColumnarValue :: Array ( Arc :: new ( right) ) ;
592676
593- let result = spark_pmod ( & [ left_value, right_value] ) . unwrap ( ) ;
677+ let result = spark_pmod ( & [ left_value, right_value] , false ) . unwrap ( ) ;
594678
595679 if let ColumnarValue :: Array ( result_array) = result {
596680 let result_int32 =
0 commit comments