1818//! Math expressions
1919use super :: { ColumnarValue , ScalarValue } ;
2020use crate :: error:: { DataFusionError , Result } ;
21- use arrow:: array:: { Float32Array , Float64Array } ;
21+ use arrow:: array:: { ArrayRef , Float32Array , Float64Array , Int64Array } ;
2222use arrow:: datatypes:: DataType ;
2323use rand:: { thread_rng, Rng } ;
24+ use std:: any:: type_name;
2425use std:: iter;
2526use std:: sync:: Arc ;
2627
@@ -84,6 +85,33 @@ macro_rules! math_unary_function {
8485 } ;
8586}
8687
88+ macro_rules! downcast_arg {
89+ ( $ARG: expr, $NAME: expr, $ARRAY_TYPE: ident) => { {
90+ $ARG. as_any( ) . downcast_ref:: <$ARRAY_TYPE>( ) . ok_or_else( || {
91+ DataFusionError :: Internal ( format!(
92+ "could not cast {} to {}" ,
93+ $NAME,
94+ type_name:: <$ARRAY_TYPE>( )
95+ ) )
96+ } ) ?
97+ } } ;
98+ }
99+
100+ macro_rules! make_function_inputs2 {
101+ ( $ARG1: expr, $ARG2: expr, $NAME1: expr, $NAME2: expr, $ARRAY_TYPE1: ident, $ARRAY_TYPE2: ident, $FUNC: block) => { {
102+ let arg1 = downcast_arg!( $ARG1, $NAME1, $ARRAY_TYPE1) ;
103+ let arg2 = downcast_arg!( $ARG2, $NAME2, $ARRAY_TYPE2) ;
104+
105+ arg1. iter( )
106+ . zip( arg2. iter( ) )
107+ . map( |( a1, a2) | match ( a1, a2) {
108+ ( Some ( a1) , Some ( a2) ) => Some ( $FUNC( a1, a2. try_into( ) . ok( ) ?) ) ,
109+ _ => None ,
110+ } )
111+ . collect:: <$ARRAY_TYPE1>( )
112+ } } ;
113+ }
114+
87115math_unary_function ! ( "sqrt" , sqrt) ;
88116math_unary_function ! ( "sin" , sin) ;
89117math_unary_function ! ( "cos" , cos) ;
@@ -93,7 +121,6 @@ math_unary_function!("acos", acos);
93121math_unary_function ! ( "atan" , atan) ;
94122math_unary_function ! ( "floor" , floor) ;
95123math_unary_function ! ( "ceil" , ceil) ;
96- math_unary_function ! ( "round" , round) ;
97124math_unary_function ! ( "trunc" , trunc) ;
98125math_unary_function ! ( "abs" , abs) ;
99126math_unary_function ! ( "signum" , signum) ;
@@ -118,11 +145,64 @@ pub fn random(args: &[ColumnarValue]) -> Result<ColumnarValue> {
118145 Ok ( ColumnarValue :: Array ( Arc :: new ( array) ) )
119146}
120147
148+ /// Round SQL function
149+ pub fn round ( args : & [ ArrayRef ] ) -> Result < ArrayRef > {
150+ if args. len ( ) != 1 && args. len ( ) != 2 {
151+ return Err ( DataFusionError :: Internal ( format ! (
152+ "round function requires one or two arguments, got {}" ,
153+ args. len( )
154+ ) ) ) ;
155+ }
156+
157+ let mut decimal_places =
158+ & ( Arc :: new ( Int64Array :: from_value ( 0 , args[ 0 ] . len ( ) ) ) as ArrayRef ) ;
159+
160+ if args. len ( ) == 2 {
161+ decimal_places = & args[ 1 ] ;
162+ }
163+
164+ match args[ 0 ] . data_type ( ) {
165+ DataType :: Float64 => Ok ( Arc :: new ( make_function_inputs2 ! (
166+ & args[ 0 ] ,
167+ decimal_places,
168+ "value" ,
169+ "decimal_places" ,
170+ Float64Array ,
171+ Int64Array ,
172+ {
173+ |value: f64 , decimal_places: i64 | {
174+ ( value * 10.0_f64 . powi( decimal_places. try_into( ) . unwrap( ) ) ) . round( )
175+ / 10.0_f64 . powi( decimal_places. try_into( ) . unwrap( ) )
176+ }
177+ }
178+ ) ) as ArrayRef ) ,
179+
180+ DataType :: Float32 => Ok ( Arc :: new ( make_function_inputs2 ! (
181+ & args[ 0 ] ,
182+ decimal_places,
183+ "value" ,
184+ "decimal_places" ,
185+ Float32Array ,
186+ Int64Array ,
187+ {
188+ |value: f32 , decimal_places: i64 | {
189+ ( value * 10.0_f32 . powi( decimal_places. try_into( ) . unwrap( ) ) ) . round( )
190+ / 10.0_f32 . powi( decimal_places. try_into( ) . unwrap( ) )
191+ }
192+ }
193+ ) ) as ArrayRef ) ,
194+
195+ other => Err ( DataFusionError :: Internal ( format ! (
196+ "Unsupported data type {other:?} for function round"
197+ ) ) ) ,
198+ }
199+ }
200+
121201#[ cfg( test) ]
122202mod tests {
123203
124204 use super :: * ;
125- use arrow:: array:: { Float64Array , NullArray } ;
205+ use arrow:: array:: { Float32Array , Float64Array , NullArray } ;
126206
127207 #[ test]
128208 fn test_random_expression ( ) {
@@ -133,4 +213,44 @@ mod tests {
133213 assert_eq ! ( floats. len( ) , 1 ) ;
134214 assert ! ( 0.0 <= floats. value( 0 ) && floats. value( 0 ) < 1.0 ) ;
135215 }
216+
217+ #[ test]
218+ fn test_round_f32 ( ) {
219+ let args: Vec < ArrayRef > = vec ! [
220+ Arc :: new( Float32Array :: from( vec![ 125.2345 ; 10 ] ) ) , // input
221+ Arc :: new( Int64Array :: from( vec![ 0 , 1 , 2 , 3 , 4 , 5 , -1 , -2 , -3 , -4 ] ) ) , // decimal_places
222+ ] ;
223+
224+ let result = round ( & args) . expect ( "failed to initialize function round" ) ;
225+ let floats = result
226+ . as_any ( )
227+ . downcast_ref :: < Float32Array > ( )
228+ . expect ( "failed to initialize function round" ) ;
229+
230+ let expected = Float32Array :: from ( vec ! [
231+ 125.0 , 125.2 , 125.23 , 125.235 , 125.2345 , 125.2345 , 130.0 , 100.0 , 0.0 , 0.0 ,
232+ ] ) ;
233+
234+ assert_eq ! ( floats, & expected) ;
235+ }
236+
237+ #[ test]
238+ fn test_round_f64 ( ) {
239+ let args: Vec < ArrayRef > = vec ! [
240+ Arc :: new( Float64Array :: from( vec![ 125.2345 ; 10 ] ) ) , // input
241+ Arc :: new( Int64Array :: from( vec![ 0 , 1 , 2 , 3 , 4 , 5 , -1 , -2 , -3 , -4 ] ) ) , // decimal_places
242+ ] ;
243+
244+ let result = round ( & args) . expect ( "failed to initialize function round" ) ;
245+ let floats = result
246+ . as_any ( )
247+ . downcast_ref :: < Float64Array > ( )
248+ . expect ( "failed to initialize function round" ) ;
249+
250+ let expected = Float64Array :: from ( vec ! [
251+ 125.0 , 125.2 , 125.23 , 125.235 , 125.2345 , 125.2345 , 130.0 , 100.0 , 0.0 , 0.0 ,
252+ ] ) ;
253+
254+ assert_eq ! ( floats, & expected) ;
255+ }
136256}
0 commit comments