Skip to content

Commit 54b848c

Browse files
authored
feat(spark): implement substring function (#19805)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> - Part of #15914 - Closes #19803. ## Rationale for this change <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> ## What changes are included in this PR? Implementation of spark substring function. Spark implementation: - https://github.com/apache/spark/blob/6831481fd7a2d30dfa16b4b70c8e6296b4deeb8c/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java#L663 - https://github.com/apache/spark/blob/6831481fd7a2d30dfa16b4b70c8e6296b4deeb8c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala#L2300 ## Are these changes tested? yes ## Are there any user-facing changes? <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> <!-- If there are any breaking changes to public APIs, please add the `api change` label. -->
1 parent de40f0c commit 54b848c

File tree

10 files changed

+713
-63
lines changed

10 files changed

+713
-63
lines changed

datafusion/functions/src/unicode/substr.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ fn substr(args: &[ArrayRef]) -> Result<ArrayRef> {
176176
// `get_true_start_end('Hi🌏', 1, None) -> (0, 6)`
177177
// `get_true_start_end('Hi🌏', 1, 1) -> (0, 1)`
178178
// `get_true_start_end('Hi🌏', -10, 2) -> (0, 0)`
179-
fn get_true_start_end(
179+
pub fn get_true_start_end(
180180
input: &str,
181181
start: i64,
182182
count: Option<u64>,
@@ -235,7 +235,7 @@ fn get_true_start_end(
235235
// string, such as `substr(long_str_with_1k_chars, 1, 32)`.
236236
// In such case the overhead of ASCII-validation may not be worth it, so
237237
// skip the validation for short prefix for now.
238-
fn enable_ascii_fast_path<'a, V: StringArrayType<'a>>(
238+
pub fn enable_ascii_fast_path<'a, V: StringArrayType<'a>>(
239239
string_array: &V,
240240
start: &Int64Array,
241241
count: Option<&Int64Array>,

datafusion/spark/Cargo.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ sha1 = "0.10"
5757
url = { workspace = true }
5858

5959
[dev-dependencies]
60+
arrow = { workspace = true, features = ["test_utils"] }
6061
criterion = { workspace = true }
6162

6263
[[bench]]
@@ -74,3 +75,7 @@ name = "hex"
7475
[[bench]]
7576
harness = false
7677
name = "slice"
78+
79+
[[bench]]
80+
harness = false
81+
name = "substring"
Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
extern crate criterion;
19+
20+
use arrow::array::{ArrayRef, Int64Array, OffsetSizeTrait};
21+
use arrow::datatypes::{DataType, Field};
22+
use arrow::util::bench_util::{
23+
create_string_array_with_len, create_string_view_array_with_len,
24+
};
25+
use criterion::{Criterion, SamplingMode, criterion_group, criterion_main};
26+
use datafusion_common::DataFusionError;
27+
use datafusion_common::config::ConfigOptions;
28+
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs};
29+
use datafusion_spark::function::string::substring;
30+
use std::hint::black_box;
31+
use std::sync::Arc;
32+
33+
fn create_args_without_count<O: OffsetSizeTrait>(
34+
size: usize,
35+
str_len: usize,
36+
start_half_way: bool,
37+
force_view_types: bool,
38+
) -> Vec<ColumnarValue> {
39+
let start_array = Arc::new(Int64Array::from(
40+
(0..size)
41+
.map(|_| {
42+
if start_half_way {
43+
(str_len / 2) as i64
44+
} else {
45+
1i64
46+
}
47+
})
48+
.collect::<Vec<_>>(),
49+
));
50+
51+
if force_view_types {
52+
let string_array =
53+
Arc::new(create_string_view_array_with_len(size, 0.1, str_len, false));
54+
vec![
55+
ColumnarValue::Array(string_array),
56+
ColumnarValue::Array(start_array),
57+
]
58+
} else {
59+
let string_array =
60+
Arc::new(create_string_array_with_len::<O>(size, 0.1, str_len));
61+
62+
vec![
63+
ColumnarValue::Array(string_array),
64+
ColumnarValue::Array(Arc::clone(&start_array) as ArrayRef),
65+
]
66+
}
67+
}
68+
69+
fn create_args_with_count<O: OffsetSizeTrait>(
70+
size: usize,
71+
str_len: usize,
72+
count_max: usize,
73+
force_view_types: bool,
74+
) -> Vec<ColumnarValue> {
75+
let start_array =
76+
Arc::new(Int64Array::from((0..size).map(|_| 1).collect::<Vec<_>>()));
77+
let count = count_max.min(str_len) as i64;
78+
let count_array = Arc::new(Int64Array::from(
79+
(0..size).map(|_| count).collect::<Vec<_>>(),
80+
));
81+
82+
if force_view_types {
83+
let string_array =
84+
Arc::new(create_string_view_array_with_len(size, 0.1, str_len, false));
85+
vec![
86+
ColumnarValue::Array(string_array),
87+
ColumnarValue::Array(start_array),
88+
ColumnarValue::Array(count_array),
89+
]
90+
} else {
91+
let string_array =
92+
Arc::new(create_string_array_with_len::<O>(size, 0.1, str_len));
93+
94+
vec![
95+
ColumnarValue::Array(string_array),
96+
ColumnarValue::Array(Arc::clone(&start_array) as ArrayRef),
97+
ColumnarValue::Array(Arc::clone(&count_array) as ArrayRef),
98+
]
99+
}
100+
}
101+
102+
#[expect(clippy::needless_pass_by_value)]
103+
fn invoke_substr_with_args(
104+
args: Vec<ColumnarValue>,
105+
number_rows: usize,
106+
) -> Result<ColumnarValue, DataFusionError> {
107+
let arg_fields = args
108+
.iter()
109+
.enumerate()
110+
.map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into())
111+
.collect::<Vec<_>>();
112+
let config_options = Arc::new(ConfigOptions::default());
113+
114+
substring().invoke_with_args(ScalarFunctionArgs {
115+
args: args.clone(),
116+
arg_fields,
117+
number_rows,
118+
return_field: Field::new("f", DataType::Utf8View, true).into(),
119+
config_options: Arc::clone(&config_options),
120+
})
121+
}
122+
123+
fn criterion_benchmark(c: &mut Criterion) {
124+
for size in [1024, 4096] {
125+
// string_len = 12, substring_len=6 (see `create_args_without_count`)
126+
let len = 12;
127+
let mut group = c.benchmark_group("SHORTER THAN 12");
128+
group.sampling_mode(SamplingMode::Flat);
129+
group.sample_size(10);
130+
131+
let args = create_args_without_count::<i32>(size, len, true, true);
132+
group.bench_function(
133+
format!("substr_string_view [size={size}, strlen={len}]"),
134+
|b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))),
135+
);
136+
137+
let args = create_args_without_count::<i32>(size, len, false, false);
138+
group.bench_function(format!("substr_string [size={size}, strlen={len}]"), |b| {
139+
b.iter(|| black_box(invoke_substr_with_args(args.clone(), size)))
140+
});
141+
142+
let args = create_args_without_count::<i64>(size, len, true, false);
143+
group.bench_function(
144+
format!("substr_large_string [size={size}, strlen={len}]"),
145+
|b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))),
146+
);
147+
148+
group.finish();
149+
150+
// string_len = 128, start=1, count=64, substring_len=64
151+
let len = 128;
152+
let count = 64;
153+
let mut group = c.benchmark_group("LONGER THAN 12");
154+
group.sampling_mode(SamplingMode::Flat);
155+
group.sample_size(10);
156+
157+
let args = create_args_with_count::<i32>(size, len, count, true);
158+
group.bench_function(
159+
format!("substr_string_view [size={size}, count={count}, strlen={len}]",),
160+
|b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))),
161+
);
162+
163+
let args = create_args_with_count::<i32>(size, len, count, false);
164+
group.bench_function(
165+
format!("substr_string [size={size}, count={count}, strlen={len}]",),
166+
|b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))),
167+
);
168+
169+
let args = create_args_with_count::<i64>(size, len, count, false);
170+
group.bench_function(
171+
format!("substr_large_string [size={size}, count={count}, strlen={len}]",),
172+
|b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))),
173+
);
174+
175+
group.finish();
176+
177+
// string_len = 128, start=1, count=6, substring_len=6
178+
let len = 128;
179+
let count = 6;
180+
let mut group = c.benchmark_group("SRC_LEN > 12, SUB_LEN < 12");
181+
group.sampling_mode(SamplingMode::Flat);
182+
group.sample_size(10);
183+
184+
let args = create_args_with_count::<i32>(size, len, count, true);
185+
group.bench_function(
186+
format!("substr_string_view [size={size}, count={count}, strlen={len}]",),
187+
|b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))),
188+
);
189+
190+
let args = create_args_with_count::<i32>(size, len, count, false);
191+
group.bench_function(
192+
format!("substr_string [size={size}, count={count}, strlen={len}]",),
193+
|b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))),
194+
);
195+
196+
let args = create_args_with_count::<i64>(size, len, count, false);
197+
group.bench_function(
198+
format!("substr_large_string [size={size}, count={count}, strlen={len}]",),
199+
|b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))),
200+
);
201+
202+
group.finish();
203+
}
204+
}
205+
206+
criterion_group!(benches, criterion_benchmark);
207+
criterion_main!(benches);

