1515// specific language governing permissions and limitations
1616// under the License.
1717
18+ use crate :: arithmetic_overflow_error;
1819use crate :: math_funcs:: utils:: { get_precision_scale, make_decimal_array, make_decimal_scalar} ;
1920use arrow:: array:: { Array , ArrowNativeTypeOp } ;
2021use arrow:: array:: { Int16Array , Int32Array , Int64Array , Int8Array } ;
2122use arrow:: datatypes:: DataType ;
23+ use arrow:: error:: ArrowError ;
2224use datafusion:: common:: { exec_err, internal_err, DataFusionError , ScalarValue } ;
2325use datafusion:: { functions:: math:: round:: round, physical_plan:: ColumnarValue } ;
2426use std:: { cmp:: min, sync:: Arc } ;
2527
2628macro_rules! integer_round {
27- ( $X: expr, $DIV: expr, $HALF: expr) => { {
29+ ( $X: expr, $DIV: expr, $HALF: expr, $FAIL_ON_ERROR : expr ) => { {
2830 let rem = $X % $DIV;
2931 if rem <= -$HALF {
30- ( $X - rem) . sub_wrapping( $DIV)
32+ if $FAIL_ON_ERROR {
33+ ( $X - rem) . sub_checked( $DIV) . map_err( |_| {
34+ ArrowError :: ComputeError ( arithmetic_overflow_error( "integer" ) . to_string( ) )
35+ } )
36+ } else {
37+ Ok ( ( $X - rem) . sub_wrapping( $DIV) )
38+ }
3139 } else if rem >= $HALF {
32- ( $X - rem) . add_wrapping( $DIV)
40+ if $FAIL_ON_ERROR {
41+ ( $X - rem) . add_checked( $DIV) . map_err( |_| {
42+ ArrowError :: ComputeError ( arithmetic_overflow_error( "integer" ) . to_string( ) )
43+ } )
44+ } else {
45+ Ok ( ( $X - rem) . add_wrapping( $DIV) )
46+ }
3347 } else {
34- $X - rem
48+ if $FAIL_ON_ERROR {
49+ $X. sub_checked( rem) . map_err( |_| {
50+ ArrowError :: ComputeError ( arithmetic_overflow_error( "integer" ) . to_string( ) )
51+ } )
52+ } else {
53+ Ok ( $X. sub_wrapping( rem) )
54+ }
3555 }
3656 } } ;
3757}
3858
3959macro_rules! round_integer_array {
40- ( $ARRAY: expr, $POINT: expr, $TYPE: ty, $NATIVE: ty) => { {
60+ ( $ARRAY: expr, $POINT: expr, $TYPE: ty, $NATIVE: ty, $FAIL_ON_ERROR : expr ) => { {
4161 let array = $ARRAY. as_any( ) . downcast_ref:: <$TYPE>( ) . unwrap( ) ;
4262 let ten: $NATIVE = 10 ;
4363 let result: $TYPE = if let Some ( div) = ten. checked_pow( ( -( * $POINT) ) as u32 ) {
4464 let half = div / 2 ;
45- arrow:: compute:: kernels:: arity:: unary( array, |x| integer_round!( x, div, half) )
65+ arrow:: compute:: kernels:: arity:: try_unary( array, |x| {
66+ integer_round!( x, div, half, $FAIL_ON_ERROR)
67+ } ) ?
4668 } else {
47- arrow:: compute:: kernels:: arity:: unary ( array, |_| 0 )
69+ arrow:: compute:: kernels:: arity:: try_unary ( array, |_| Ok ( 0 ) ) ?
4870 } ;
4971 Ok ( ColumnarValue :: Array ( Arc :: new( result) ) )
5072 } } ;
5173}
5274
5375macro_rules! round_integer_scalar {
54- ( $SCALAR: expr, $POINT: expr, $TYPE: expr, $NATIVE: ty) => { {
76+ ( $SCALAR: expr, $POINT: expr, $TYPE: expr, $NATIVE: ty, $FAIL_ON_ERROR : expr ) => { {
5577 let ten: $NATIVE = 10 ;
5678 if let Some ( div) = ten. checked_pow( ( -( * $POINT) ) as u32 ) {
5779 let half = div / 2 ;
58- Ok ( ColumnarValue :: Scalar ( $TYPE(
59- $SCALAR. map( |x| integer_round!( x, div, half) ) ,
60- ) ) )
80+ let scalar_opt = match $SCALAR {
81+ Some ( x) => match integer_round!( x, div, half, $FAIL_ON_ERROR) {
82+ Ok ( v) => Some ( v) ,
83+ Err ( e) => {
84+ return Err ( DataFusionError :: ArrowError (
85+ Box :: from( e) ,
86+ Some ( DataFusionError :: get_back_trace( ) ) ,
87+ ) )
88+ }
89+ } ,
90+ None => None ,
91+ } ;
92+ Ok ( ColumnarValue :: Scalar ( $TYPE( scalar_opt) ) )
6193 } else {
6294 Ok ( ColumnarValue :: Scalar ( $TYPE( Some ( 0 ) ) ) )
6395 }
@@ -68,6 +100,7 @@ macro_rules! round_integer_scalar {
68100pub fn spark_round (
69101 args : & [ ColumnarValue ] ,
70102 data_type : & DataType ,
103+ fail_on_error : bool ,
71104) -> Result < ColumnarValue , DataFusionError > {
72105 let value = & args[ 0 ] ;
73106 let point = & args[ 1 ] ;
@@ -76,10 +109,18 @@ pub fn spark_round(
76109 } ;
77110 match value {
78111 ColumnarValue :: Array ( array) => match array. data_type ( ) {
79- DataType :: Int64 if * point < 0 => round_integer_array ! ( array, point, Int64Array , i64 ) ,
80- DataType :: Int32 if * point < 0 => round_integer_array ! ( array, point, Int32Array , i32 ) ,
81- DataType :: Int16 if * point < 0 => round_integer_array ! ( array, point, Int16Array , i16 ) ,
82- DataType :: Int8 if * point < 0 => round_integer_array ! ( array, point, Int8Array , i8 ) ,
112+ DataType :: Int64 if * point < 0 => {
113+ round_integer_array ! ( array, point, Int64Array , i64 , fail_on_error)
114+ }
115+ DataType :: Int32 if * point < 0 => {
116+ round_integer_array ! ( array, point, Int32Array , i32 , fail_on_error)
117+ }
118+ DataType :: Int16 if * point < 0 => {
119+ round_integer_array ! ( array, point, Int16Array , i16 , fail_on_error)
120+ }
121+ DataType :: Int8 if * point < 0 => {
122+ round_integer_array ! ( array, point, Int8Array , i8 , fail_on_error)
123+ }
83124 DataType :: Decimal128 ( _, scale) if * scale >= 0 => {
84125 let f = decimal_round_f ( scale, point) ;
85126 let ( precision, scale) = get_precision_scale ( data_type) ;
@@ -93,16 +134,16 @@ pub fn spark_round(
93134 } ,
94135 ColumnarValue :: Scalar ( a) => match a {
95136 ScalarValue :: Int64 ( a) if * point < 0 => {
96- round_integer_scalar ! ( a, point, ScalarValue :: Int64 , i64 )
137+ round_integer_scalar ! ( a, point, ScalarValue :: Int64 , i64 , fail_on_error )
97138 }
98139 ScalarValue :: Int32 ( a) if * point < 0 => {
99- round_integer_scalar ! ( a, point, ScalarValue :: Int32 , i32 )
140+ round_integer_scalar ! ( a, point, ScalarValue :: Int32 , i32 , fail_on_error )
100141 }
101142 ScalarValue :: Int16 ( a) if * point < 0 => {
102- round_integer_scalar ! ( a, point, ScalarValue :: Int16 , i16 )
143+ round_integer_scalar ! ( a, point, ScalarValue :: Int16 , i16 , fail_on_error )
103144 }
104145 ScalarValue :: Int8 ( a) if * point < 0 => {
105- round_integer_scalar ! ( a, point, ScalarValue :: Int8 , i8 )
146+ round_integer_scalar ! ( a, point, ScalarValue :: Int8 , i8 , fail_on_error )
106147 }
107148 ScalarValue :: Decimal128 ( a, _, scale) if * scale >= 0 => {
108149 let f = decimal_round_f ( scale, point) ;
@@ -158,7 +199,7 @@ mod test {
158199 ] ) ) ) ,
159200 ColumnarValue :: Scalar ( ScalarValue :: Int64 ( Some ( 2 ) ) ) ,
160201 ] ;
161- let ColumnarValue :: Array ( result) = spark_round ( & args, & DataType :: Float32 ) ? else {
202+ let ColumnarValue :: Array ( result) = spark_round ( & args, & DataType :: Float32 , false ) ? else {
162203 unreachable ! ( )
163204 } ;
164205 let floats = as_float32_array ( & result) ?;
@@ -176,7 +217,7 @@ mod test {
176217 ] ) ) ) ,
177218 ColumnarValue :: Scalar ( ScalarValue :: Int64 ( Some ( 2 ) ) ) ,
178219 ] ;
179- let ColumnarValue :: Array ( result) = spark_round ( & args, & DataType :: Float64 ) ? else {
220+ let ColumnarValue :: Array ( result) = spark_round ( & args, & DataType :: Float64 , false ) ? else {
180221 unreachable ! ( )
181222 } ;
182223 let floats = as_float64_array ( & result) ?;
@@ -193,7 +234,7 @@ mod test {
193234 ColumnarValue :: Scalar ( ScalarValue :: Int64 ( Some ( 2 ) ) ) ,
194235 ] ;
195236 let ColumnarValue :: Scalar ( ScalarValue :: Float32 ( Some ( result) ) ) =
196- spark_round ( & args, & DataType :: Float32 ) ?
237+ spark_round ( & args, & DataType :: Float32 , false ) ?
197238 else {
198239 unreachable ! ( )
199240 } ;
@@ -209,7 +250,7 @@ mod test {
209250 ColumnarValue :: Scalar ( ScalarValue :: Int64 ( Some ( 2 ) ) ) ,
210251 ] ;
211252 let ColumnarValue :: Scalar ( ScalarValue :: Float64 ( Some ( result) ) ) =
212- spark_round ( & args, & DataType :: Float64 ) ?
253+ spark_round ( & args, & DataType :: Float64 , false ) ?
213254 else {
214255 unreachable ! ( )
215256 } ;
0 commit comments