@@ -113,6 +113,18 @@ macro_rules! make_function_inputs2 {
113113 } )
114114 . collect:: <$ARRAY_TYPE>( )
115115 } } ;
116+ ( $ARG1: expr, $ARG2: expr, $NAME1: expr, $NAME2: expr, $ARRAY_TYPE1: ident, $ARRAY_TYPE2: ident, $FUNC: block) => { {
117+ let arg1 = downcast_arg!( $ARG1, $NAME1, $ARRAY_TYPE1) ;
118+ let arg2 = downcast_arg!( $ARG2, $NAME2, $ARRAY_TYPE2) ;
119+
120+ arg1. iter( )
121+ . zip( arg2. iter( ) )
122+ . map( |( a1, a2) | match ( a1, a2) {
123+ ( Some ( a1) , Some ( a2) ) => Some ( $FUNC( a1, a2. try_into( ) . ok( ) ?) ) ,
124+ _ => None ,
125+ } )
126+ . collect:: <$ARRAY_TYPE1>( )
127+ } } ;
116128}
117129
118130math_unary_function ! ( "sqrt" , sqrt) ;
@@ -124,7 +136,6 @@ math_unary_function!("acos", acos);
124136math_unary_function ! ( "atan" , atan) ;
125137math_unary_function ! ( "floor" , floor) ;
126138math_unary_function ! ( "ceil" , ceil) ;
127- math_unary_function ! ( "round" , round) ;
128139math_unary_function ! ( "trunc" , trunc) ;
129140math_unary_function ! ( "abs" , abs) ;
130141math_unary_function ! ( "signum" , signum) ;
@@ -160,6 +171,59 @@ pub fn random(args: &[ColumnarValue]) -> Result<ColumnarValue> {
160171 Ok ( ColumnarValue :: Array ( Arc :: new ( array) ) )
161172}
162173
174+ /// Round SQL function
175+ pub fn round ( args : & [ ArrayRef ] ) -> Result < ArrayRef > {
176+ if args. len ( ) != 1 && args. len ( ) != 2 {
177+ return Err ( DataFusionError :: Internal ( format ! (
178+ "round function requires one or two arguments, got {}" ,
179+ args. len( )
180+ ) ) ) ;
181+ }
182+
183+ let mut decimal_places =
184+ & ( Arc :: new ( Int64Array :: from_value ( 0 , args[ 0 ] . len ( ) ) ) as ArrayRef ) ;
185+
186+ if args. len ( ) == 2 {
187+ decimal_places = & args[ 1 ] ;
188+ }
189+
190+ match args[ 0 ] . data_type ( ) {
191+ DataType :: Float64 => Ok ( Arc :: new ( make_function_inputs2 ! (
192+ & args[ 0 ] ,
193+ decimal_places,
194+ "value" ,
195+ "decimal_places" ,
196+ Float64Array ,
197+ Int64Array ,
198+ {
199+ |value: f64 , decimal_places: i64 | {
200+ ( value * 10.0_f64 . powi( decimal_places. try_into( ) . unwrap( ) ) ) . round( )
201+ / 10.0_f64 . powi( decimal_places. try_into( ) . unwrap( ) )
202+ }
203+ }
204+ ) ) as ArrayRef ) ,
205+
206+ DataType :: Float32 => Ok ( Arc :: new ( make_function_inputs2 ! (
207+ & args[ 0 ] ,
208+ decimal_places,
209+ "value" ,
210+ "decimal_places" ,
211+ Float32Array ,
212+ Int64Array ,
213+ {
214+ |value: f32 , decimal_places: i64 | {
215+ ( value * 10.0_f32 . powi( decimal_places. try_into( ) . unwrap( ) ) ) . round( )
216+ / 10.0_f32 . powi( decimal_places. try_into( ) . unwrap( ) )
217+ }
218+ }
219+ ) ) as ArrayRef ) ,
220+
221+ other => Err ( DataFusionError :: Internal ( format ! (
222+ "Unsupported data type {other:?} for function round"
223+ ) ) ) ,
224+ }
225+ }
226+
163227pub fn power ( args : & [ ArrayRef ] ) -> Result < ArrayRef > {
164228 match args[ 0 ] . data_type ( ) {
165229 DataType :: Float64 => Ok ( Arc :: new ( make_function_inputs2 ! (
@@ -202,4 +266,44 @@ mod tests {
202266 assert_eq ! ( floats. len( ) , 1 ) ;
203267 assert ! ( 0.0 <= floats. value( 0 ) && floats. value( 0 ) < 1.0 ) ;
204268 }
269+
270+ #[ test]
271+ fn test_round_f32 ( ) {
272+ let args: Vec < ArrayRef > = vec ! [
273+ Arc :: new( Float32Array :: from( vec![ 125.2345 ; 10 ] ) ) , // input
274+ Arc :: new( Int64Array :: from( vec![ 0 , 1 , 2 , 3 , 4 , 5 , -1 , -2 , -3 , -4 ] ) ) , // decimal_places
275+ ] ;
276+
277+ let result = round ( & args) . expect ( "failed to initialize function round" ) ;
278+ let floats = result
279+ . as_any ( )
280+ . downcast_ref :: < Float32Array > ( )
281+ . expect ( "failed to initialize function round" ) ;
282+
283+ let expected = Float32Array :: from ( vec ! [
284+ 125.0 , 125.2 , 125.23 , 125.235 , 125.2345 , 125.2345 , 130.0 , 100.0 , 0.0 , 0.0 ,
285+ ] ) ;
286+
287+ assert_eq ! ( floats, & expected) ;
288+ }
289+
290+ #[ test]
291+ fn test_round_f64 ( ) {
292+ let args: Vec < ArrayRef > = vec ! [
293+ Arc :: new( Float64Array :: from( vec![ 125.2345 ; 10 ] ) ) , // input
294+ Arc :: new( Int64Array :: from( vec![ 0 , 1 , 2 , 3 , 4 , 5 , -1 , -2 , -3 , -4 ] ) ) , // decimal_places
295+ ] ;
296+
297+ let result = round ( & args) . expect ( "failed to initialize function round" ) ;
298+ let floats = result
299+ . as_any ( )
300+ . downcast_ref :: < Float64Array > ( )
301+ . expect ( "failed to initialize function round" ) ;
302+
303+ let expected = Float64Array :: from ( vec ! [
304+ 125.0 , 125.2 , 125.23 , 125.235 , 125.2345 , 125.2345 , 130.0 , 100.0 , 0.0 , 0.0 ,
305+ ] ) ;
306+
307+ assert_eq ! ( floats, & expected) ;
308+ }
205309}
0 commit comments