Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 89 additions & 38 deletions datafusion/functions/benches/substr_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@ where
}
}

fn data() -> (StringArray, StringArray, Int64Array) {
fn data(
batch_size: usize,
single_char_delimiter: bool,
) -> (StringArray, StringArray, Int64Array) {
let dist = Filter {
dist: Uniform::new(-4, 5),
test: |x: &i64| x != &0,
Expand All @@ -60,19 +63,39 @@ fn data() -> (StringArray, StringArray, Int64Array) {
let mut delimiters: Vec<String> = vec![];
let mut counts: Vec<i64> = vec![];

for _ in 0..1000 {
for _ in 0..batch_size {
let length = rng.random_range(20..50);
let text: String = (&mut rng)
let base: String = (&mut rng)
.sample_iter(&Alphanumeric)
.take(length)
.map(char::from)
.collect();
let char = rng.random_range(0..text.len());
let delimiter = &text.chars().nth(char).unwrap();

let (string_value, delimiter): (String, String) = if single_char_delimiter {
let char_idx = rng.random_range(0..base.chars().count());
let delimiter = base.chars().nth(char_idx).unwrap().to_string();
(base, delimiter)
} else {
let long_delimiters = ["|||", "***", "&&&", "###", "@@@", "$$$"];
let delimiter =
long_delimiters[rng.random_range(0..long_delimiters.len())].to_string();

let delimiter_count = rng.random_range(1..4);
let mut result = String::new();

for i in 0..delimiter_count {
result.push_str(&base);
if i < delimiter_count - 1 {
result.push_str(&delimiter);
}
}
(result, delimiter)
};

let count = rng.sample(dist.dist.unwrap());

strings.push(text);
delimiters.push(delimiter.to_string());
strings.push(string_value);
delimiters.push(delimiter);
counts.push(count);
}

Expand All @@ -83,38 +106,66 @@ fn data() -> (StringArray, StringArray, Int64Array) {
)
}

fn criterion_benchmark(c: &mut Criterion) {
c.bench_function("substr_index_array_array_1000", |b| {
let (strings, delimiters, counts) = data();
let batch_len = counts.len();
let strings = ColumnarValue::Array(Arc::new(strings) as ArrayRef);
let delimiters = ColumnarValue::Array(Arc::new(delimiters) as ArrayRef);
let counts = ColumnarValue::Array(Arc::new(counts) as ArrayRef);

let args = vec![strings, delimiters, counts];
let arg_fields = args
.iter()
.enumerate()
.map(|(idx, arg)| {
Field::new(format!("arg_{idx}"), arg.data_type(), true).into()
})
.collect::<Vec<_>>();
let config_options = Arc::new(ConfigOptions::default());

b.iter(|| {
black_box(
substr_index()
.invoke_with_args(ScalarFunctionArgs {
args: args.clone(),
arg_fields: arg_fields.clone(),
number_rows: batch_len,
return_field: Field::new("f", DataType::Utf8, true).into(),
config_options: Arc::clone(&config_options),
})
.expect("substr_index should work on valid values"),
)
fn run_benchmark(
b: &mut criterion::Bencher,
strings: StringArray,
delimiters: StringArray,
counts: Int64Array,
batch_size: usize,
) {
let strings = ColumnarValue::Array(Arc::new(strings) as ArrayRef);
let delimiters = ColumnarValue::Array(Arc::new(delimiters) as ArrayRef);
let counts = ColumnarValue::Array(Arc::new(counts) as ArrayRef);

let args = vec![strings, delimiters, counts];
let arg_fields = args
.iter()
.enumerate()
.map(|(idx, arg)| {
Field::new(format!("arg_{idx}"), arg.data_type().clone(), true).into()
})
});
.collect::<Vec<_>>();
let config_options = Arc::new(ConfigOptions::default());

b.iter(|| {
black_box(
substr_index()
.invoke_with_args(ScalarFunctionArgs {
args: args.clone(),
arg_fields: arg_fields.clone(),
number_rows: batch_size,
return_field: Field::new("f", DataType::Utf8, true).into(),
config_options: Arc::clone(&config_options),
})
.expect("substr_index should work on valid values"),
)
})
}

fn criterion_benchmark(c: &mut Criterion) {
let mut group = c.benchmark_group("substr_index");

let batch_sizes = [100, 1000, 10_000];

for batch_size in batch_sizes {
group.bench_function(
&format!("substr_index_{}_single_delimiter", batch_size),
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

variables can be used directly in the format! string

|b| {
let (strings, delimiters, counts) = data(batch_size, true);
run_benchmark(b, strings, delimiters, counts, batch_size);
},
);

group.bench_function(
&format!("substr_index_{}_long_delimiter", batch_size),
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

variables can be used directly in the format! string

|b| {
let (strings, delimiters, counts) = data(batch_size, false);
run_benchmark(b, strings, delimiters, counts, batch_size);
},
);
}

group.finish();
}

criterion_group!(benches, criterion_benchmark);
Expand Down
68 changes: 43 additions & 25 deletions datafusion/functions/src/unicode/substrindex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ use std::any::Any;
use std::sync::Arc;

use arrow::array::{
ArrayAccessor, ArrayIter, ArrayRef, ArrowPrimitiveType, AsArray, OffsetSizeTrait,
PrimitiveArray, StringBuilder,
ArrayAccessor, ArrayIter, ArrayRef, ArrowPrimitiveType, AsArray,
GenericStringBuilder, OffsetSizeTrait, PrimitiveArray,
};
use arrow::datatypes::{DataType, Int32Type, Int64Type};

Expand Down Expand Up @@ -182,7 +182,8 @@ fn substr_index_general<
where
T::Native: OffsetSizeTrait,
{
let mut builder = StringBuilder::new();
let num_rows = string_array.len();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you wanted to make this really fast, you could also implement special case code for the common cases where delimiter and count were scalar values -- I suspect you could make it quite fast.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! I’ll try adding a special-case implementation for that. Thanks!

let mut builder = GenericStringBuilder::<T::Native>::with_capacity(num_rows, 0);
let string_iter = ArrayIter::new(string_array);
let delimiter_array_iter = ArrayIter::new(delimiter_array);
let count_array_iter = ArrayIter::new(count_array);
Expand All @@ -198,31 +199,48 @@ where
}

let occurrences = usize::try_from(n.unsigned_abs()).unwrap_or(usize::MAX);
let length = if n > 0 {
let split = string.split(delimiter);
split
.take(occurrences)
.map(|s| s.len() + delimiter.len())
.sum::<usize>()
- delimiter.len()
} else {
let split = string.rsplit(delimiter);
split
.take(occurrences)
.map(|s| s.len() + delimiter.len())
.sum::<usize>()
- delimiter.len()
};
if n > 0 {
match string.get(..length) {
Some(substring) => builder.append_value(substring),
None => builder.append_null(),
let result_idx = if delimiter.len() == 1 {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I double checked that this checks for length in bytes

https://doc.rust-lang.org/std/string/struct.String.html#method.len

let d_byte = delimiter.as_bytes()[0];
let bytes = string.as_bytes();

if n > 0 {
bytes
.iter()
.enumerate()
.filter(|&(_, &b)| b == d_byte)
.nth(occurrences - 1)
.map(|(idx, _)| idx)
} else {
bytes
.iter()
.enumerate()
.rev()
.filter(|&(_, &b)| b == d_byte)
.nth(occurrences - 1)
.map(|(idx, _)| idx + 1)
}
} else {
match string.get(string.len().saturating_sub(length)..) {
Some(substring) => builder.append_value(substring),
None => builder.append_null(),
if n > 0 {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this else { if .. } block can be collapsed

string
.match_indices(delimiter)
.nth(occurrences - 1)
.map(|(idx, _)| idx)
} else {
string
.rmatch_indices(delimiter)
.nth(occurrences - 1)
.map(|(idx, _)| idx + delimiter.len())
}
};
match result_idx {
Some(idx) => {
if n > 0 {
builder.append_value(&string[..idx]);
} else {
builder.append_value(&string[idx..]);
}
}
None => builder.append_value(string),
}
}
_ => builder.append_null(),
Expand Down