datafusion/spark/src/function/string/mod.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ pub mod length;
2525
pub mod like;
2626
pub mod luhn_check;
2727
pub mod space;
28+
pub mod substring;
2829

2930
use datafusion_expr::ScalarUDF;
3031
use datafusion_functions::make_udf_function;
@@ -40,6 +41,7 @@ make_udf_function!(like::SparkLike, like);
4041
make_udf_function!(luhn_check::SparkLuhnCheck, luhn_check);
4142
make_udf_function!(format_string::FormatStringFunc, format_string);
4243
make_udf_function!(space::SparkSpace, space);
44+
make_udf_function!(substring::SparkSubstring, substring);
4345

4446
pub mod expr_fn {
4547
use datafusion_functions::export_functions;
@@ -90,6 +92,11 @@ pub mod expr_fn {
9092
strfmt args
9193
));
9294
export_functions!((space, "Returns a string consisting of n spaces.", arg1));
95+
export_functions!((
96+
substring,
97+
"Returns the substring from string `str` starting at position `pos` with length `length.",
98+
str pos length
99+
));
93100
}
94101

95102
pub fn functions() -> Vec<Arc<ScalarUDF>> {
@@ -104,5 +111,6 @@ pub fn functions() -> Vec<Arc<ScalarUDF>> {
104111
luhn_check(),
105112
format_string(),
106113
space(),
114+
substring(),
107115
]
108116
}

0 commit comments

Comments
 (0)