Skip to content

Commit 37ce4eb

Browse files
committed
feat(spark): implement array_repeat function
1 parent 1d5d63c commit 37ce4eb

File tree

6 files changed

+306
-114
lines changed

6 files changed

+306
-114
lines changed

datafusion/spark/src/function/array/mod.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
pub mod repeat;
1819
pub mod shuffle;
1920
pub mod spark_array;
2021

@@ -24,6 +25,7 @@ use std::sync::Arc;
2425

2526
make_udf_function!(spark_array::SparkArray, array);
2627
make_udf_function!(shuffle::SparkShuffle, shuffle);
28+
make_udf_function!(repeat::SparkArrayRepeat, array_repeat);
2729

2830
pub mod expr_fn {
2931
use datafusion_functions::export_functions;
@@ -34,8 +36,13 @@ pub mod expr_fn {
3436
"Returns a random permutation of the given array.",
3537
args
3638
));
39+
export_functions!((
40+
array_repeat,
41+
"returns an array containing element count times.",
42+
element count
43+
));
3744
}
3845

3946
pub fn functions() -> Vec<Arc<ScalarUDF>> {
40-
vec![array(), shuffle()]
47+
vec![array(), shuffle(), array_repeat()]
4148
}
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::datatypes::{DataType, Field};
19+
use datafusion_common::utils::take_function_args;
20+
use datafusion_common::{Result, ScalarValue, exec_err};
21+
use datafusion_expr::{
22+
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
23+
};
24+
use datafusion_functions_nested::repeat::ArrayRepeat;
25+
use std::any::Any;
26+
use std::sync::Arc;
27+
28+
use crate::function::null_utils::{
29+
NullMaskResolution, apply_null_mask, compute_null_mask,
30+
};
31+
32+
/// Spark-compatible `array_repeat` expression. The difference with DataFusion's `array_repeat` is the handling of NULL inputs: in spark if any input is NULL, the result is NULL.
33+
/// <https://spark.apache.org/docs/latest/api/sql/index.html#array_repeat>
34+
#[derive(Debug, PartialEq, Eq, Hash)]
35+
pub struct SparkArrayRepeat {
36+
signature: Signature,
37+
}
38+
39+
impl Default for SparkArrayRepeat {
40+
fn default() -> Self {
41+
Self::new()
42+
}
43+
}
44+
45+
impl SparkArrayRepeat {
46+
pub fn new() -> Self {
47+
Self {
48+
signature: Signature::user_defined(Volatility::Immutable),
49+
}
50+
}
51+
}
52+
53+
impl ScalarUDFImpl for SparkArrayRepeat {
54+
fn as_any(&self) -> &dyn Any {
55+
self
56+
}
57+
58+
fn name(&self) -> &str {
59+
"array_repeat"
60+
}
61+
62+
fn signature(&self) -> &Signature {
63+
&self.signature
64+
}
65+
66+
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
67+
Ok(DataType::List(Arc::new(Field::new_list_field(
68+
arg_types[0].clone(),
69+
true,
70+
))))
71+
}
72+
73+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
74+
spark_array_repeat(args)
75+
}
76+
77+
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
78+
let [first_type, second_type] = take_function_args(self.name(), arg_types)?;
79+
80+
// Coerce the second argument to Int64/UInt64 if it's a numeric type
81+
let second = match second_type {
82+
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
83+
DataType::Int64
84+
}
85+
DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => {
86+
DataType::UInt64
87+
}
88+
_ => return exec_err!("count must be an integer type"),
89+
};
90+
91+
Ok(vec![first_type.clone(), second])
92+
}
93+
}
94+
95+
/// This is a Spark-specific wrapper around DataFusion's array_repeat that returns NULL
96+
/// if any argument is NULL (Spark behavior), whereas DataFusion's array_repeat ignores NULLs.
97+
fn spark_array_repeat(args: ScalarFunctionArgs) -> Result<ColumnarValue> {
98+
let ScalarFunctionArgs {
99+
args: arg_values,
100+
arg_fields,
101+
number_rows,
102+
return_field,
103+
config_options,
104+
} = args;
105+
let return_type = return_field.data_type().clone();
106+
107+
// Step 1: Check for NULL mask in incoming args
108+
let null_mask = compute_null_mask(&arg_values, number_rows)?;
109+
110+
// If any argument is null then return NULL immediately
111+
if matches!(null_mask, NullMaskResolution::ReturnNull) {
112+
return Ok(ColumnarValue::Scalar(ScalarValue::try_from(return_type)?));
113+
}
114+
115+
// Step 2: Delegate to DataFusion's array_repeat
116+
let array_repeat_func = ArrayRepeat::new();
117+
let func_args = ScalarFunctionArgs {
118+
args: arg_values,
119+
arg_fields,
120+
number_rows,
121+
return_field,
122+
config_options,
123+
};
124+
let result = array_repeat_func.invoke_with_args(func_args)?;
125+
126+
// Step 3: Apply NULL mask to result
127+
apply_null_mask(result, null_mask, &return_type)
128+
}

