Skip to content

Commit ce4e96c

Browse files
feat(spark): implement size function for arrays and maps
Implements the Spark-compatible `size` function that returns the number of elements in an array or the number of key-value pairs in a map. - Supports List, LargeList, FixedSizeList, and Map types - Returns NULL for NULL input (modern Spark 3.0+ behavior) - Returns Int32 to match Spark's IntegerType
1 parent 8959b3d commit ce4e96c

File tree

3 files changed

+466
-2
lines changed

3 files changed

+466
-2
lines changed

datafusion/spark/src/function/collection/mod.rs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,20 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
pub mod size;
19+
1820
use datafusion_expr::ScalarUDF;
21+
use datafusion_functions::make_udf_function;
1922
use std::sync::Arc;
2023

21-
pub mod expr_fn {}
24+
make_udf_function!(size::SparkSize, size);
25+
26+
pub mod expr_fn {
27+
use datafusion_functions::export_functions;
28+
29+
export_functions!((size, "Return the size of an array or map.", arg));
30+
}
2231

2332
pub fn functions() -> Vec<Arc<ScalarUDF>> {
24-
vec![]
33+
vec![size()]
2534
}
Lines changed: 349 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,349 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::array::{
19+
Array, ArrayRef, AsArray, FixedSizeListArray, Int32Array, Int32Builder,
20+
};
21+
use arrow::datatypes::{DataType, Field, FieldRef};
22+
use datafusion_common::{Result, plan_err};
23+
use datafusion_expr::{
24+
ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature,
25+
TypeSignature, Volatility,
26+
};
27+
use datafusion_functions::utils::make_scalar_function;
28+
use std::any::Any;
29+
use std::sync::Arc;
30+
31+
/// Spark-compatible `size` function.
32+
///
33+
/// Returns the number of elements in an array or the number of key-value pairs in a map.
34+
/// Returns null for null input.
35+
#[derive(Debug, PartialEq, Eq, Hash)]
36+
pub struct SparkSize {
37+
signature: Signature,
38+
}
39+
40+
impl Default for SparkSize {
41+
fn default() -> Self {
42+
Self::new()
43+
}
44+
}
45+
46+
impl SparkSize {
47+
pub fn new() -> Self {
48+
Self {
49+
signature: Signature::one_of(
50+
vec![TypeSignature::Any(1)],
51+
Volatility::Immutable,
52+
),
53+
}
54+
}
55+
}
56+
57+
impl ScalarUDFImpl for SparkSize {
58+
fn as_any(&self) -> &dyn Any {
59+
self
60+
}
61+
62+
fn name(&self) -> &str {
63+
"size"
64+
}
65+
66+
fn signature(&self) -> &Signature {
67+
&self.signature
68+
}
69+
70+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
71+
Ok(DataType::Int32)
72+
}
73+
74+
fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
75+
if args.arg_fields.len() != 1 {
76+
return plan_err!("size expects exactly 1 argument");
77+
}
78+
79+
let input_field = &args.arg_fields[0];
80+
81+
match input_field.data_type() {
82+
DataType::List(_)
83+
| DataType::LargeList(_)
84+
| DataType::FixedSizeList(_, _)
85+
| DataType::Map(_, _)
86+
| DataType::Null => {}
87+
dt => {
88+
return plan_err!(
89+
"size function requires array or map types, got: {}",
90+
dt
91+
);
92+
}
93+
}
94+
95+
let mut out_nullable = input_field.is_nullable();
96+
97+
let scala_null_present = args
98+
.scalar_arguments
99+
.iter()
100+
.any(|opt_s| opt_s.is_some_and(|sv| sv.is_null()));
101+
if scala_null_present {
102+
out_nullable = true;
103+
}
104+
105+
Ok(Arc::new(Field::new(
106+
self.name(),
107+
DataType::Int32,
108+
out_nullable,
109+
)))
110+
}
111+
112+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
113+
if args.args.len() != 1 {
114+
return plan_err!("size expects exactly 1 argument");
115+
}
116+
make_scalar_function(spark_size_inner, vec![])(&args.args)
117+
}
118+
}
119+
120+
fn spark_size_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
121+
let array = &args[0];
122+
123+
match array.data_type() {
124+
DataType::List(_) => {
125+
let list_array = array.as_list::<i32>();
126+
let mut builder = Int32Builder::with_capacity(list_array.len());
127+
for i in 0..list_array.len() {
128+
if list_array.is_null(i) {
129+
builder.append_null();
130+
} else {
131+
let len = list_array.value(i).len();
132+
builder.append_value(len as i32)
133+
}
134+
}
135+
136+
Ok(Arc::new(builder.finish()))
137+
}
138+
DataType::LargeList(_) => {
139+
let list_array = array.as_list::<i64>();
140+
let mut builder = Int32Builder::with_capacity(list_array.len());
141+
for i in 0..list_array.len() {
142+
if list_array.is_null(i) {
143+
builder.append_null();
144+
} else {
145+
let len = list_array.value(i).len();
146+
builder.append_value(len as i32)
147+
}
148+
}
149+
150+
Ok(Arc::new(builder.finish()))
151+
}
152+
DataType::FixedSizeList(_, size) => {
153+
let list_array: &FixedSizeListArray = array.as_fixed_size_list();
154+
let fixed_size = *size;
155+
let result: Int32Array = (0..list_array.len())
156+
.map(|i| {
157+
if list_array.is_null(i) {
158+
None
159+
} else {
160+
Some(fixed_size)
161+
}
162+
})
163+
.collect();
164+
165+
Ok(Arc::new(result))
166+
}
167+
DataType::Map(_, _) => {
168+
let map_array = array.as_map();
169+
let mut builder = Int32Builder::with_capacity(map_array.len());
170+
171+
for i in 0..map_array.len() {
172+
if map_array.is_null(i) {
173+
builder.append_null();
174+
} else {
175+
let len = map_array.value(i).len();
176+
builder.append_value(len as i32)
177+
}
178+
}
179+
180+
Ok(Arc::new(builder.finish()))
181+
}
182+
DataType::Null => Ok(Arc::new(Int32Array::new_null(array.len()))),
183+
dt => {
184+
plan_err!("size function does not support type: {}", dt)
185+
}
186+
}
187+
}
188+
189+
#[cfg(test)]
190+
mod tests {
191+
use super::*;
192+
use arrow::array::{Int32Array, ListArray, MapArray, StringArray, StructArray};
193+
use arrow::buffer::{NullBuffer, OffsetBuffer};
194+
use arrow::datatypes::{DataType, Field, Fields};
195+
use datafusion_common::ScalarValue;
196+
use datafusion_expr::ReturnFieldArgs;
197+
198+
#[test]
199+
fn test_size_nullability() {
200+
let size_fn = SparkSize::new();
201+
202+
// Non-nullable list input -> non-nullable output
203+
let non_nullable_list = Arc::new(Field::new(
204+
"col",
205+
DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
206+
false,
207+
));
208+
let out = size_fn
209+
.return_field_from_args(ReturnFieldArgs {
210+
arg_fields: &[Arc::clone(&non_nullable_list)],
211+
scalar_arguments: &[None],
212+
})
213+
.unwrap();
214+
215+
assert!(!out.is_nullable());
216+
assert_eq!(out.data_type(), &DataType::Int32);
217+
218+
// Nullable list output -> nullable output
219+
let nullable_list = Arc::new(Field::new(
220+
"col",
221+
DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
222+
true,
223+
));
224+
let out = size_fn
225+
.return_field_from_args(ReturnFieldArgs {
226+
arg_fields: &[Arc::clone(&nullable_list)],
227+
scalar_arguments: &[None],
228+
})
229+
.unwrap();
230+
231+
assert!(out.is_nullable());
232+
}
233+
234+
#[test]
235+
fn test_size_with_null_scalar() {
236+
let size_fn = SparkSize::new();
237+
238+
let non_nullable_list = Arc::new(Field::new(
239+
"col",
240+
DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
241+
false,
242+
));
243+
244+
// With null scalar argument
245+
let null_scalar = ScalarValue::List(Arc::new(ListArray::new_null(
246+
Arc::new(Field::new("item", DataType::Int32, true)),
247+
1,
248+
)));
249+
let out = size_fn
250+
.return_field_from_args(ReturnFieldArgs {
251+
arg_fields: &[Arc::clone(&non_nullable_list)],
252+
scalar_arguments: &[Some(&null_scalar)],
253+
})
254+
.unwrap();
255+
256+
assert!(out.is_nullable());
257+
}
258+
259+
#[test]
260+
fn test_size_list_array() -> Result<()> {
261+
// Create a list array: [[1, 2, 3], [4, 5], null, []]
262+
let values = Int32Array::from(vec![1, 2, 3, 4, 5]);
263+
let offsets = OffsetBuffer::new(vec![0, 3, 5, 5, 5].into());
264+
let nulls = NullBuffer::from(vec![true, true, false, true]);
265+
let list_array = ListArray::new(
266+
Arc::new(Field::new("item", DataType::Int32, true)),
267+
offsets,
268+
Arc::new(values),
269+
Some(nulls),
270+
);
271+
272+
let result = spark_size_inner(&[Arc::new(list_array)])?;
273+
let result = result.as_any().downcast_ref::<Int32Array>().unwrap();
274+
275+
assert_eq!(result.len(), 4);
276+
assert_eq!(result.value(0), 3); // [1, 2, 3]
277+
assert_eq!(result.value(1), 2); // [4, 5]
278+
assert!(result.is_null(2)); // null
279+
assert_eq!(result.value(3), 0); // []
280+
281+
Ok(())
282+
}
283+
284+
#[test]
285+
fn test_size_map_array() -> Result<()> {
286+
// Create a map array with entries
287+
let keys = StringArray::from(vec!["a", "b", "c", "d"]);
288+
let values = Int32Array::from(vec![1, 2, 3, 4]);
289+
290+
let entries_field = Arc::new(Field::new(
291+
"entries",
292+
DataType::Struct(Fields::from(vec![
293+
Field::new("key", DataType::Utf8, false),
294+
Field::new("value", DataType::Int32, true),
295+
])),
296+
false,
297+
));
298+
299+
let entries = StructArray::from(vec![
300+
(
301+
Arc::new(Field::new("key", DataType::Utf8, false)),
302+
Arc::new(keys) as ArrayRef,
303+
),
304+
(
305+
Arc::new(Field::new("value", DataType::Int32, true)),
306+
Arc::new(values) as ArrayRef,
307+
),
308+
]);
309+
310+
// Map with 3 rows: {a:1, b:2}, {c:3}, null
311+
let offsets = OffsetBuffer::new(vec![0, 2, 3, 4].into());
312+
let nulls = NullBuffer::from(vec![true, true, false]);
313+
let map_array =
314+
MapArray::new(entries_field, offsets, entries, Some(nulls), false);
315+
316+
let result = spark_size_inner(&[Arc::new(map_array)])?;
317+
let result = result.as_any().downcast_ref::<Int32Array>().unwrap();
318+
319+
assert_eq!(result.len(), 3);
320+
assert_eq!(result.value(0), 2); // {a:1, b:2}
321+
assert_eq!(result.value(1), 1); // {c:3}
322+
assert!(result.is_null(2)); // null
323+
324+
Ok(())
325+
}
326+
327+
#[test]
328+
fn test_size_fixed_size_list() -> Result<()> {
329+
// Create a fixed size list of size 3
330+
let values = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);
331+
let nulls = NullBuffer::from(vec![true, true, false]);
332+
let list_array = FixedSizeListArray::new(
333+
Arc::new(Field::new("item", DataType::Int32, true)),
334+
3,
335+
Arc::new(values),
336+
Some(nulls),
337+
);
338+
339+
let result = spark_size_inner(&[Arc::new(list_array)])?;
340+
let result = result.as_any().downcast_ref::<Int32Array>().unwrap();
341+
342+
assert_eq!(result.len(), 3);
343+
assert_eq!(result.value(0), 3);
344+
assert_eq!(result.value(1), 3);
345+
assert!(result.is_null(2));
346+
347+
Ok(())
348+
}
349+
}

0 commit comments

Comments
 (0)