1515// specific language governing permissions and limitations
1616// under the License.
1717
18- use arrow:: array:: { ArrayRef , Int64Array } ;
18+ use arrow:: array:: { new_null_array, ArrayRef , AsArray , Int64Array , PrimitiveArray } ;
19+ use arrow:: compute:: try_binary;
20+ use arrow:: datatypes:: { DataType , Int64Type } ;
1921use arrow:: error:: ArrowError ;
2022use std:: any:: Any ;
2123use std:: mem:: swap;
2224use std:: sync:: Arc ;
2325
24- use arrow:: datatypes:: DataType ;
25- use arrow:: datatypes:: DataType :: Int64 ;
26-
27- use crate :: utils:: make_scalar_function;
28- use datafusion_common:: {
29- arrow_datafusion_err, exec_err, internal_datafusion_err, DataFusionError , Result ,
30- } ;
26+ use datafusion_common:: { exec_err, internal_datafusion_err, Result , ScalarValue } ;
3127use datafusion_expr:: {
3228 ColumnarValue , Documentation , ScalarFunctionArgs , ScalarUDFImpl , Signature ,
3329 Volatility ,
@@ -54,9 +50,12 @@ impl Default for GcdFunc {
5450
5551impl GcdFunc {
5652 pub fn new ( ) -> Self {
57- use DataType :: * ;
5853 Self {
59- signature : Signature :: uniform ( 2 , vec ! [ Int64 ] , Volatility :: Immutable ) ,
54+ signature : Signature :: uniform (
55+ 2 ,
56+ vec ! [ DataType :: Int64 ] ,
57+ Volatility :: Immutable ,
58+ ) ,
6059 }
6160 }
6261}
@@ -75,36 +74,69 @@ impl ScalarUDFImpl for GcdFunc {
7574 }
7675
7776 fn return_type ( & self , _arg_types : & [ DataType ] ) -> Result < DataType > {
78- Ok ( Int64 )
77+ Ok ( DataType :: Int64 )
7978 }
8079
8180 fn invoke_with_args ( & self , args : ScalarFunctionArgs ) -> Result < ColumnarValue > {
82- make_scalar_function ( gcd, vec ! [ ] ) ( & args. args )
81+ let args: [ ColumnarValue ; 2 ] = args. args . try_into ( ) . map_err ( |_| {
82+ internal_datafusion_err ! ( "Expected 2 arguments for function gcd" )
83+ } ) ?;
84+
85+ match args {
86+ [ ColumnarValue :: Array ( a) , ColumnarValue :: Array ( b) ] => {
87+ compute_gcd_for_arrays ( & a, & b)
88+ }
89+ [ ColumnarValue :: Scalar ( ScalarValue :: Int64 ( a) ) , ColumnarValue :: Scalar ( ScalarValue :: Int64 ( b) ) ] => {
90+ match ( a, b) {
91+ ( Some ( a) , Some ( b) ) => Ok ( ColumnarValue :: Scalar ( ScalarValue :: Int64 (
92+ Some ( compute_gcd ( a, b) ?) ,
93+ ) ) ) ,
94+ _ => Ok ( ColumnarValue :: Scalar ( ScalarValue :: Int64 ( None ) ) ) ,
95+ }
96+ }
97+ [ ColumnarValue :: Array ( a) , ColumnarValue :: Scalar ( ScalarValue :: Int64 ( b) ) ] => {
98+ compute_gcd_with_scalar ( & a, b)
99+ }
100+ [ ColumnarValue :: Scalar ( ScalarValue :: Int64 ( a) ) , ColumnarValue :: Array ( b) ] => {
101+ compute_gcd_with_scalar ( & b, a)
102+ }
103+ _ => exec_err ! ( "Unsupported argument types for function gcd" ) ,
104+ }
83105 }
84106
85107 fn documentation ( & self ) -> Option < & Documentation > {
86108 self . doc ( )
87109 }
88110}
89111
90- /// Gcd SQL function
91- fn gcd ( args : & [ ArrayRef ] ) -> Result < ArrayRef > {
92- match args[ 0 ] . data_type ( ) {
93- Int64 => {
94- let arg1 = downcast_named_arg ! ( & args[ 0 ] , "x" , Int64Array ) ;
95- let arg2 = downcast_named_arg ! ( & args[ 1 ] , "y" , Int64Array ) ;
112+ fn compute_gcd_for_arrays ( a : & ArrayRef , b : & ArrayRef ) -> Result < ColumnarValue > {
113+ let a = a. as_primitive :: < Int64Type > ( ) ;
114+ let b = b. as_primitive :: < Int64Type > ( ) ;
115+ try_binary ( a, b, compute_gcd)
116+ . map ( |arr : PrimitiveArray < Int64Type > | {
117+ ColumnarValue :: Array ( Arc :: new ( arr) as ArrayRef )
118+ } )
119+ . map_err ( Into :: into) // convert ArrowError to DataFusionError
120+ }
96121
97- Ok ( arg1
122+ fn compute_gcd_with_scalar ( arr : & ArrayRef , scalar : Option < i64 > ) -> Result < ColumnarValue > {
123+ match scalar {
124+ Some ( scalar_value) => {
125+ let result: Result < Int64Array > = arr
126+ . as_primitive :: < Int64Type > ( )
98127 . iter ( )
99- . zip ( arg2. iter ( ) )
100- . map ( |( a1, a2) | match ( a1, a2) {
101- ( Some ( a1) , Some ( a2) ) => Ok ( Some ( compute_gcd ( a1, a2) ?) ) ,
128+ . map ( |val| match val {
129+ Some ( val) => Ok ( Some ( compute_gcd ( val, scalar_value) ?) ) ,
102130 _ => Ok ( None ) ,
103131 } )
104- . collect :: < Result < Int64Array > > ( )
105- . map ( Arc :: new) ? as ArrayRef )
132+ . collect ( ) ;
133+
134+ result. map ( |arr| ColumnarValue :: Array ( Arc :: new ( arr) as ArrayRef ) )
106135 }
107- other => exec_err ! ( "Unsupported data type {other:?} for function gcd" ) ,
136+ None => Ok ( ColumnarValue :: Array ( new_null_array (
137+ & DataType :: Int64 ,
138+ arr. len ( ) ,
139+ ) ) ) ,
108140 }
109141}
110142
@@ -132,61 +164,12 @@ pub(super) fn unsigned_gcd(mut a: u64, mut b: u64) -> u64 {
132164}
133165
134166/// Computes greatest common divisor using Binary GCD algorithm.
135- pub fn compute_gcd ( x : i64 , y : i64 ) -> Result < i64 > {
167+ pub fn compute_gcd ( x : i64 , y : i64 ) -> Result < i64 , ArrowError > {
136168 let a = x. unsigned_abs ( ) ;
137169 let b = y. unsigned_abs ( ) ;
138170 let r = unsigned_gcd ( a, b) ;
139171 // gcd(i64::MIN, i64::MIN) = i64::MIN.unsigned_abs() cannot fit into i64
140172 r. try_into ( ) . map_err ( |_| {
141- arrow_datafusion_err ! ( ArrowError :: ComputeError ( format!(
142- "Signed integer overflow in GCD({x}, {y})"
143- ) ) )
173+ ArrowError :: ComputeError ( format ! ( "Signed integer overflow in GCD({x}, {y})" ) )
144174 } )
145175}
146-
147- #[ cfg( test) ]
148- mod test {
149- use std:: sync:: Arc ;
150-
151- use arrow:: {
152- array:: { ArrayRef , Int64Array } ,
153- error:: ArrowError ,
154- } ;
155-
156- use crate :: math:: gcd:: gcd;
157- use datafusion_common:: { cast:: as_int64_array, DataFusionError } ;
158-
159- #[ test]
160- fn test_gcd_i64 ( ) {
161- let args: Vec < ArrayRef > = vec ! [
162- Arc :: new( Int64Array :: from( vec![ 0 , 3 , 25 , -16 ] ) ) , // x
163- Arc :: new( Int64Array :: from( vec![ 0 , -2 , 15 , 8 ] ) ) , // y
164- ] ;
165-
166- let result = gcd ( & args) . expect ( "failed to initialize function gcd" ) ;
167- let ints = as_int64_array ( & result) . expect ( "failed to initialize function gcd" ) ;
168-
169- assert_eq ! ( ints. len( ) , 4 ) ;
170- assert_eq ! ( ints. value( 0 ) , 0 ) ;
171- assert_eq ! ( ints. value( 1 ) , 1 ) ;
172- assert_eq ! ( ints. value( 2 ) , 5 ) ;
173- assert_eq ! ( ints. value( 3 ) , 8 ) ;
174- }
175-
176- #[ test]
177- fn overflow_on_both_param_i64_min ( ) {
178- let args: Vec < ArrayRef > = vec ! [
179- Arc :: new( Int64Array :: from( vec![ i64 :: MIN ] ) ) , // x
180- Arc :: new( Int64Array :: from( vec![ i64 :: MIN ] ) ) , // y
181- ] ;
182-
183- match gcd ( & args) {
184- // we expect a overflow
185- Err ( DataFusionError :: ArrowError ( ArrowError :: ComputeError ( _) , _) ) => { }
186- Err ( _) => {
187- panic ! ( "failed to initialize function gcd" )
188- }
189- Ok ( _) => panic ! ( "GCD({0}, {0}) should have overflown" , i64 :: MIN ) ,
190- } ;
191- }
192- }
0 commit comments