Skip to content

Commit eececb2

Browse files
committed
fix: Support scalar/array args for rpad/read_side_padding
1 parent 0d2c5fc commit eececb2

File tree

4 files changed

+163
-210
lines changed

4 files changed

+163
-210
lines changed

native/spark-expr/src/array_funcs/array_repeat.rs

Lines changed: 2 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use crate::utils::make_scalar_function;
1819
use arrow::array::{
1920
new_null_array, Array, ArrayRef, Capacities, GenericListArray, ListArray, MutableArrayData,
2021
NullBufferBuilder, OffsetSizeTrait, UInt64Array,
@@ -29,44 +30,12 @@ use datafusion::common::{exec_err, DataFusionError, ScalarValue};
2930
use datafusion::logical_expr::ColumnarValue;
3031
use std::sync::Arc;
3132

32-
pub fn make_scalar_function<F>(
33-
inner: F,
34-
) -> impl Fn(&[ColumnarValue]) -> Result<ColumnarValue, DataFusionError>
35-
where
36-
F: Fn(&[ArrayRef]) -> Result<ArrayRef, DataFusionError>,
37-
{
38-
move |args: &[ColumnarValue]| {
39-
// first, identify if any of the arguments is an Array. If yes, store its `len`,
40-
// as any scalar will need to be converted to an array of len `len`.
41-
let len = args
42-
.iter()
43-
.fold(Option::<usize>::None, |acc, arg| match arg {
44-
ColumnarValue::Scalar(_) => acc,
45-
ColumnarValue::Array(a) => Some(a.len()),
46-
});
47-
48-
let is_scalar = len.is_none();
49-
50-
let args = ColumnarValue::values_to_arrays(args)?;
51-
52-
let result = (inner)(&args);
53-
54-
if is_scalar {
55-
// If all inputs are scalar, keeps output as scalar
56-
let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
57-
result.map(ColumnarValue::Scalar)
58-
} else {
59-
result.map(ColumnarValue::Array)
60-
}
61-
}
62-
}
63-
6433
pub fn spark_array_repeat(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
6534
make_scalar_function(spark_array_repeat_inner)(args)
6635
}
6736

6837
/// Array_repeat SQL function
69-
fn spark_array_repeat_inner(args: &[ArrayRef]) -> datafusion::common::Result<ArrayRef> {
38+
fn spark_array_repeat_inner(args: &[ArrayRef]) -> Result<ArrayRef, DataFusionError> {
7039
let element = &args[0];
7140
let count_array = &args[1];
7241

native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs

Lines changed: 101 additions & 157 deletions
Original file line numberDiff line numberDiff line change
@@ -15,148 +15,85 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use crate::utils::make_scalar_function;
1819
use arrow::array::builder::GenericStringBuilder;
19-
use arrow::array::cast::as_dictionary_array;
2020
use arrow::array::types::Int32Type;
21-
use arrow::array::{make_array, Array, AsArray, DictionaryArray};
21+
use arrow::array::{Array, AsArray};
2222
use arrow::array::{ArrayRef, OffsetSizeTrait};
2323
use arrow::datatypes::DataType;
24-
use datafusion::common::{cast::as_generic_string_array, DataFusionError, ScalarValue};
24+
use datafusion::common::{cast::as_generic_string_array, DataFusionError};
2525
use datafusion::physical_plan::ColumnarValue;
2626
use std::sync::Arc;
2727

2828
const SPACE: &str = " ";
2929
/// Similar to DataFusion `rpad`, but not to truncate when the string is already longer than length
3030
pub fn spark_read_side_padding(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
31-
spark_read_side_padding2(args, false)
31+
make_scalar_function(spark_read_side_padding_no_truncate)(args)
3232
}
3333

3434
/// Custom `rpad` because DataFusion's `rpad` has differences in unicode handling
3535
pub fn spark_rpad(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
36+
make_scalar_function(spark_read_side_padding_truncate)(args)
37+
}
38+
39+
pub fn spark_read_side_padding_truncate(args: &[ArrayRef]) -> Result<ArrayRef, DataFusionError> {
3640
spark_read_side_padding2(args, true)
3741
}
3842

43+
pub fn spark_read_side_padding_no_truncate(args: &[ArrayRef]) -> Result<ArrayRef, DataFusionError> {
44+
spark_read_side_padding2(args, false)
45+
}
46+
3947
fn spark_read_side_padding2(
40-
args: &[ColumnarValue],
48+
args: &[ArrayRef],
4149
truncate: bool,
42-
) -> Result<ColumnarValue, DataFusionError> {
50+
) -> Result<ArrayRef, DataFusionError> {
4351
match args {
44-
[ColumnarValue::Array(array), ColumnarValue::Scalar(ScalarValue::Int32(Some(length)))] => {
45-
match array.data_type() {
46-
DataType::Utf8 => spark_read_side_padding_internal::<i32>(
47-
array,
48-
truncate,
49-
ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))),
50-
SPACE,
51-
),
52-
DataType::LargeUtf8 => spark_read_side_padding_internal::<i64>(
53-
array,
54-
truncate,
55-
ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))),
56-
SPACE,
57-
),
58-
// Dictionary support required for SPARK-48498
59-
DataType::Dictionary(_, value_type) => {
60-
let dict = as_dictionary_array::<Int32Type>(array);
61-
let col = if value_type.as_ref() == &DataType::Utf8 {
62-
spark_read_side_padding_internal::<i32>(
63-
dict.values(),
64-
truncate,
65-
ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))),
66-
SPACE,
67-
)?
68-
} else {
69-
spark_read_side_padding_internal::<i64>(
70-
dict.values(),
71-
truncate,
72-
ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))),
73-
SPACE,
74-
)?
75-
};
76-
// col consists of an array, so arg of to_array() is not used. Can be anything
77-
let values = col.to_array(0)?;
78-
let result = DictionaryArray::try_new(dict.keys().clone(), values)?;
79-
Ok(ColumnarValue::Array(make_array(result.into())))
80-
}
81-
other => Err(DataFusionError::Internal(format!(
82-
"Unsupported data type {other:?} for function rpad/read_side_padding",
83-
))),
52+
[array, array_int] => match array.data_type() {
53+
DataType::Utf8 => {
54+
spark_read_side_padding_space_internal::<i32>(array, truncate, array_int)
8455
}
85-
}
86-
[ColumnarValue::Array(array), ColumnarValue::Scalar(ScalarValue::Int32(Some(length))), ColumnarValue::Scalar(ScalarValue::Utf8(Some(string)))] =>
87-
{
88-
match array.data_type() {
89-
DataType::Utf8 => spark_read_side_padding_internal::<i32>(
90-
array,
91-
truncate,
92-
ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))),
93-
string,
94-
),
95-
DataType::LargeUtf8 => spark_read_side_padding_internal::<i64>(
96-
array,
97-
truncate,
98-
ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))),
99-
string,
100-
),
101-
// Dictionary support required for SPARK-48498
102-
DataType::Dictionary(_, value_type) => {
103-
let dict = as_dictionary_array::<Int32Type>(array);
104-
let col = if value_type.as_ref() == &DataType::Utf8 {
105-
spark_read_side_padding_internal::<i32>(
106-
dict.values(),
107-
truncate,
108-
ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))),
109-
SPACE,
110-
)?
111-
} else {
112-
spark_read_side_padding_internal::<i64>(
113-
dict.values(),
114-
truncate,
115-
ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))),
116-
SPACE,
117-
)?
118-
};
119-
// col consists of an array, so arg of to_array() is not used. Can be anything
120-
let values = col.to_array(0)?;
121-
let result = DictionaryArray::try_new(dict.keys().clone(), values)?;
122-
Ok(ColumnarValue::Array(make_array(result.into())))
123-
}
124-
other => Err(DataFusionError::Internal(format!(
125-
"Unsupported data type {other:?} for function rpad/read_side_padding",
126-
))),
56+
DataType::LargeUtf8 => {
57+
spark_read_side_padding_space_internal::<i64>(array, truncate, array_int)
12758
}
128-
}
129-
[ColumnarValue::Array(array), ColumnarValue::Array(array_int)] => match array.data_type() {
130-
DataType::Utf8 => spark_read_side_padding_internal::<i32>(
131-
array,
132-
truncate,
133-
ColumnarValue::Array(Arc::<dyn Array>::clone(array_int)),
134-
SPACE,
135-
),
136-
DataType::LargeUtf8 => spark_read_side_padding_internal::<i64>(
137-
array,
138-
truncate,
139-
ColumnarValue::Array(Arc::<dyn Array>::clone(array_int)),
140-
SPACE,
141-
),
14259
other => Err(DataFusionError::Internal(format!(
14360
"Unsupported data type {other:?} for function rpad/read_side_padding",
14461
))),
14562
},
146-
[ColumnarValue::Array(array), ColumnarValue::Array(array_int), ColumnarValue::Scalar(ScalarValue::Utf8(Some(string)))] => {
147-
match array.data_type() {
148-
DataType::Utf8 => spark_read_side_padding_internal::<i32>(
149-
array,
150-
truncate,
151-
ColumnarValue::Array(Arc::<dyn Array>::clone(array_int)),
152-
string,
153-
),
154-
DataType::LargeUtf8 => spark_read_side_padding_internal::<i64>(
155-
array,
156-
truncate,
157-
ColumnarValue::Array(Arc::<dyn Array>::clone(array_int)),
158-
string,
159-
),
63+
[array, array_int, array_pad_string] => {
64+
match (array.data_type(), array_pad_string.data_type()) {
65+
(DataType::Utf8, DataType::Utf8) => {
66+
spark_read_side_padding_internal::<i32, i32, i32>(
67+
array,
68+
truncate,
69+
array_int,
70+
array_pad_string,
71+
)
72+
}
73+
(DataType::Utf8, DataType::LargeUtf8) => {
74+
spark_read_side_padding_internal::<i32, i64, i64>(
75+
array,
76+
truncate,
77+
array_int,
78+
array_pad_string,
79+
)
80+
}
81+
(DataType::LargeUtf8, DataType::Utf8) => {
82+
spark_read_side_padding_internal::<i64, i32, i64>(
83+
array,
84+
truncate,
85+
array_int,
86+
array_pad_string,
87+
)
88+
}
89+
(DataType::LargeUtf8, DataType::LargeUtf8) => {
90+
spark_read_side_padding_internal::<i64, i64, i64>(
91+
array,
92+
truncate,
93+
array_int,
94+
array_pad_string,
95+
)
96+
}
16097
other => Err(DataFusionError::Internal(format!(
16198
"Unsupported data type {other:?} for function rpad/read_side_padding",
16299
))),
@@ -168,57 +105,64 @@ fn spark_read_side_padding2(
168105
}
169106
}
170107

171-
fn spark_read_side_padding_internal<T: OffsetSizeTrait>(
108+
fn spark_read_side_padding_space_internal<T: OffsetSizeTrait>(
172109
array: &ArrayRef,
173110
truncate: bool,
174-
pad_type: ColumnarValue,
175-
pad_string: &str,
176-
) -> Result<ColumnarValue, DataFusionError> {
111+
array_int: &ArrayRef,
112+
) -> Result<ArrayRef, DataFusionError> {
177113
let string_array = as_generic_string_array::<T>(array)?;
178-
match pad_type {
179-
ColumnarValue::Array(array_int) => {
180-
let int_pad_array = array_int.as_primitive::<Int32Type>();
114+
let int_pad_array = array_int.as_primitive::<Int32Type>();
181115

182-
let mut builder = GenericStringBuilder::<T>::with_capacity(
183-
string_array.len(),
184-
string_array.len() * int_pad_array.len(),
185-
);
116+
let mut builder = GenericStringBuilder::<T>::with_capacity(
117+
string_array.len(),
118+
string_array.len() * int_pad_array.len(),
119+
);
186120

187-
for (string, length) in string_array.iter().zip(int_pad_array) {
188-
match string {
189-
Some(string) => builder.append_value(add_padding_string(
190-
string.parse().unwrap(),
191-
length.unwrap() as usize,
192-
truncate,
193-
pad_string,
194-
)?),
195-
_ => builder.append_null(),
196-
}
197-
}
198-
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
121+
for (string, length) in string_array.iter().zip(int_pad_array) {
122+
match string {
123+
Some(string) => builder.append_value(add_padding_string(
124+
string.parse().unwrap(),
125+
length.unwrap() as usize,
126+
truncate,
127+
SPACE,
128+
)?),
129+
_ => builder.append_null(),
199130
}
200-
ColumnarValue::Scalar(const_pad_length) => {
201-
let length = 0.max(i32::try_from(const_pad_length)?) as usize;
131+
}
132+
Ok(Arc::new(builder.finish()))
133+
}
202134

203-
let mut builder = GenericStringBuilder::<T>::with_capacity(
204-
string_array.len(),
205-
string_array.len() * length,
206-
);
135+
fn spark_read_side_padding_internal<T: OffsetSizeTrait, O: OffsetSizeTrait, S: OffsetSizeTrait>(
136+
array: &ArrayRef,
137+
truncate: bool,
138+
array_int: &ArrayRef,
139+
pad_string_array: &ArrayRef,
140+
) -> Result<ArrayRef, DataFusionError> {
141+
let string_array = as_generic_string_array::<T>(array)?;
142+
let int_pad_array = array_int.as_primitive::<Int32Type>();
143+
let pad_string_array = as_generic_string_array::<O>(pad_string_array)?;
207144

208-
for string in string_array.iter() {
209-
match string {
210-
Some(string) => builder.append_value(add_padding_string(
211-
string.parse().unwrap(),
212-
length,
213-
truncate,
214-
pad_string,
215-
)?),
216-
_ => builder.append_null(),
217-
}
218-
}
219-
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
145+
let mut builder = GenericStringBuilder::<S>::with_capacity(
146+
string_array.len(),
147+
string_array.len() * int_pad_array.len(),
148+
);
149+
150+
for ((string, length), pad_string) in string_array
151+
.iter()
152+
.zip(int_pad_array)
153+
.zip(pad_string_array.iter())
154+
{
155+
match string {
156+
Some(string) => builder.append_value(add_padding_string(
157+
string.parse().unwrap(),
158+
length.unwrap() as usize,
159+
truncate,
160+
pad_string.unwrap_or(SPACE),
161+
)?),
162+
_ => builder.append_null(),
220163
}
221164
}
165+
Ok(Arc::new(builder.finish()))
222166
}
223167

224168
fn add_padding_string(

0 commit comments

Comments
 (0)