diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs index 42d455a05760a..4cca065945352 100644 --- a/datafusion/functions/src/string/concat.rs +++ b/datafusion/functions/src/string/concat.rs @@ -20,6 +20,7 @@ use arrow::datatypes::DataType; use datafusion_expr::sort_properties::ExprProperties; use std::any::Any; use std::sync::Arc; +use crate::array::array_concat; use crate::string::concat; use crate::strings::{ @@ -108,6 +109,16 @@ impl ScalarUDFImpl for ConcatFunc { fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { let ScalarFunctionArgs { args, .. } = args; + // If all arguments are arrays, delegate to array_concat + // If all arguments are arrays, delegate to array_concat + let all_arrays = args.iter().all(|arg| matches!(arg, ColumnarValue::Array(_))); + if all_arrays { + use crate::array::array_concat; + return array_concat().invoke_with_args(args.into()); + } + + + let mut return_datatype = DataType::Utf8; args.iter().for_each(|col| { if col.data_type() == DataType::Utf8View { @@ -385,6 +396,58 @@ mod tests { use arrow::datatypes::Field; use datafusion_common::config::ConfigOptions; + #[test] + fn concat_array_should_match_array_concat() -> Result<()> { + use arrow::array::Int64Array; + + let a = ColumnarValue::Array(Arc::new( + Int64Array::from(vec![Some(1), Some(2), Some(3)]) + )); + let b = ColumnarValue::Array(Arc::new( + Int64Array::from(vec![Some(4), Some(5)]) + )); + + let args = ScalarFunctionArgs { + args: vec![a, b], + arg_fields: vec![ + Arc::new(Field::new("a", DataType::Int64, true)), + Arc::new(Field::new("b", DataType::Int64, true)), + ], + number_rows: 3, + return_field: Field::new( + "f", + List(Arc::new(Field::new("item", Int64, true))), + true, + ) + .into(), + config_options: Arc::new(ConfigOptions::default()), + }; + + let result = ConcatFunc::new().invoke_with_args(args)?; + + match result { + ColumnarValue::Array(array) => { + let values = array + .as_any() + .downcast_ref::() + .unwrap(); + + let inner = values + .value(0) + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!(inner.values(), &[1, 2, 3, 4, 5]); + + } + _ => panic!("Expected array output"), + } + + Ok(()) + } + + #[test] fn test_functions() -> Result<()> { test_function!(