Skip to content

Commit 3472aa1

Browse files
authored
Implement spark array function array (#16936)
* feat(spark): implement spark array function array Signed-off-by: Alan Tang <[email protected]> * chore: add license header Signed-off-by: Alan Tang <[email protected]> * chore: fix clippy error Signed-off-by: Alan Tang <[email protected]> * feat: add with_list_field_name method and more tests Signed-off-by: Alan Tang <[email protected]> * feat: add name field to SparkArray structure Signed-off-by: Alan Tang <[email protected]> * chore: hardcode field name Signed-off-by: Alan Tang <[email protected]> * chore: fix clippy error Signed-off-by: Alan Tang <[email protected]> --------- Signed-off-by: Alan Tang <[email protected]>
1 parent b0c8dd6 commit 3472aa1

File tree

6 files changed

+395
-13
lines changed

6 files changed

+395
-13
lines changed

datafusion/spark/src/function/array/mod.rs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,20 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
pub mod spark_array;
19+
1820
use datafusion_expr::ScalarUDF;
21+
use datafusion_functions::make_udf_function;
1922
use std::sync::Arc;
2023

21-
pub mod expr_fn {}
24+
make_udf_function!(spark_array::SparkArray, array);
25+
26+
pub mod expr_fn {
27+
use datafusion_functions::export_functions;
28+
29+
export_functions!((array, "Returns an array with the given elements.", args));
30+
}
2231

2332
pub fn functions() -> Vec<Arc<ScalarUDF>> {
24-
vec![]
33+
vec![array()]
2534
}
Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use std::{any::Any, sync::Arc};
19+
20+
use arrow::array::{
21+
make_array, new_null_array, Array, ArrayData, ArrayRef, Capacities, GenericListArray,
22+
MutableArrayData, NullArray, OffsetSizeTrait,
23+
};
24+
use arrow::buffer::OffsetBuffer;
25+
use arrow::datatypes::{DataType, Field, FieldRef};
26+
use datafusion_common::utils::SingleRowListArrayBuilder;
27+
use datafusion_common::{plan_datafusion_err, plan_err, Result};
28+
use datafusion_expr::type_coercion::binary::comparison_coercion;
29+
use datafusion_expr::{
30+
ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature,
31+
TypeSignature, Volatility,
32+
};
33+
34+
use crate::function::functions_nested_utils::make_scalar_function;
35+
36+
const ARRAY_FIELD_DEFAULT_NAME: &str = "element";
37+
38+
#[derive(Debug)]
39+
pub struct SparkArray {
40+
signature: Signature,
41+
aliases: Vec<String>,
42+
}
43+
44+
impl Default for SparkArray {
45+
fn default() -> Self {
46+
Self::new()
47+
}
48+
}
49+
50+
impl SparkArray {
51+
pub fn new() -> Self {
52+
Self {
53+
signature: Signature::one_of(
54+
vec![TypeSignature::UserDefined, TypeSignature::Nullary],
55+
Volatility::Immutable,
56+
),
57+
aliases: vec![String::from("spark_make_array")],
58+
}
59+
}
60+
}
61+
62+
impl ScalarUDFImpl for SparkArray {
63+
fn as_any(&self) -> &dyn Any {
64+
self
65+
}
66+
67+
fn name(&self) -> &str {
68+
"array"
69+
}
70+
71+
fn signature(&self) -> &Signature {
72+
&self.signature
73+
}
74+
75+
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
76+
match arg_types.len() {
77+
0 => Ok(empty_array_type()),
78+
_ => {
79+
let mut expr_type = DataType::Null;
80+
for arg_type in arg_types {
81+
if !arg_type.equals_datatype(&DataType::Null) {
82+
expr_type = arg_type.clone();
83+
break;
84+
}
85+
}
86+
87+
if expr_type.is_null() {
88+
expr_type = DataType::Int32;
89+
}
90+
91+
Ok(DataType::List(Arc::new(Field::new_list_field(
92+
expr_type, true,
93+
))))
94+
}
95+
}
96+
}
97+
98+
fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
99+
let data_types = args
100+
.arg_fields
101+
.iter()
102+
.map(|f| f.data_type())
103+
.cloned()
104+
.collect::<Vec<_>>();
105+
let return_type = self.return_type(&data_types)?;
106+
Ok(Arc::new(Field::new(
107+
ARRAY_FIELD_DEFAULT_NAME,
108+
return_type,
109+
false,
110+
)))
111+
}
112+
113+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
114+
let ScalarFunctionArgs { args, .. } = args;
115+
make_scalar_function(make_array_inner)(args.as_slice())
116+
}
117+
118+
fn aliases(&self) -> &[String] {
119+
&self.aliases
120+
}
121+
122+
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
123+
let first_type = arg_types.first().ok_or_else(|| {
124+
plan_datafusion_err!("Spark array function requires at least one argument")
125+
})?;
126+
let new_type =
127+
arg_types
128+
.iter()
129+
.skip(1)
130+
.try_fold(first_type.clone(), |acc, x| {
131+
// The coerced types found by `comparison_coercion` are not guaranteed to be
132+
// coercible for the arguments. `comparison_coercion` returns more loose
133+
// types that can be coerced to both `acc` and `x` for comparison purpose.
134+
// See `maybe_data_types` for the actual coercion.
135+
let coerced_type = comparison_coercion(&acc, x);
136+
if let Some(coerced_type) = coerced_type {
137+
Ok(coerced_type)
138+
} else {
139+
plan_err!("Coercion from {acc:?} to {x:?} failed.")
140+
}
141+
})?;
142+
Ok(vec![new_type; arg_types.len()])
143+
}
144+
}
145+
146+
// Empty array is a special case that is useful for many other array functions
147+
pub(super) fn empty_array_type() -> DataType {
148+
DataType::List(Arc::new(Field::new(
149+
ARRAY_FIELD_DEFAULT_NAME,
150+
DataType::Int32,
151+
true,
152+
)))
153+
}
154+
155+
/// `make_array_inner` is the implementation of the `make_array` function.
156+
/// Constructs an array using the input `data` as `ArrayRef`.
157+
/// Returns a reference-counted `Array` instance result.
158+
pub fn make_array_inner(arrays: &[ArrayRef]) -> Result<ArrayRef> {
159+
let mut data_type = DataType::Null;
160+
for arg in arrays {
161+
let arg_data_type = arg.data_type();
162+
if !arg_data_type.equals_datatype(&DataType::Null) {
163+
data_type = arg_data_type.clone();
164+
break;
165+
}
166+
}
167+
168+
match data_type {
169+
// Either an empty array or all nulls:
170+
DataType::Null => {
171+
let length = arrays.iter().map(|a| a.len()).sum();
172+
// By default Int32
173+
let array = new_null_array(&DataType::Int32, length);
174+
Ok(Arc::new(
175+
SingleRowListArrayBuilder::new(array)
176+
.with_nullable(true)
177+
.build_list_array(),
178+
))
179+
}
180+
DataType::LargeList(..) => array_array::<i64>(arrays, data_type),
181+
_ => array_array::<i32>(arrays, data_type),
182+
}
183+
}
184+
185+
/// Convert one or more [`ArrayRef`] of the same type into a
186+
/// `ListArray` or 'LargeListArray' depending on the offset size.
187+
///
188+
/// # Example (non nested)
189+
///
190+
/// Calling `array(col1, col2)` where col1 and col2 are non nested
191+
/// would return a single new `ListArray`, where each row was a list
192+
/// of 2 elements:
193+
///
194+
/// ```text
195+
/// ┌─────────┐ ┌─────────┐ ┌──────────────┐
196+
/// │ ┌─────┐ │ │ ┌─────┐ │ │ ┌──────────┐ │
197+
/// │ │ A │ │ │ │ X │ │ │ │ [A, X] │ │
198+
/// │ ├─────┤ │ │ ├─────┤ │ │ ├──────────┤ │
199+
/// │ │NULL │ │ │ │ Y │ │──────────▶│ │[NULL, Y] │ │
200+
/// │ ├─────┤ │ │ ├─────┤ │ │ ├──────────┤ │
201+
/// │ │ C │ │ │ │ Z │ │ │ │ [C, Z] │ │
202+
/// │ └─────┘ │ │ └─────┘ │ │ └──────────┘ │
203+
/// └─────────┘ └─────────┘ └──────────────┘
204+
/// col1 col2 output
205+
/// ```
206+
///
207+
/// # Example (nested)
208+
///
209+
/// Calling `array(col1, col2)` where col1 and col2 are lists
210+
/// would return a single new `ListArray`, where each row was a list
211+
/// of the corresponding elements of col1 and col2.
212+
///
213+
/// ``` text
214+
/// ┌──────────────┐ ┌──────────────┐ ┌─────────────────────────────┐
215+
/// │ ┌──────────┐ │ │ ┌──────────┐ │ │ ┌────────────────────────┐ │
216+
/// │ │ [A, X] │ │ │ │ [] │ │ │ │ [[A, X], []] │ │
217+
/// │ ├──────────┤ │ │ ├──────────┤ │ │ ├────────────────────────┤ │
218+
/// │ │[NULL, Y] │ │ │ │[Q, R, S] │ │───────▶│ │ [[NULL, Y], [Q, R, S]] │ │
219+
/// │ ├──────────┤ │ │ ├──────────┤ │ │ ├────────────────────────│ │
220+
/// │ │ [C, Z] │ │ │ │ NULL │ │ │ │ [[C, Z], NULL] │ │
221+
/// │ └──────────┘ │ │ └──────────┘ │ │ └────────────────────────┘ │
222+
/// └──────────────┘ └──────────────┘ └─────────────────────────────┘
223+
/// col1 col2 output
224+
/// ```
225+
fn array_array<O: OffsetSizeTrait>(
226+
args: &[ArrayRef],
227+
data_type: DataType,
228+
) -> Result<ArrayRef> {
229+
// do not accept 0 arguments.
230+
if args.is_empty() {
231+
return plan_err!("Array requires at least one argument");
232+
}
233+
234+
let mut data = vec![];
235+
let mut total_len = 0;
236+
for arg in args {
237+
let arg_data = if arg.as_any().is::<NullArray>() {
238+
ArrayData::new_empty(&data_type)
239+
} else {
240+
arg.to_data()
241+
};
242+
total_len += arg_data.len();
243+
data.push(arg_data);
244+
}
245+
246+
let mut offsets: Vec<O> = Vec::with_capacity(total_len);
247+
offsets.push(O::usize_as(0));
248+
249+
let capacity = Capacities::Array(total_len);
250+
let data_ref = data.iter().collect::<Vec<_>>();
251+
let mut mutable = MutableArrayData::with_capacities(data_ref, true, capacity);
252+
253+
let num_rows = args[0].len();
254+
for row_idx in 0..num_rows {
255+
for (arr_idx, arg) in args.iter().enumerate() {
256+
if !arg.as_any().is::<NullArray>()
257+
&& !arg.is_null(row_idx)
258+
&& arg.is_valid(row_idx)
259+
{
260+
mutable.extend(arr_idx, row_idx, row_idx + 1);
261+
} else {
262+
mutable.extend_nulls(1);
263+
}
264+
}
265+
offsets.push(O::usize_as(mutable.len()));
266+
}
267+
let data = mutable.freeze();
268+
269+
Ok(Arc::new(GenericListArray::<O>::try_new(
270+
Arc::new(Field::new(ARRAY_FIELD_DEFAULT_NAME, data_type, true)),
271+
OffsetBuffer::new(offsets.into()),
272+
make_array(data),
273+
None,
274+
)?))
275+
}
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::array::{Array, ArrayRef};
19+
use datafusion_common::{Result, ScalarValue};
20+
use datafusion_expr::ColumnarValue;
21+
22+
/// array function wrapper that differentiates between scalar (length 1) and array.
23+
pub(crate) fn make_scalar_function<F>(
24+
inner: F,
25+
) -> impl Fn(&[ColumnarValue]) -> Result<ColumnarValue>
26+
where
27+
F: Fn(&[ArrayRef]) -> Result<ArrayRef>,
28+
{
29+
move |args: &[ColumnarValue]| {
30+
// first, identify if any of the arguments is an Array. If yes, store its `len`,
31+
// as any scalar will need to be converted to an array of len `len`.
32+
let len = args
33+
.iter()
34+
.fold(Option::<usize>::None, |acc, arg| match arg {
35+
ColumnarValue::Scalar(_) => acc,
36+
ColumnarValue::Array(a) => Some(a.len()),
37+
});
38+
39+
let is_scalar = len.is_none();
40+
41+
let args = ColumnarValue::values_to_arrays(args)?;
42+
43+
let result = (inner)(&args);
44+
45+
if is_scalar {
46+
// If all inputs are scalar, keeps output as scalar
47+
let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
48+
result.map(ColumnarValue::Scalar)
49+
} else {
50+
result.map(ColumnarValue::Array)
51+
}
52+
}
53+
}

datafusion/spark/src/function/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ pub mod conversion;
2424
pub mod csv;
2525
pub mod datetime;
2626
pub mod error_utils;
27+
pub mod functions_nested_utils;
2728
pub mod generator;
2829
pub mod hash;
2930
pub mod json;

0 commit comments

Comments
 (0)