Skip to content

Commit 5e7ff11

Browse files
Use Arrow kernels and stricter signature for size function
1 parent e347504 commit 5e7ff11

File tree

2 files changed

+93
-254
lines changed
  • datafusion
    • spark/src/function/collection
    • sqllogictest/test_files/spark/collection

2 files changed

+93
-254
lines changed

datafusion/spark/src/function/collection/size.rs

Lines changed: 68 additions & 252 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,13 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use arrow::array::{
19-
Array, ArrayRef, AsArray, FixedSizeListArray, Int32Array, Int32Builder,
20-
};
18+
use arrow::array::{Array, ArrayRef, AsArray, Int32Array};
19+
use arrow::compute::kernels::length::length as arrow_length;
2120
use arrow::datatypes::{DataType, Field, FieldRef};
2221
use datafusion_common::{Result, plan_err};
2322
use datafusion_expr::{
24-
ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature,
25-
TypeSignature, Volatility,
23+
ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, ReturnFieldArgs,
24+
ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility,
2625
};
2726
use datafusion_functions::utils::make_scalar_function;
2827
use std::any::Any;
@@ -31,7 +30,7 @@ use std::sync::Arc;
3130
/// Spark-compatible `size` function.
3231
///
3332
/// Returns the number of elements in an array or the number of key-value pairs in a map.
34-
/// Returns null for null input.
33+
/// Returns -1 for null input (Spark behavior).
3534
#[derive(Debug, PartialEq, Eq, Hash)]
3635
pub struct SparkSize {
3736
signature: Signature,
@@ -47,7 +46,15 @@ impl SparkSize {
4746
pub fn new() -> Self {
4847
Self {
4948
signature: Signature::one_of(
50-
vec![TypeSignature::Any(1)],
49+
vec![
50+
// Array Type
51+
TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
52+
arguments: vec![ArrayFunctionArgument::Array],
53+
array_coercion: None,
54+
}),
55+
// Map Type
56+
TypeSignature::ArraySignature(ArrayFunctionSignature::MapArray),
57+
],
5158
Volatility::Immutable,
5259
),
5360
}
@@ -72,47 +79,14 @@ impl ScalarUDFImpl for SparkSize {
7279
}
7380

7481
fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
75-
if args.arg_fields.len() != 1 {
76-
return plan_err!("size expects exactly 1 argument");
77-
}
78-
79-
let input_field = &args.arg_fields[0];
80-
81-
match input_field.data_type() {
82-
DataType::List(_)
83-
| DataType::LargeList(_)
84-
| DataType::FixedSizeList(_, _)
85-
| DataType::Map(_, _)
86-
| DataType::Null => {}
87-
dt => {
88-
return plan_err!(
89-
"size function requires array or map types, got: {}",
90-
dt
91-
);
92-
}
93-
}
94-
95-
let mut out_nullable = input_field.is_nullable();
96-
97-
let scala_null_present = args
98-
.scalar_arguments
99-
.iter()
100-
.any(|opt_s| opt_s.is_some_and(|sv| sv.is_null()));
101-
if scala_null_present {
102-
out_nullable = true;
103-
}
104-
10582
Ok(Arc::new(Field::new(
10683
self.name(),
10784
DataType::Int32,
108-
out_nullable,
85+
args.arg_fields[0].is_nullable(),
10986
)))
11087
}
11188

11289
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
113-
if args.args.len() != 1 {
114-
return plan_err!("size expects exactly 1 argument");
115-
}
11690
make_scalar_function(spark_size_inner, vec![])(&args.args)
11791
}
11892
}
@@ -122,228 +96,70 @@ fn spark_size_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
12296

