1515// specific language governing permissions and limitations
1616// under the License.
1717
18- use arrow:: array:: {
19- Array , ArrayRef , AsArray , FixedSizeListArray , Int32Array , Int32Builder ,
20- } ;
18+ use arrow:: array:: { Array , ArrayRef , AsArray , Int32Array } ;
19+ use arrow:: compute:: kernels:: length:: length as arrow_length;
2120use arrow:: datatypes:: { DataType , Field , FieldRef } ;
2221use datafusion_common:: { Result , plan_err} ;
2322use datafusion_expr:: {
24- ColumnarValue , ReturnFieldArgs , ScalarFunctionArgs , ScalarUDFImpl , Signature ,
25- TypeSignature , Volatility ,
23+ ArrayFunctionArgument , ArrayFunctionSignature , ColumnarValue , ReturnFieldArgs ,
24+ ScalarFunctionArgs , ScalarUDFImpl , Signature , TypeSignature , Volatility ,
2625} ;
2726use datafusion_functions:: utils:: make_scalar_function;
2827use std:: any:: Any ;
@@ -31,7 +30,7 @@ use std::sync::Arc;
3130/// Spark-compatible `size` function.
3231///
3332/// Returns the number of elements in an array or the number of key-value pairs in a map.
34- /// Returns null for null input.
33+ /// Returns -1 for null input (Spark behavior) .
3534#[ derive( Debug , PartialEq , Eq , Hash ) ]
3635pub struct SparkSize {
3736 signature : Signature ,
@@ -47,7 +46,15 @@ impl SparkSize {
4746 pub fn new ( ) -> Self {
4847 Self {
4948 signature : Signature :: one_of (
50- vec ! [ TypeSignature :: Any ( 1 ) ] ,
49+ vec ! [
50+ // Array Type
51+ TypeSignature :: ArraySignature ( ArrayFunctionSignature :: Array {
52+ arguments: vec![ ArrayFunctionArgument :: Array ] ,
53+ array_coercion: None ,
54+ } ) ,
55+ // Map Type
56+ TypeSignature :: ArraySignature ( ArrayFunctionSignature :: MapArray ) ,
57+ ] ,
5158 Volatility :: Immutable ,
5259 ) ,
5360 }
@@ -72,47 +79,14 @@ impl ScalarUDFImpl for SparkSize {
7279 }
7380
7481 fn return_field_from_args ( & self , args : ReturnFieldArgs ) -> Result < FieldRef > {
75- if args. arg_fields . len ( ) != 1 {
76- return plan_err ! ( "size expects exactly 1 argument" ) ;
77- }
78-
79- let input_field = & args. arg_fields [ 0 ] ;
80-
81- match input_field. data_type ( ) {
82- DataType :: List ( _)
83- | DataType :: LargeList ( _)
84- | DataType :: FixedSizeList ( _, _)
85- | DataType :: Map ( _, _)
86- | DataType :: Null => { }
87- dt => {
88- return plan_err ! (
89- "size function requires array or map types, got: {}" ,
90- dt
91- ) ;
92- }
93- }
94-
95- let mut out_nullable = input_field. is_nullable ( ) ;
96-
97- let scala_null_present = args
98- . scalar_arguments
99- . iter ( )
100- . any ( |opt_s| opt_s. is_some_and ( |sv| sv. is_null ( ) ) ) ;
101- if scala_null_present {
102- out_nullable = true ;
103- }
104-
10582 Ok ( Arc :: new ( Field :: new (
10683 self . name ( ) ,
10784 DataType :: Int32 ,
108- out_nullable ,
85+ args . arg_fields [ 0 ] . is_nullable ( ) ,
10986 ) ) )
11087 }
11188
11289 fn invoke_with_args ( & self , args : ScalarFunctionArgs ) -> Result < ColumnarValue > {
113- if args. args . len ( ) != 1 {
114- return plan_err ! ( "size expects exactly 1 argument" ) ;
115- }
11690 make_scalar_function ( spark_size_inner, vec ! [ ] ) ( & args. args )
11791 }
11892}
@@ -122,228 +96,70 @@ fn spark_size_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
12296
12397 match array. data_type ( ) {
12498 DataType :: List ( _) => {
125- let list_array = array. as_list :: < i32 > ( ) ;
126- let mut builder = Int32Builder :: with_capacity ( list_array. len ( ) ) ;
127- for i in 0 ..list_array. len ( ) {
128- if list_array. is_null ( i) {
129- builder. append_null ( ) ;
130- } else {
131- let len = list_array. value ( i) . len ( ) ;
132- builder. append_value ( len as i32 )
133- }
99+ if array. null_count ( ) == 0 {
100+ Ok ( arrow_length ( array) ?)
101+ } else {
102+ let list_array = array. as_list :: < i32 > ( ) ;
103+ let lengths: Vec < i32 > = list_array
104+ . offsets ( )
105+ . lengths ( )
106+ . enumerate ( )
107+ . map ( |( i, len) | if array. is_null ( i) { -1 } else { len as i32 } )
108+ . collect ( ) ;
109+ Ok ( Arc :: new ( Int32Array :: from ( lengths) ) )
110+ }
111+ }
112+ DataType :: FixedSizeList ( _, size) => {
113+ if array. null_count ( ) == 0 {
114+ Ok ( arrow_length ( array) ?)
115+ } else {
116+ let length: Vec < i32 > = ( 0 ..array. len ( ) )
117+ . map ( |i| if array. is_null ( i) { -1 } else { * size } )
118+ . collect ( ) ;
119+ Ok ( Arc :: new ( Int32Array :: from ( length) ) )
134120 }
135-
136- Ok ( Arc :: new ( builder. finish ( ) ) )
137121 }
138122 DataType :: LargeList ( _) => {
123+ // Arrow length kernel returns Int64 for LargeList
139124 let list_array = array. as_list :: < i64 > ( ) ;
140- let mut builder = Int32Builder :: with_capacity ( list_array. len ( ) ) ;
141- for i in 0 ..list_array. len ( ) {
142- if list_array. is_null ( i) {
143- builder. append_null ( ) ;
144- } else {
145- let len = list_array. value ( i) . len ( ) ;
146- builder. append_value ( len as i32 )
147- }
125+ if array. null_count ( ) == 0 {
126+ let lengths: Vec < i32 > = list_array
127+ . offsets ( )
128+ . lengths ( )
129+ . map ( |len| len as i32 )
130+ . collect ( ) ;
131+ Ok ( Arc :: new ( Int32Array :: from ( lengths) ) )
132+ } else {
133+ let lengths: Vec < i32 > = list_array
134+ . offsets ( )
135+ . lengths ( )
136+ . enumerate ( )
137+ . map ( |( i, len) | if array. is_null ( i) { -1 } else { len as i32 } )
138+ . collect ( ) ;
139+ Ok ( Arc :: new ( Int32Array :: from ( lengths) ) )
148140 }
149-
150- Ok ( Arc :: new ( builder. finish ( ) ) )
151- }
152- DataType :: FixedSizeList ( _, size) => {
153- let list_array: & FixedSizeListArray = array. as_fixed_size_list ( ) ;
154- let fixed_size = * size;
155- let result: Int32Array = ( 0 ..list_array. len ( ) )
156- . map ( |i| {
157- if list_array. is_null ( i) {
158- None
159- } else {
160- Some ( fixed_size)
161- }
162- } )
163- . collect ( ) ;
164-
165- Ok ( Arc :: new ( result) )
166141 }
167142 DataType :: Map ( _, _) => {
168143 let map_array = array. as_map ( ) ;
169- let mut builder = Int32Builder :: with_capacity ( map_array. len ( ) ) ;
170-
171- for i in 0 ..map_array. len ( ) {
172- if map_array. is_null ( i) {
173- builder. append_null ( ) ;
174- } else {
175- let len = map_array. value ( i) . len ( ) ;
176- builder. append_value ( len as i32 )
177- }
178- }
179-
180- Ok ( Arc :: new ( builder. finish ( ) ) )
144+ let length: Vec < i32 > = if array. null_count ( ) == 0 {
145+ map_array
146+ . offsets ( )
147+ . lengths ( )
148+ . map ( |len| len as i32 )
149+ . collect ( )
150+ } else {
151+ map_array
152+ . offsets ( )
153+ . lengths ( )
154+ . enumerate ( )
155+ . map ( |( i, len) | if array. is_null ( i) { -1 } else { len as i32 } )
156+ . collect ( )
157+ } ;
158+ Ok ( Arc :: new ( Int32Array :: from ( length) ) )
181159 }
182- DataType :: Null => Ok ( Arc :: new ( Int32Array :: new_null ( array. len ( ) ) ) ) ,
160+ DataType :: Null => Ok ( Arc :: new ( Int32Array :: from ( vec ! [ - 1 ; array. len( ) ] ) ) ) ,
183161 dt => {
184162 plan_err ! ( "size function does not support type: {}" , dt)
185163 }
186164 }
187165}
188-
189- #[ cfg( test) ]
190- mod tests {
191- use super :: * ;
192- use arrow:: array:: { Int32Array , ListArray , MapArray , StringArray , StructArray } ;
193- use arrow:: buffer:: { NullBuffer , OffsetBuffer } ;
194- use arrow:: datatypes:: { DataType , Field , Fields } ;
195- use datafusion_common:: ScalarValue ;
196- use datafusion_expr:: ReturnFieldArgs ;
197-
198- #[ test]
199- fn test_size_nullability ( ) {
200- let size_fn = SparkSize :: new ( ) ;
201-
202- // Non-nullable list input -> non-nullable output
203- let non_nullable_list = Arc :: new ( Field :: new (
204- "col" ,
205- DataType :: List ( Arc :: new ( Field :: new ( "item" , DataType :: Int32 , true ) ) ) ,
206- false ,
207- ) ) ;
208- let out = size_fn
209- . return_field_from_args ( ReturnFieldArgs {
210- arg_fields : & [ Arc :: clone ( & non_nullable_list) ] ,
211- scalar_arguments : & [ None ] ,
212- } )
213- . unwrap ( ) ;
214-
215- assert ! ( !out. is_nullable( ) ) ;
216- assert_eq ! ( out. data_type( ) , & DataType :: Int32 ) ;
217-
218- // Nullable list output -> nullable output
219- let nullable_list = Arc :: new ( Field :: new (
220- "col" ,
221- DataType :: List ( Arc :: new ( Field :: new ( "item" , DataType :: Int32 , true ) ) ) ,
222- true ,
223- ) ) ;
224- let out = size_fn
225- . return_field_from_args ( ReturnFieldArgs {
226- arg_fields : & [ Arc :: clone ( & nullable_list) ] ,
227- scalar_arguments : & [ None ] ,
228- } )
229- . unwrap ( ) ;
230-
231- assert ! ( out. is_nullable( ) ) ;
232- }
233-
234- #[ test]
235- fn test_size_with_null_scalar ( ) {
236- let size_fn = SparkSize :: new ( ) ;
237-
238- let non_nullable_list = Arc :: new ( Field :: new (
239- "col" ,
240- DataType :: List ( Arc :: new ( Field :: new ( "item" , DataType :: Int32 , true ) ) ) ,
241- false ,
242- ) ) ;
243-
244- // With null scalar argument
245- let null_scalar = ScalarValue :: List ( Arc :: new ( ListArray :: new_null (
246- Arc :: new ( Field :: new ( "item" , DataType :: Int32 , true ) ) ,
247- 1 ,
248- ) ) ) ;
249- let out = size_fn
250- . return_field_from_args ( ReturnFieldArgs {
251- arg_fields : & [ Arc :: clone ( & non_nullable_list) ] ,
252- scalar_arguments : & [ Some ( & null_scalar) ] ,
253- } )
254- . unwrap ( ) ;
255-
256- assert ! ( out. is_nullable( ) ) ;
257- }
258-
259- #[ test]
260- fn test_size_list_array ( ) -> Result < ( ) > {
261- // Create a list array: [[1, 2, 3], [4, 5], null, []]
262- let values = Int32Array :: from ( vec ! [ 1 , 2 , 3 , 4 , 5 ] ) ;
263- let offsets = OffsetBuffer :: new ( vec ! [ 0 , 3 , 5 , 5 , 5 ] . into ( ) ) ;
264- let nulls = NullBuffer :: from ( vec ! [ true , true , false , true ] ) ;
265- let list_array = ListArray :: new (
266- Arc :: new ( Field :: new ( "item" , DataType :: Int32 , true ) ) ,
267- offsets,
268- Arc :: new ( values) ,
269- Some ( nulls) ,
270- ) ;
271-
272- let result = spark_size_inner ( & [ Arc :: new ( list_array) ] ) ?;
273- let result = result. as_any ( ) . downcast_ref :: < Int32Array > ( ) . unwrap ( ) ;
274-
275- assert_eq ! ( result. len( ) , 4 ) ;
276- assert_eq ! ( result. value( 0 ) , 3 ) ; // [1, 2, 3]
277- assert_eq ! ( result. value( 1 ) , 2 ) ; // [4, 5]
278- assert ! ( result. is_null( 2 ) ) ; // null
279- assert_eq ! ( result. value( 3 ) , 0 ) ; // []
280-
281- Ok ( ( ) )
282- }
283-
284- #[ test]
285- fn test_size_map_array ( ) -> Result < ( ) > {
286- // Create a map array with entries
287- let keys = StringArray :: from ( vec ! [ "a" , "b" , "c" , "d" ] ) ;
288- let values = Int32Array :: from ( vec ! [ 1 , 2 , 3 , 4 ] ) ;
289-
290- let entries_field = Arc :: new ( Field :: new (
291- "entries" ,
292- DataType :: Struct ( Fields :: from ( vec ! [
293- Field :: new( "key" , DataType :: Utf8 , false ) ,
294- Field :: new( "value" , DataType :: Int32 , true ) ,
295- ] ) ) ,
296- false ,
297- ) ) ;
298-
299- let entries = StructArray :: from ( vec ! [
300- (
301- Arc :: new( Field :: new( "key" , DataType :: Utf8 , false ) ) ,
302- Arc :: new( keys) as ArrayRef ,
303- ) ,
304- (
305- Arc :: new( Field :: new( "value" , DataType :: Int32 , true ) ) ,
306- Arc :: new( values) as ArrayRef ,
307- ) ,
308- ] ) ;
309-
310- // Map with 3 rows: {a:1, b:2}, {c:3}, null
311- let offsets = OffsetBuffer :: new ( vec ! [ 0 , 2 , 3 , 4 ] . into ( ) ) ;
312- let nulls = NullBuffer :: from ( vec ! [ true , true , false ] ) ;
313- let map_array =
314- MapArray :: new ( entries_field, offsets, entries, Some ( nulls) , false ) ;
315-
316- let result = spark_size_inner ( & [ Arc :: new ( map_array) ] ) ?;
317- let result = result. as_any ( ) . downcast_ref :: < Int32Array > ( ) . unwrap ( ) ;
318-
319- assert_eq ! ( result. len( ) , 3 ) ;
320- assert_eq ! ( result. value( 0 ) , 2 ) ; // {a:1, b:2}
321- assert_eq ! ( result. value( 1 ) , 1 ) ; // {c:3}
322- assert ! ( result. is_null( 2 ) ) ; // null
323-
324- Ok ( ( ) )
325- }
326-
327- #[ test]
328- fn test_size_fixed_size_list ( ) -> Result < ( ) > {
329- // Create a fixed size list of size 3
330- let values = Int32Array :: from ( vec ! [ 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ] ) ;
331- let nulls = NullBuffer :: from ( vec ! [ true , true , false ] ) ;
332- let list_array = FixedSizeListArray :: new (
333- Arc :: new ( Field :: new ( "item" , DataType :: Int32 , true ) ) ,
334- 3 ,
335- Arc :: new ( values) ,
336- Some ( nulls) ,
337- ) ;
338-
339- let result = spark_size_inner ( & [ Arc :: new ( list_array) ] ) ?;
340- let result = result. as_any ( ) . downcast_ref :: < Int32Array > ( ) . unwrap ( ) ;
341-
342- assert_eq ! ( result. len( ) , 3 ) ;
343- assert_eq ! ( result. value( 0 ) , 3 ) ;
344- assert_eq ! ( result. value( 1 ) , 3 ) ;
345- assert ! ( result. is_null( 2 ) ) ;
346-
347- Ok ( ( ) )
348- }
349- }
0 commit comments