diff --git a/datafusion/functions/benches/substr_index.rs b/datafusion/functions/benches/substr_index.rs index 88600317c9967..28ce6e444eb5c 100644 --- a/datafusion/functions/benches/substr_index.rs +++ b/datafusion/functions/benches/substr_index.rs @@ -50,7 +50,10 @@ where } } -fn data() -> (StringArray, StringArray, Int64Array) { +fn data( + batch_size: usize, + single_char_delimiter: bool, +) -> (StringArray, StringArray, Int64Array) { let dist = Filter { dist: Uniform::new(-4, 5), test: |x: &i64| x != &0, @@ -60,19 +63,39 @@ fn data() -> (StringArray, StringArray, Int64Array) { let mut delimiters: Vec = vec![]; let mut counts: Vec = vec![]; - for _ in 0..1000 { + for _ in 0..batch_size { let length = rng.random_range(20..50); - let text: String = (&mut rng) + let base: String = (&mut rng) .sample_iter(&Alphanumeric) .take(length) .map(char::from) .collect(); - let char = rng.random_range(0..text.len()); - let delimiter = &text.chars().nth(char).unwrap(); + + let (string_value, delimiter): (String, String) = if single_char_delimiter { + let char_idx = rng.random_range(0..base.chars().count()); + let delimiter = base.chars().nth(char_idx).unwrap().to_string(); + (base, delimiter) + } else { + let long_delimiters = ["|||", "***", "&&&", "###", "@@@", "$$$"]; + let delimiter = + long_delimiters[rng.random_range(0..long_delimiters.len())].to_string(); + + let delimiter_count = rng.random_range(1..4); + let mut result = String::new(); + + for i in 0..delimiter_count { + result.push_str(&base); + if i < delimiter_count - 1 { + result.push_str(&delimiter); + } + } + (result, delimiter) + }; + let count = rng.sample(dist.dist.unwrap()); - strings.push(text); - delimiters.push(delimiter.to_string()); + strings.push(string_value); + delimiters.push(delimiter); counts.push(count); } @@ -83,38 +106,63 @@ fn data() -> (StringArray, StringArray, Int64Array) { ) } -fn criterion_benchmark(c: &mut Criterion) { - c.bench_function("substr_index_array_array_1000", |b| { - let (strings, delimiters, counts) = data(); - let batch_len = counts.len(); - let strings = ColumnarValue::Array(Arc::new(strings) as ArrayRef); - let delimiters = ColumnarValue::Array(Arc::new(delimiters) as ArrayRef); - let counts = ColumnarValue::Array(Arc::new(counts) as ArrayRef); - - let args = vec![strings, delimiters, counts]; - let arg_fields = args - .iter() - .enumerate() - .map(|(idx, arg)| { - Field::new(format!("arg_{idx}"), arg.data_type(), true).into() - }) - .collect::>(); - let config_options = Arc::new(ConfigOptions::default()); - - b.iter(|| { - black_box( - substr_index() - .invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - arg_fields: arg_fields.clone(), - number_rows: batch_len, - return_field: Field::new("f", DataType::Utf8, true).into(), - config_options: Arc::clone(&config_options), - }) - .expect("substr_index should work on valid values"), - ) +fn run_benchmark( + b: &mut criterion::Bencher, + strings: StringArray, + delimiters: StringArray, + counts: Int64Array, + batch_size: usize, +) { + let strings = ColumnarValue::Array(Arc::new(strings) as ArrayRef); + let delimiters = ColumnarValue::Array(Arc::new(delimiters) as ArrayRef); + let counts = ColumnarValue::Array(Arc::new(counts) as ArrayRef); + + let args = vec![strings, delimiters, counts]; + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type().clone(), true).into() }) - }); + .collect::>(); + let config_options = Arc::new(ConfigOptions::default()); + + b.iter(|| { + black_box( + substr_index() + .invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: arg_fields.clone(), + number_rows: batch_size, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + }) + .expect("substr_index should work on valid values"), + ) + }) +} + +fn criterion_benchmark(c: &mut Criterion) { + let mut group = c.benchmark_group("substr_index"); + + let batch_sizes = [100, 1000, 10_000]; + + for batch_size in batch_sizes { + group.bench_function( + format!("substr_index_{batch_size}_single_delimiter"), + |b| { + let (strings, delimiters, counts) = data(batch_size, true); + run_benchmark(b, strings, delimiters, counts, batch_size); + }, + ); + + group.bench_function(format!("substr_index_{batch_size}_long_delimiter"), |b| { + let (strings, delimiters, counts) = data(batch_size, false); + run_benchmark(b, strings, delimiters, counts, batch_size); + }); + } + + group.finish(); } criterion_group!(benches, criterion_benchmark); diff --git a/datafusion/functions/src/unicode/substrindex.rs b/datafusion/functions/src/unicode/substrindex.rs index cd9d0702b4976..6389dc92c2380 100644 --- a/datafusion/functions/src/unicode/substrindex.rs +++ b/datafusion/functions/src/unicode/substrindex.rs @@ -19,8 +19,8 @@ use std::any::Any; use std::sync::Arc; use arrow::array::{ - ArrayAccessor, ArrayIter, ArrayRef, ArrowPrimitiveType, AsArray, OffsetSizeTrait, - PrimitiveArray, StringBuilder, + ArrayAccessor, ArrayIter, ArrayRef, ArrowPrimitiveType, AsArray, + GenericStringBuilder, OffsetSizeTrait, PrimitiveArray, }; use arrow::datatypes::{DataType, Int32Type, Int64Type}; @@ -182,7 +182,8 @@ fn substr_index_general< where T::Native: OffsetSizeTrait, { - let mut builder = StringBuilder::new(); + let num_rows = string_array.len(); + let mut builder = GenericStringBuilder::::with_capacity(num_rows, 0); let string_iter = ArrayIter::new(string_array); let delimiter_array_iter = ArrayIter::new(delimiter_array); let count_array_iter = ArrayIter::new(count_array); @@ -198,31 +199,49 @@ where } let occurrences = usize::try_from(n.unsigned_abs()).unwrap_or(usize::MAX); - let length = if n > 0 { - let split = string.split(delimiter); - split - .take(occurrences) - .map(|s| s.len() + delimiter.len()) - .sum::() - - delimiter.len() - } else { - let split = string.rsplit(delimiter); - split - .take(occurrences) - .map(|s| s.len() + delimiter.len()) - .sum::() - - delimiter.len() - }; - if n > 0 { - match string.get(..length) { - Some(substring) => builder.append_value(substring), - None => builder.append_null(), + let result_idx = if delimiter.len() == 1 { + // Fast path: use byte-level search for single-character delimiters + let d_byte = delimiter.as_bytes()[0]; + let bytes = string.as_bytes(); + + if n > 0 { + bytes + .iter() + .enumerate() + .filter(|&(_, &b)| b == d_byte) + .nth(occurrences - 1) + .map(|(idx, _)| idx) + } else { + bytes + .iter() + .enumerate() + .rev() + .filter(|&(_, &b)| b == d_byte) + .nth(occurrences - 1) + .map(|(idx, _)| idx + 1) } + } else if n > 0 { + // Multi-byte path: forward search for n-th occurrence + string + .match_indices(delimiter) + .nth(occurrences - 1) + .map(|(idx, _)| idx) } else { - match string.get(string.len().saturating_sub(length)..) { - Some(substring) => builder.append_value(substring), - None => builder.append_null(), + // Multi-byte path: backward search for n-th occurrence from the right + string + .rmatch_indices(delimiter) + .nth(occurrences - 1) + .map(|(idx, _)| idx + delimiter.len()) + }; + match result_idx { + Some(idx) => { + if n > 0 { + builder.append_value(&string[..idx]); + } else { + builder.append_value(&string[idx..]); + } } + None => builder.append_value(string), } } _ => builder.append_null(), @@ -328,7 +347,6 @@ mod tests { Utf8, StringArray ); - Ok(()) } }