12397
match array.data_type() {
12498
DataType::List(_) => {
125-
let list_array = array.as_list::<i32>();
126-
let mut builder = Int32Builder::with_capacity(list_array.len());
127-
for i in 0..list_array.len() {
128-
if list_array.is_null(i) {
129-
builder.append_null();
130-
} else {
131-
let len = list_array.value(i).len();
132-
builder.append_value(len as i32)
133-
}
99+
if array.null_count() == 0 {
100+
Ok(arrow_length(array)?)
101+
} else {
102+
let list_array = array.as_list::<i32>();
103+
let lengths: Vec<i32> = list_array
104+
.offsets()
105+
.lengths()
106+
.enumerate()
107+
.map(|(i, len)| if array.is_null(i) { -1 } else { len as i32 })
108+
.collect();
109+
Ok(Arc::new(Int32Array::from(lengths)))
110+
}
111+
}
112+
DataType::FixedSizeList(_, size) => {
113+
if array.null_count() == 0 {
114+
Ok(arrow_length(array)?)
115+
} else {
116+
let length: Vec<i32> = (0..array.len())
117+
.map(|i| if array.is_null(i) { -1 } else { *size })
118+
.collect();
119+
Ok(Arc::new(Int32Array::from(length)))
134120
}
135-
136-
Ok(Arc::new(builder.finish()))
137121
}
138122
DataType::LargeList(_) => {
123+
// Arrow length kernel returns Int64 for LargeList
139124
let list_array = array.as_list::<i64>();
140-
let mut builder = Int32Builder::with_capacity(list_array.len());
141-
for i in 0..list_array.len() {
142-
if list_array.is_null(i) {
143-
builder.append_null();
144-
} else {
145-
let len = list_array.value(i).len();
146-
builder.append_value(len as i32)
147-
}
125+
if array.null_count() == 0 {
126+
let lengths: Vec<i32> = list_array
127+
.offsets()
128+
.lengths()
129+
.map(|len| len as i32)
130+
.collect();
131+
Ok(Arc::new(Int32Array::from(lengths)))
132+
} else {
133+
let lengths: Vec<i32> = list_array
134+
.offsets()
135+
.lengths()
136+
.enumerate()
137+
.map(|(i, len)| if array.is_null(i) { -1 } else { len as i32 })
138+
.collect();
139+
Ok(Arc::new(Int32Array::from(lengths)))
148140
}
149-
150-
Ok(Arc::new(builder.finish()))
151-
}
152-
DataType::FixedSizeList(_, size) => {
153-
let list_array: &FixedSizeListArray = array.as_fixed_size_list();
154-
let fixed_size = *size;
155-
let result: Int32Array = (0..list_array.len())
156-
.map(|i| {
157-
if list_array.is_null(i) {
158-
None
159-
} else {
160-
Some(fixed_size)
161-
}
162-
})
163-
.collect();
164-
165-
Ok(Arc::new(result))
166141
}
167142
DataType::Map(_, _) => {
168143
let map_array = array.as_map();
169-
let mut builder = Int32Builder::with_capacity(map_array.len());
170-
171-
for i in 0..map_array.len() {
172-
if map_array.is_null(i) {
173-
builder.append_null();
174-
} else {
175-
let len = map_array.value(i).len();
176-
builder.append_value(len as i32)
177-
}
178-
}
179-
180-
Ok(Arc::new(builder.finish()))
144+
let length: Vec<i32> = if array.null_count() == 0 {
145+
map_array
146+
.offsets()
147+
.lengths()
148+
.map(|len| len as i32)
149+
.collect()
150+
} else {
151+
map_array
152+
.offsets()
153+
.lengths()
154+
.enumerate()
155+
.map(|(i, len)| if array.is_null(i) { -1 } else { len as i32 })
156+
.collect()
157+
};
158+
Ok(Arc::new(Int32Array::from(length)))
181159
}
182-
DataType::Null => Ok(Arc::new(Int32Array::new_null(array.len()))),
160+
DataType::Null => Ok(Arc::new(Int32Array::from(vec![-1; array.len()]))),
183161
dt => {
184162
plan_err!("size function does not support type: {}", dt)
185163
}
186164
}
187165
}
188-
189-
#[cfg(test)]
190-
mod tests {
191-
use super::*;
192-
use arrow::array::{Int32Array, ListArray, MapArray, StringArray, StructArray};
193-
use arrow::buffer::{NullBuffer, OffsetBuffer};
194-
use arrow::datatypes::{DataType, Field, Fields};
195-
use datafusion_common::ScalarValue;
196-
use datafusion_expr::ReturnFieldArgs;
197-
198-
#[test]
199-
fn test_size_nullability() {
200-
let size_fn = SparkSize::new();
201-
202-
// Non-nullable list input -> non-nullable output
203-
let non_nullable_list = Arc::new(Field::new(
204-
"col",
205-
DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
206-
false,
207-
));
208-
let out = size_fn
209-
.return_field_from_args(ReturnFieldArgs {
210-
arg_fields: &[Arc::clone(&non_nullable_list)],
211-
scalar_arguments: &[None],
212-
})
213-
.unwrap();
214-
215-
assert!(!out.is_nullable());
216-
assert_eq!(out.data_type(), &DataType::Int32);
217-
218-
// Nullable list output -> nullable output
219-
let nullable_list = Arc::new(Field::new(
220-
"col",
221-
DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
222-
true,
223-
));
224-
let out = size_fn
225-
.return_field_from_args(ReturnFieldArgs {
226-
arg_fields: &[Arc::clone(&nullable_list)],
227-
scalar_arguments: &[None],
228-
})
229-
.unwrap();
230-
231-
assert!(out.is_nullable());
232-
}
233-
234-
#[test]
235-
fn test_size_with_null_scalar() {
236-
let size_fn = SparkSize::new();
237-
238-
let non_nullable_list = Arc::new(Field::new(
239-
"col",
240-
DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
241-
false,
242-
));
243-
244-
// With null scalar argument
245-
let null_scalar = ScalarValue::List(Arc::new(ListArray::new_null(
246-
Arc::new(Field::new("item", DataType::Int32, true)),
247-
1,
248-
)));
249-
let out = size_fn
250-
.return_field_from_args(ReturnFieldArgs {
251-
arg_fields: &[Arc::clone(&non_nullable_list)],
252-
scalar_arguments: &[Some(&null_scalar)],
253-
})
254-
.unwrap();
255-
256-
assert!(out.is_nullable());
257-
}
258-
259-
#[test]
260-
fn test_size_list_array() -> Result<()> {
261-
// Create a list array: [[1, 2, 3], [4, 5], null, []]
262-
let values = Int32Array::from(vec![1, 2, 3, 4, 5]);
263-
let offsets = OffsetBuffer::new(vec![0, 3, 5, 5, 5].into());
264-
let nulls = NullBuffer::from(vec![true, true, false, true]);
265-
let list_array = ListArray::new(
266-
Arc::new(Field::new("item", DataType::Int32, true)),
267-
offsets,
268-
Arc::new(values),
269-
Some(nulls),
270-
);
271-
272-
let result = spark_size_inner(&[Arc::new(list_array)])?;
273-
let result = result.as_any().downcast_ref::<Int32Array>().unwrap();
274-
275-
assert_eq!(result.len(), 4);
276-
assert_eq!(result.value(0), 3); // [1, 2, 3]
277-
assert_eq!(result.value(1), 2); // [4, 5]
278-
assert!(result.is_null(2)); // null
279-
assert_eq!(result.value(3), 0); // []
280-
281-
Ok(())
282-
}
283-
284-
#[test]
285-
fn test_size_map_array() -> Result<()> {
286-
// Create a map array with entries
287-
let keys = StringArray::from(vec!["a", "b", "c", "d"]);
288-
let values = Int32Array::from(vec![1, 2, 3, 4]);
289-
290-
let entries_field = Arc::new(Field::new(
291-
"entries",
292-
DataType::Struct(Fields::from(vec![
293-
Field::new("key", DataType::Utf8, false),
294-
Field::new("value", DataType::Int32, true),
295-
])),
296-
false,
297-
));
298-
299-
let entries = StructArray::from(vec![
300-
(
301-
Arc::new(Field::new("key", DataType::Utf8, false)),
302-
Arc::new(keys) as ArrayRef,
303-
),
304-
(
305-
Arc::new(Field::new("value", DataType::Int32, true)),
306-
Arc::new(values) as ArrayRef,
307-
),
308-
]);
309-
310-
// Map with 3 rows: {a:1, b:2}, {c:3}, null
311-
let offsets = OffsetBuffer::new(vec![0, 2, 3, 4].into());
312-
let nulls = NullBuffer::from(vec![true, true, false]);
313-
let map_array =
314-
MapArray::new(entries_field, offsets, entries, Some(nulls), false);
315-
316-
let result = spark_size_inner(&[Arc::new(map_array)])?;
317-
let result = result.as_any().downcast_ref::<Int32Array>().unwrap();
318-
319-
assert_eq!(result.len(), 3);
320-
assert_eq!(result.value(0), 2); // {a:1, b:2}
321-
assert_eq!(result.value(1), 1); // {c:3}
322-
assert!(result.is_null(2)); // null
323-
324-
Ok(())
325-
}
326-
327-
#[test]
328-
fn test_size_fixed_size_list() -> Result<()> {
329-
// Create a fixed size list of size 3
330-
let values = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);
331-
let nulls = NullBuffer::from(vec![true, true, false]);
332-
let list_array = FixedSizeListArray::new(
333-
Arc::new(Field::new("item", DataType::Int32, true)),
334-
3,
335-
Arc::new(values),
336-
Some(nulls),
337-
);
338-
339-
let result = spark_size_inner(&[Arc::new(list_array)])?;
340-
let result = result.as_any().downcast_ref::<Int32Array>().unwrap();
341-
342-
assert_eq!(result.len(), 3);
343-
assert_eq!(result.value(0), 3);
344-
assert_eq!(result.value(1), 3);
345-
assert!(result.is_null(2));
346-
347-
Ok(())
348-
}
349-
}

0 commit comments

Comments
 (0)