@@ -85,9 +85,10 @@ pub fn spark_round(
8585 let ( precision, scale) = get_precision_scale ( data_type) ;
8686 make_decimal_array ( array, precision, scale, & f)
8787 }
88- DataType :: Float32 | DataType :: Float64 => {
89- Ok ( ColumnarValue :: Array ( round ( & [ Arc :: clone ( array) ] ) ?) )
90- }
88+ DataType :: Float32 | DataType :: Float64 => Ok ( ColumnarValue :: Array ( round ( & [
89+ Arc :: clone ( array) ,
90+ args[ 1 ] . to_array ( array. len ( ) ) ?,
91+ ] ) ?) ) ,
9192 dt => exec_err ! ( "Not supported datatype for ROUND: {dt}" ) ,
9293 } ,
9394 ColumnarValue :: Scalar ( a) => match a {
@@ -109,7 +110,7 @@ pub fn spark_round(
109110 make_decimal_scalar ( a, precision, scale, & f)
110111 }
111112 ScalarValue :: Float32 ( _) | ScalarValue :: Float64 ( _) => Ok ( ColumnarValue :: Scalar (
112- ScalarValue :: try_from_array ( & round ( & [ a. to_array ( ) ?] ) ?, 0 ) ?,
113+ ScalarValue :: try_from_array ( & round ( & [ a. to_array ( ) ?, args [ 1 ] . to_array ( 1 ) ? ] ) ?, 0 ) ?,
113114 ) ) ,
114115 dt => exec_err ! ( "Not supported datatype for ROUND: {dt}" ) ,
115116 } ,
@@ -135,3 +136,80 @@ fn decimal_round_f(scale: &i8, point: &i64) -> Box<dyn Fn(i128) -> i128> {
135136 Box :: new ( move |x : i128 | ( x + x. signum ( ) * half) / div)
136137 }
137138}
139+
140+ #[ cfg( test) ]
141+ mod test {
142+ use std:: sync:: Arc ;
143+
144+ use crate :: spark_round;
145+
146+ use arrow:: array:: { Float32Array , Float64Array } ;
147+ use arrow_schema:: DataType ;
148+ use datafusion_common:: cast:: { as_float32_array, as_float64_array} ;
149+ use datafusion_common:: { Result , ScalarValue } ;
150+ use datafusion_expr:: ColumnarValue ;
151+
152+ #[ test]
153+ fn test_round_f32_array ( ) -> Result < ( ) > {
154+ let args = vec ! [
155+ ColumnarValue :: Array ( Arc :: new( Float32Array :: from( vec![
156+ 125.2345 , 15.3455 , 0.1234 , 0.125 , 0.785 , 123.123 ,
157+ ] ) ) ) ,
158+ ColumnarValue :: Scalar ( ScalarValue :: Int64 ( Some ( 2 ) ) ) ,
159+ ] ;
160+ let ColumnarValue :: Array ( result) = spark_round ( & args, & DataType :: Float32 ) ? else {
161+ unreachable ! ( )
162+ } ;
163+ let floats = as_float32_array ( & result) ?;
164+ let expected = Float32Array :: from ( vec ! [ 125.23 , 15.35 , 0.12 , 0.13 , 0.79 , 123.12 ] ) ;
165+ assert_eq ! ( floats, & expected) ;
166+ Ok ( ( ) )
167+ }
168+
169+ #[ test]
170+ fn test_round_f64_array ( ) -> Result < ( ) > {
171+ let args = vec ! [
172+ ColumnarValue :: Array ( Arc :: new( Float64Array :: from( vec![
173+ 125.2345 , 15.3455 , 0.1234 , 0.125 , 0.785 , 123.123 ,
174+ ] ) ) ) ,
175+ ColumnarValue :: Scalar ( ScalarValue :: Int64 ( Some ( 2 ) ) ) ,
176+ ] ;
177+ let ColumnarValue :: Array ( result) = spark_round ( & args, & DataType :: Float64 ) ? else {
178+ unreachable ! ( )
179+ } ;
180+ let floats = as_float64_array ( & result) ?;
181+ let expected = Float64Array :: from ( vec ! [ 125.23 , 15.35 , 0.12 , 0.13 , 0.79 , 123.12 ] ) ;
182+ assert_eq ! ( floats, & expected) ;
183+ Ok ( ( ) )
184+ }
185+
186+ #[ test]
187+ fn test_round_f32_scalar ( ) -> Result < ( ) > {
188+ let args = vec ! [
189+ ColumnarValue :: Scalar ( ScalarValue :: Float32 ( Some ( 125.2345 ) ) ) ,
190+ ColumnarValue :: Scalar ( ScalarValue :: Int64 ( Some ( 2 ) ) ) ,
191+ ] ;
192+ let ColumnarValue :: Scalar ( ScalarValue :: Float32 ( Some ( result) ) ) =
193+ spark_round ( & args, & DataType :: Float32 ) ?
194+ else {
195+ unreachable ! ( )
196+ } ;
197+ assert_eq ! ( result, 125.23 ) ;
198+ Ok ( ( ) )
199+ }
200+
201+ #[ test]
202+ fn test_round_f64_scalar ( ) -> Result < ( ) > {
203+ let args = vec ! [
204+ ColumnarValue :: Scalar ( ScalarValue :: Float64 ( Some ( 125.2345 ) ) ) ,
205+ ColumnarValue :: Scalar ( ScalarValue :: Int64 ( Some ( 2 ) ) ) ,
206+ ] ;
207+ let ColumnarValue :: Scalar ( ScalarValue :: Float64 ( Some ( result) ) ) =
208+ spark_round ( & args, & DataType :: Float64 ) ?
209+ else {
210+ unreachable ! ( )
211+ } ;
212+ assert_eq ! ( result, 125.23 ) ;
213+ Ok ( ( ) )
214+ }
215+ }
0 commit comments