diff --git a/datafusion/spark/src/function/collection/mod.rs b/datafusion/spark/src/function/collection/mod.rs index a87df9a2c87a0..6871e3aba6469 100644 --- a/datafusion/spark/src/function/collection/mod.rs +++ b/datafusion/spark/src/function/collection/mod.rs @@ -15,11 +15,20 @@ // specific language governing permissions and limitations // under the License. +pub mod size; + use datafusion_expr::ScalarUDF; +use datafusion_functions::make_udf_function; use std::sync::Arc; -pub mod expr_fn {} +make_udf_function!(size::SparkSize, size); + +pub mod expr_fn { + use datafusion_functions::export_functions; + + export_functions!((size, "Return the size of an array or map.", arg)); +} pub fn functions() -> Vec> { - vec![] + vec![size()] } diff --git a/datafusion/spark/src/function/collection/size.rs b/datafusion/spark/src/function/collection/size.rs new file mode 100644 index 0000000000000..99e6fe485b0d8 --- /dev/null +++ b/datafusion/spark/src/function/collection/size.rs @@ -0,0 +1,165 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, ArrayRef, AsArray, Int32Array}; +use arrow::compute::kernels::length::length as arrow_length; +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::{Result, plan_err}; +use datafusion_expr::{ + ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, ReturnFieldArgs, + ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, +}; +use datafusion_functions::utils::make_scalar_function; +use std::any::Any; +use std::sync::Arc; + +/// Spark-compatible `size` function. +/// +/// Returns the number of elements in an array or the number of key-value pairs in a map. +/// Returns -1 for null input (Spark behavior). +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkSize { + signature: Signature, +} + +impl Default for SparkSize { + fn default() -> Self { + Self::new() + } +} + +impl SparkSize { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + // Array Type + TypeSignature::ArraySignature(ArrayFunctionSignature::Array { + arguments: vec![ArrayFunctionArgument::Array], + array_coercion: None, + }), + // Map Type + TypeSignature::ArraySignature(ArrayFunctionSignature::MapArray), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkSize { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "size" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int32) + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + Ok(Arc::new(Field::new( + self.name(), + DataType::Int32, + args.arg_fields[0].is_nullable(), + ))) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(spark_size_inner, vec![])(&args.args) + } +} + +fn spark_size_inner(args: &[ArrayRef]) -> Result { + let array = &args[0]; + + match array.data_type() { + DataType::List(_) => { + if array.null_count() == 0 { + Ok(arrow_length(array)?) + } else { + let list_array = array.as_list::(); + let lengths: Vec = list_array + .offsets() + .lengths() + .enumerate() + .map(|(i, len)| if array.is_null(i) { -1 } else { len as i32 }) + .collect(); + Ok(Arc::new(Int32Array::from(lengths))) + } + } + DataType::FixedSizeList(_, size) => { + if array.null_count() == 0 { + Ok(arrow_length(array)?) + } else { + let length: Vec = (0..array.len()) + .map(|i| if array.is_null(i) { -1 } else { *size }) + .collect(); + Ok(Arc::new(Int32Array::from(length))) + } + } + DataType::LargeList(_) => { + // Arrow length kernel returns Int64 for LargeList + let list_array = array.as_list::(); + if array.null_count() == 0 { + let lengths: Vec = list_array + .offsets() + .lengths() + .map(|len| len as i32) + .collect(); + Ok(Arc::new(Int32Array::from(lengths))) + } else { + let lengths: Vec = list_array + .offsets() + .lengths() + .enumerate() + .map(|(i, len)| if array.is_null(i) { -1 } else { len as i32 }) + .collect(); + Ok(Arc::new(Int32Array::from(lengths))) + } + } + DataType::Map(_, _) => { + let map_array = array.as_map(); + let length: Vec = if array.null_count() == 0 { + map_array + .offsets() + .lengths() + .map(|len| len as i32) + .collect() + } else { + map_array + .offsets() + .lengths() + .enumerate() + .map(|(i, len)| if array.is_null(i) { -1 } else { len as i32 }) + .collect() + }; + Ok(Arc::new(Int32Array::from(length))) + } + DataType::Null => Ok(Arc::new(Int32Array::from(vec![-1; array.len()]))), + dt => { + plan_err!("size function does not support type: {}", dt) + } + } +} diff --git a/datafusion/sqllogictest/test_files/spark/collection/size.slt b/datafusion/sqllogictest/test_files/spark/collection/size.slt new file mode 100644 index 0000000000000..67200b1de653b --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/collection/size.slt @@ -0,0 +1,129 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT size(array(1, 2, 3)); +## PySpark 3.5.5 Result: {'size(array(1, 2, 3))': 3} + +# Basic array size tests +query I +SELECT size(make_array(1, 2, 3)); +---- +3 + +query I +SELECT size(make_array(1, 2, 3, 4, 5)); +---- +5 + +query I +SELECT size(make_array(1)); +---- +1 + +# Empty array +query I +SELECT size(arrow_cast(make_array(), 'List(Int32)')); +---- +0 + +# Nested arrays +query I +SELECT size(make_array(make_array(1, 2), make_array(3, 4, 5))); +---- +2 + +# Array with NULL elements (size counts elements including NULLs) +query I +SELECT size(make_array(1, NULL, 3)); +---- +3 + +# NULL array returns -1 (Spark behavior) +query I +SELECT size(NULL::int[]); +---- +-1 + +# Map size tests +query I +SELECT size(map(make_array('a', 'b', 'c'), make_array(1, 2, 3))); +---- +3 + +query I +SELECT size(map(make_array('a'), make_array(1))); +---- +1 + +# Empty map +query I +SELECT size(map(arrow_cast(make_array(), 'List(Utf8)'), arrow_cast(make_array(), 'List(Int32)'))); +---- +0 + +# Map with multiple entries +query I +SELECT size(map(make_array('x', 'y', 'z', 'w'), make_array(10, 20, 30, 40))); +---- +4 + +# String array +query I +SELECT size(make_array('hello', 'world')); +---- +2 + +# Boolean array +query I +SELECT size(make_array(true, false, true)); +---- +3 + +# Float array +query I +SELECT size(make_array(1.5, 2.5, 3.5, 4.5)); +---- +4 + +# LargeList tests +query I +SELECT size(arrow_cast(make_array(1, 2, 3), 'LargeList(Int32)')); +---- +3 + +query I +SELECT size(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)')); +---- +5 + +# FixedSizeList tests +query I +SELECT size(arrow_cast(make_array(1, 2, 3), 'FixedSizeList(3, Int32)')); +---- +3 + +query I +SELECT size(arrow_cast(make_array(1, 2, 3, 4), 'FixedSizeList(4, Int32)')); +---- +4 +