1515// specific language governing permissions and limitations
1616// under the License.
1717
18- use arrow:: array:: { make_comparator, Array , ArrayRef , BooleanArray } ;
18+ use crate :: core:: greatest_least_utils:: GreatestLeastOperator ;
19+ use arrow:: array:: { make_comparator, Array , BooleanArray } ;
1920use arrow:: compute:: kernels:: cmp;
20- use arrow:: compute:: kernels:: zip:: zip;
2121use arrow:: compute:: SortOptions ;
2222use arrow:: datatypes:: DataType ;
2323use arrow_buffer:: BooleanBuffer ;
24- use datafusion_common:: { exec_err , plan_err , Result , ScalarValue } ;
24+ use datafusion_common:: { internal_err , Result , ScalarValue } ;
2525use datafusion_doc:: Documentation ;
26- use datafusion_expr:: binary:: type_union_resolution;
2726use datafusion_expr:: scalar_doc_sections:: DOC_SECTION_CONDITIONAL ;
2827use datafusion_expr:: ColumnarValue ;
2928use datafusion_expr:: { ScalarUDFImpl , Signature , Volatility } ;
3029use std:: any:: Any ;
31- use std:: sync:: { Arc , OnceLock } ;
30+ use std:: sync:: OnceLock ;
3231
3332const SORT_OPTIONS : SortOptions = SortOptions {
3433 // We want greatest first
@@ -57,79 +56,57 @@ impl GreatestFunc {
5756 }
5857}
5958
60- fn get_logical_null_count ( arr : & dyn Array ) -> usize {
61- arr. logical_nulls ( )
62- . map ( |n| n. null_count ( ) )
63- . unwrap_or_default ( )
64- }
59+ impl GreatestLeastOperator for GreatestFunc {
60+ const NAME : & ' static str = "greatest" ;
6561
66- /// Return boolean array where `arr[i] = lhs[i] >= rhs[i]` for all i, where `arr` is the result array
67- /// Nulls are always considered smaller than any other value
68- fn get_larger ( lhs : & dyn Array , rhs : & dyn Array ) -> Result < BooleanArray > {
69- // Fast path:
70- // If both arrays are not nested, have the same length and no nulls, we can use the faster vectorised kernel
71- // - If both arrays are not nested: Nested types, such as lists, are not supported as the null semantics are not well-defined.
72- // - both array does not have any nulls: cmp::gt_eq will return null if any of the input is null while we want to return false in that case
73- if !lhs. data_type ( ) . is_nested ( )
74- && get_logical_null_count ( lhs) == 0
75- && get_logical_null_count ( rhs) == 0
76- {
77- return cmp:: gt_eq ( & lhs, & rhs) . map_err ( |e| e. into ( ) ) ;
78- }
62+ fn keep_scalar < ' a > (
63+ lhs : & ' a ScalarValue ,
64+ rhs : & ' a ScalarValue ,
65+ ) -> Result < & ' a ScalarValue > {
66+ if !lhs. data_type ( ) . is_nested ( ) {
67+ return if lhs >= rhs { Ok ( lhs) } else { Ok ( rhs) } ;
68+ }
7969
80- let cmp = make_comparator ( lhs, rhs, SORT_OPTIONS ) ?;
70+ // If complex type we can't compare directly as we want null values to be smaller
71+ let cmp = make_comparator (
72+ lhs. to_array ( ) ?. as_ref ( ) ,
73+ rhs. to_array ( ) ?. as_ref ( ) ,
74+ SORT_OPTIONS ,
75+ ) ?;
8176
82- if lhs. len ( ) != rhs. len ( ) {
83- return exec_err ! (
84- "All arrays should have the same length for greatest comparison"
85- ) ;
77+ if cmp ( 0 , 0 ) . is_ge ( ) {
78+ Ok ( lhs)
79+ } else {
80+ Ok ( rhs)
81+ }
8682 }
8783
88- let values = BooleanBuffer :: collect_bool ( lhs. len ( ) , |i| cmp ( i, i) . is_ge ( ) ) ;
89-
90- // No nulls as we only want to keep the values that are larger, its either true or false
91- Ok ( BooleanArray :: new ( values, None ) )
92- }
93-
94- /// Return array where the largest value at each index is kept
95- fn keep_larger ( lhs : ArrayRef , rhs : ArrayRef ) -> Result < ArrayRef > {
96- // True for values that we should keep from the left array
97- let keep_lhs = get_larger ( lhs. as_ref ( ) , rhs. as_ref ( ) ) ?;
98-
99- let larger = zip ( & keep_lhs, & lhs, & rhs) ?;
84+ /// Return boolean array where `arr[i] = lhs[i] >= rhs[i]` for all i, where `arr` is the result array
85+ /// Nulls are always considered smaller than any other value
86+ fn get_indexes_to_keep ( lhs : & dyn Array , rhs : & dyn Array ) -> Result < BooleanArray > {
87+ // Fast path:
88+ // If both arrays are not nested, have the same length and no nulls, we can use the faster vectorised kernel
89+ // - If both arrays are not nested: Nested types, such as lists, are not supported as the null semantics are not well-defined.
90+ // - both array does not have any nulls: cmp::gt_eq will return null if any of the input is null while we want to return false in that case
91+ if !lhs. data_type ( ) . is_nested ( )
92+ && lhs. logical_null_count ( ) == 0
93+ && rhs. logical_null_count ( ) == 0
94+ {
95+ return cmp:: gt_eq ( & lhs, & rhs) . map_err ( |e| e. into ( ) ) ;
96+ }
10097
101- Ok ( larger)
102- }
98+ let cmp = make_comparator ( lhs, rhs, SORT_OPTIONS ) ?;
10399
104- fn keep_larger_scalar < ' a > (
105- lhs : & ' a ScalarValue ,
106- rhs : & ' a ScalarValue ,
107- ) -> Result < & ' a ScalarValue > {
108- if !lhs. data_type ( ) . is_nested ( ) {
109- return if lhs >= rhs { Ok ( lhs) } else { Ok ( rhs) } ;
110- }
111-
112- // If complex type we can't compare directly as we want null values to be smaller
113- let cmp = make_comparator (
114- lhs. to_array ( ) ?. as_ref ( ) ,
115- rhs. to_array ( ) ?. as_ref ( ) ,
116- SORT_OPTIONS ,
117- ) ?;
100+ if lhs. len ( ) != rhs. len ( ) {
101+ return internal_err ! (
102+ "All arrays should have the same length for greatest comparison"
103+ ) ;
104+ }
118105
119- if cmp ( 0 , 0 ) . is_ge ( ) {
120- Ok ( lhs)
121- } else {
122- Ok ( rhs)
123- }
124- }
106+ let values = BooleanBuffer :: collect_bool ( lhs. len ( ) , |i| cmp ( i, i) . is_ge ( ) ) ;
125107
126- fn find_coerced_type ( data_types : & [ DataType ] ) -> Result < DataType > {
127- if data_types. is_empty ( ) {
128- plan_err ! ( "greatest was called without any arguments. It requires at least 1." )
129- } else if let Some ( coerced_type) = type_union_resolution ( data_types) {
130- Ok ( coerced_type)
131- } else {
132- plan_err ! ( "Cannot find a common type for arguments" )
108+ // No nulls as we only want to keep the values that are larger, its either true or false
109+ Ok ( BooleanArray :: new ( values, None ) )
133110 }
134111}
135112
@@ -151,74 +128,12 @@ impl ScalarUDFImpl for GreatestFunc {
151128 }
152129
153130 fn invoke ( & self , args : & [ ColumnarValue ] ) -> Result < ColumnarValue > {
154- if args. is_empty ( ) {
155- return exec_err ! (
156- "greatest was called with no arguments. It requires at least 1."
157- ) ;
158- }
159-
160- // Some engines (e.g. SQL Server) allow greatest with single arg, it's a noop
161- if args. len ( ) == 1 {
162- return Ok ( args[ 0 ] . clone ( ) ) ;
163- }
164-
165- // Split to scalars and arrays for later optimization
166- let ( scalars, arrays) : ( Vec < _ > , Vec < _ > ) = args. iter ( ) . partition ( |x| match x {
167- ColumnarValue :: Scalar ( _) => true ,
168- ColumnarValue :: Array ( _) => false ,
169- } ) ;
170-
171- let mut arrays_iter = arrays. iter ( ) . map ( |x| match x {
172- ColumnarValue :: Array ( a) => a,
173- _ => unreachable ! ( ) ,
174- } ) ;
175-
176- let first_array = arrays_iter. next ( ) ;
177-
178- let mut largest: ArrayRef ;
179-
180- // Optimization: merge all scalars into one to avoid recomputing
181- if !scalars. is_empty ( ) {
182- let mut scalars_iter = scalars. iter ( ) . map ( |x| match x {
183- ColumnarValue :: Scalar ( s) => s,
184- _ => unreachable ! ( ) ,
185- } ) ;
186-
187- // We have at least one scalar
188- let mut largest_scalar = scalars_iter. next ( ) . unwrap ( ) ;
189-
190- for scalar in scalars_iter {
191- largest_scalar = keep_larger_scalar ( largest_scalar, scalar) ?;
192- }
193-
194- // If we only have scalars, return the largest one
195- if arrays. is_empty ( ) {
196- return Ok ( ColumnarValue :: Scalar ( largest_scalar. clone ( ) ) ) ;
197- }
198-
199- // We have at least one array
200- let first_array = first_array. unwrap ( ) ;
201-
202- // Start with the largest value
203- largest = keep_larger (
204- Arc :: clone ( first_array) ,
205- largest_scalar. to_array_of_size ( first_array. len ( ) ) ?,
206- ) ?;
207- } else {
208- // If we only have arrays, start with the first array
209- // (We must have at least one array)
210- largest = Arc :: clone ( first_array. unwrap ( ) ) ;
211- }
212-
213- for array in arrays_iter {
214- largest = keep_larger ( Arc :: clone ( array) , largest) ?;
215- }
216-
217- Ok ( ColumnarValue :: Array ( largest) )
131+ super :: greatest_least_utils:: execute_conditional :: < Self > ( args)
218132 }
219133
220134 fn coerce_types ( & self , arg_types : & [ DataType ] ) -> Result < Vec < DataType > > {
221- let coerced_type = find_coerced_type ( arg_types) ?;
135+ let coerced_type =
136+ super :: greatest_least_utils:: find_coerced_type :: < Self > ( arg_types) ?;
222137
223138 Ok ( vec ! [ coerced_type; arg_types. len( ) ] )
224139 }
0 commit comments