datafusion/spark/src/function/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ pub mod lambda;
3333
pub mod map;
3434
pub mod math;
3535
pub mod misc;
36+
pub mod null_utils;
3637
pub mod predicate;
3738
pub mod string;
3839
pub mod r#struct;
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::array::Array;
19+
use arrow::buffer::NullBuffer;
20+
use arrow::datatypes::DataType;
21+
use datafusion_common::{Result, ScalarValue};
22+
use datafusion_expr::ColumnarValue;
23+
use std::sync::Arc;
24+
25+
pub(crate) enum NullMaskResolution {
26+
/// Return NULL as the result (e.g., scalar inputs with at least one NULL)
27+
ReturnNull,
28+
/// No null mask needed (e.g., all scalar inputs are non-NULL)
29+
NoMask,
30+
/// Null mask to apply for arrays
31+
Apply(NullBuffer),
32+
}
33+
34+
/// Compute NULL mask for the arguments using NullBuffer::union
35+
pub(crate) fn compute_null_mask(
36+
args: &[ColumnarValue],
37+
number_rows: usize,
38+
) -> Result<NullMaskResolution> {
39+
// Check if all arguments are scalars
40+
let all_scalars = args
41+
.iter()
42+
.all(|arg| matches!(arg, ColumnarValue::Scalar(_)));
43+
44+
if all_scalars {
45+
// For scalars, check if any is NULL
46+
for arg in args {
47+
if let ColumnarValue::Scalar(scalar) = arg
48+
&& scalar.is_null()
49+
{
50+
return Ok(NullMaskResolution::ReturnNull);
51+
}
52+
}
53+
// No NULLs in scalars
54+
Ok(NullMaskResolution::NoMask)
55+
} else {
56+
// For arrays, compute NULL mask for each row using NullBuffer::union
57+
let array_len = args
58+
.iter()
59+
.find_map(|arg| match arg {
60+
ColumnarValue::Array(array) => Some(array.len()),
61+
_ => None,
62+
})
63+
.unwrap_or(number_rows);
64+
65+
// Convert all scalars to arrays for uniform processing
66+
let arrays: Result<Vec<_>> = args
67+
.iter()
68+
.map(|arg| match arg {
69+
ColumnarValue::Array(array) => Ok(Arc::clone(array)),
70+
ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(array_len),
71+
})
72+
.collect();
73+
let arrays = arrays?;
74+
75+
// Use NullBuffer::union to combine all null buffers
76+
let combined_nulls = arrays
77+
.iter()
78+
.map(|arr| arr.nulls())
79+
.fold(None, |acc, nulls| NullBuffer::union(acc.as_ref(), nulls));
80+
81+
match combined_nulls {
82+
Some(nulls) => Ok(NullMaskResolution::Apply(nulls)),
83+
None => Ok(NullMaskResolution::NoMask),
84+
}
85+
}
86+
}
87+
88+
/// Apply NULL mask to the result using NullBuffer::union
89+
pub(crate) fn apply_null_mask(
90+
result: ColumnarValue,
91+
null_mask: NullMaskResolution,
92+
return_type: &DataType,
93+
) -> Result<ColumnarValue> {
94+
match (result, null_mask) {
95+
// Scalar with ReturnNull mask means return NULL of the correct type
96+
(ColumnarValue::Scalar(_), NullMaskResolution::ReturnNull) => {
97+
Ok(ColumnarValue::Scalar(ScalarValue::try_from(return_type)?))
98+
}
99+
// Scalar without mask, return as-is
100+
(scalar @ ColumnarValue::Scalar(_), NullMaskResolution::NoMask) => Ok(scalar),
101+
// Array with NULL mask - use NullBuffer::union to combine nulls
102+
(ColumnarValue::Array(array), NullMaskResolution::Apply(null_mask)) => {
103+
// Combine the result's existing nulls with our computed null mask
104+
let combined_nulls = NullBuffer::union(array.nulls(), Some(&null_mask));
105+
106+
// Create new array with combined nulls
107+
let new_array = array
108+
.into_data()
109+
.into_builder()
110+
.nulls(combined_nulls)
111+
.build()?;
112+
113+
Ok(ColumnarValue::Array(Arc::new(arrow::array::make_array(
114+
new_array,
115+
))))
116+
}
117+
// Array without NULL mask, return as-is
118+
(array @ ColumnarValue::Array(_), NullMaskResolution::NoMask) => Ok(array),
119+
// Edge cases that shouldn't happen in practice
120+
(scalar, _) => Ok(scalar),
121+
}
122+
}

0 commit comments

Comments
 (0)