-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Perf: Optimize substring_index via single-byte fast path and direct indexing
#19590
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -50,7 +50,10 @@ where | |
| } | ||
| } | ||
|
|
||
| fn data() -> (StringArray, StringArray, Int64Array) { | ||
| fn data( | ||
| batch_size: usize, | ||
| single_char_delimiter: bool, | ||
| ) -> (StringArray, StringArray, Int64Array) { | ||
| let dist = Filter { | ||
| dist: Uniform::new(-4, 5), | ||
| test: |x: &i64| x != &0, | ||
|
|
@@ -60,19 +63,39 @@ fn data() -> (StringArray, StringArray, Int64Array) { | |
| let mut delimiters: Vec<String> = vec![]; | ||
| let mut counts: Vec<i64> = vec![]; | ||
|
|
||
| for _ in 0..1000 { | ||
| for _ in 0..batch_size { | ||
| let length = rng.random_range(20..50); | ||
| let text: String = (&mut rng) | ||
| let base: String = (&mut rng) | ||
| .sample_iter(&Alphanumeric) | ||
| .take(length) | ||
| .map(char::from) | ||
| .collect(); | ||
| let char = rng.random_range(0..text.len()); | ||
| let delimiter = &text.chars().nth(char).unwrap(); | ||
|
|
||
| let (string_value, delimiter): (String, String) = if single_char_delimiter { | ||
| let char_idx = rng.random_range(0..base.chars().count()); | ||
| let delimiter = base.chars().nth(char_idx).unwrap().to_string(); | ||
| (base, delimiter) | ||
| } else { | ||
| let long_delimiters = ["|||", "***", "&&&", "###", "@@@", "$$$"]; | ||
| let delimiter = | ||
| long_delimiters[rng.random_range(0..long_delimiters.len())].to_string(); | ||
|
|
||
| let delimiter_count = rng.random_range(1..4); | ||
| let mut result = String::new(); | ||
|
|
||
| for i in 0..delimiter_count { | ||
| result.push_str(&base); | ||
| if i < delimiter_count - 1 { | ||
| result.push_str(&delimiter); | ||
| } | ||
| } | ||
| (result, delimiter) | ||
| }; | ||
|
|
||
| let count = rng.sample(dist.dist.unwrap()); | ||
|
|
||
| strings.push(text); | ||
| delimiters.push(delimiter.to_string()); | ||
| strings.push(string_value); | ||
| delimiters.push(delimiter); | ||
| counts.push(count); | ||
| } | ||
|
|
||
|
|
@@ -83,38 +106,66 @@ fn data() -> (StringArray, StringArray, Int64Array) { | |
| ) | ||
| } | ||
|
|
||
| fn criterion_benchmark(c: &mut Criterion) { | ||
| c.bench_function("substr_index_array_array_1000", |b| { | ||
| let (strings, delimiters, counts) = data(); | ||
| let batch_len = counts.len(); | ||
| let strings = ColumnarValue::Array(Arc::new(strings) as ArrayRef); | ||
| let delimiters = ColumnarValue::Array(Arc::new(delimiters) as ArrayRef); | ||
| let counts = ColumnarValue::Array(Arc::new(counts) as ArrayRef); | ||
|
|
||
| let args = vec![strings, delimiters, counts]; | ||
| let arg_fields = args | ||
| .iter() | ||
| .enumerate() | ||
| .map(|(idx, arg)| { | ||
| Field::new(format!("arg_{idx}"), arg.data_type(), true).into() | ||
| }) | ||
| .collect::<Vec<_>>(); | ||
| let config_options = Arc::new(ConfigOptions::default()); | ||
|
|
||
| b.iter(|| { | ||
| black_box( | ||
| substr_index() | ||
| .invoke_with_args(ScalarFunctionArgs { | ||
| args: args.clone(), | ||
| arg_fields: arg_fields.clone(), | ||
| number_rows: batch_len, | ||
| return_field: Field::new("f", DataType::Utf8, true).into(), | ||
| config_options: Arc::clone(&config_options), | ||
| }) | ||
| .expect("substr_index should work on valid values"), | ||
| ) | ||
| fn run_benchmark( | ||
| b: &mut criterion::Bencher, | ||
| strings: StringArray, | ||
| delimiters: StringArray, | ||
| counts: Int64Array, | ||
| batch_size: usize, | ||
| ) { | ||
| let strings = ColumnarValue::Array(Arc::new(strings) as ArrayRef); | ||
| let delimiters = ColumnarValue::Array(Arc::new(delimiters) as ArrayRef); | ||
| let counts = ColumnarValue::Array(Arc::new(counts) as ArrayRef); | ||
|
|
||
| let args = vec![strings, delimiters, counts]; | ||
| let arg_fields = args | ||
| .iter() | ||
| .enumerate() | ||
| .map(|(idx, arg)| { | ||
| Field::new(format!("arg_{idx}"), arg.data_type().clone(), true).into() | ||
| }) | ||
| }); | ||
| .collect::<Vec<_>>(); | ||
| let config_options = Arc::new(ConfigOptions::default()); | ||
|
|
||
| b.iter(|| { | ||
| black_box( | ||
| substr_index() | ||
| .invoke_with_args(ScalarFunctionArgs { | ||
| args: args.clone(), | ||
| arg_fields: arg_fields.clone(), | ||
| number_rows: batch_size, | ||
| return_field: Field::new("f", DataType::Utf8, true).into(), | ||
| config_options: Arc::clone(&config_options), | ||
| }) | ||
| .expect("substr_index should work on valid values"), | ||
| ) | ||
| }) | ||
| } | ||
|
|
||
| fn criterion_benchmark(c: &mut Criterion) { | ||
| let mut group = c.benchmark_group("substr_index"); | ||
|
|
||
| let batch_sizes = [100, 1000, 10_000]; | ||
|
|
||
| for batch_size in batch_sizes { | ||
| group.bench_function( | ||
| &format!("substr_index_{}_single_delimiter", batch_size), | ||
| |b| { | ||
| let (strings, delimiters, counts) = data(batch_size, true); | ||
| run_benchmark(b, strings, delimiters, counts, batch_size); | ||
| }, | ||
| ); | ||
|
|
||
| group.bench_function( | ||
| &format!("substr_index_{}_long_delimiter", batch_size), | ||
|
||
| |b| { | ||
| let (strings, delimiters, counts) = data(batch_size, false); | ||
| run_benchmark(b, strings, delimiters, counts, batch_size); | ||
| }, | ||
| ); | ||
| } | ||
|
|
||
| group.finish(); | ||
| } | ||
|
|
||
| criterion_group!(benches, criterion_benchmark); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,8 +19,8 @@ use std::any::Any; | |
| use std::sync::Arc; | ||
|
|
||
| use arrow::array::{ | ||
| ArrayAccessor, ArrayIter, ArrayRef, ArrowPrimitiveType, AsArray, OffsetSizeTrait, | ||
| PrimitiveArray, StringBuilder, | ||
| ArrayAccessor, ArrayIter, ArrayRef, ArrowPrimitiveType, AsArray, | ||
| GenericStringBuilder, OffsetSizeTrait, PrimitiveArray, | ||
| }; | ||
| use arrow::datatypes::{DataType, Int32Type, Int64Type}; | ||
|
|
||
|
|
@@ -182,7 +182,8 @@ fn substr_index_general< | |
| where | ||
| T::Native: OffsetSizeTrait, | ||
| { | ||
| let mut builder = StringBuilder::new(); | ||
| let num_rows = string_array.len(); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you wanted to make this really fast, you could also implement special case code for the common cases where
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point! I’ll try adding a special-case implementation for that. Thanks! |
||
| let mut builder = GenericStringBuilder::<T::Native>::with_capacity(num_rows, 0); | ||
| let string_iter = ArrayIter::new(string_array); | ||
| let delimiter_array_iter = ArrayIter::new(delimiter_array); | ||
| let count_array_iter = ArrayIter::new(count_array); | ||
|
|
@@ -198,31 +199,48 @@ where | |
| } | ||
|
|
||
| let occurrences = usize::try_from(n.unsigned_abs()).unwrap_or(usize::MAX); | ||
| let length = if n > 0 { | ||
| let split = string.split(delimiter); | ||
| split | ||
| .take(occurrences) | ||
| .map(|s| s.len() + delimiter.len()) | ||
| .sum::<usize>() | ||
| - delimiter.len() | ||
| } else { | ||
| let split = string.rsplit(delimiter); | ||
| split | ||
| .take(occurrences) | ||
| .map(|s| s.len() + delimiter.len()) | ||
| .sum::<usize>() | ||
| - delimiter.len() | ||
| }; | ||
| if n > 0 { | ||
| match string.get(..length) { | ||
| Some(substring) => builder.append_value(substring), | ||
| None => builder.append_null(), | ||
| let result_idx = if delimiter.len() == 1 { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I double checked that this checks for length in bytes https://doc.rust-lang.org/std/string/struct.String.html#method.len |
||
| let d_byte = delimiter.as_bytes()[0]; | ||
| let bytes = string.as_bytes(); | ||
|
|
||
| if n > 0 { | ||
| bytes | ||
| .iter() | ||
| .enumerate() | ||
| .filter(|&(_, &b)| b == d_byte) | ||
| .nth(occurrences - 1) | ||
| .map(|(idx, _)| idx) | ||
| } else { | ||
| bytes | ||
| .iter() | ||
| .enumerate() | ||
| .rev() | ||
| .filter(|&(_, &b)| b == d_byte) | ||
| .nth(occurrences - 1) | ||
| .map(|(idx, _)| idx + 1) | ||
| } | ||
| } else { | ||
| match string.get(string.len().saturating_sub(length)..) { | ||
| Some(substring) => builder.append_value(substring), | ||
| None => builder.append_null(), | ||
| if n > 0 { | ||
|
||
| string | ||
| .match_indices(delimiter) | ||
| .nth(occurrences - 1) | ||
| .map(|(idx, _)| idx) | ||
| } else { | ||
| string | ||
| .rmatch_indices(delimiter) | ||
| .nth(occurrences - 1) | ||
| .map(|(idx, _)| idx + delimiter.len()) | ||
| } | ||
| }; | ||
| match result_idx { | ||
| Some(idx) => { | ||
| if n > 0 { | ||
| builder.append_value(&string[..idx]); | ||
| } else { | ||
| builder.append_value(&string[idx..]); | ||
| } | ||
| } | ||
| None => builder.append_value(string), | ||
| } | ||
| } | ||
| _ => builder.append_null(), | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
variables can be used directly in the
format!string