Skip to content

Commit 54dc054

Browse files
committed
perf: Optimize contains expression with SIMD-based scalar pattern search (#2972)
1 parent dca45ea commit 54dc054

File tree

7 files changed

+310
-3
lines changed

7 files changed

+310
-3
lines changed

native/Cargo.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

native/spark-expr/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,11 @@ edition = { workspace = true }
2828

2929
[dependencies]
3030
arrow = { workspace = true }
31+
arrow-string = "57.0.0"
3132
chrono = { workspace = true }
3233
datafusion = { workspace = true }
3334
chrono-tz = { workspace = true }
35+
memchr = "2.7"
3436
num = { workspace = true }
3537
regex = { workspace = true }
3638
serde_json = "1.0"

native/spark-expr/src/comet_scalar_funcs.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ use crate::math_funcs::modulo_expr::spark_modulo;
2222
use crate::{
2323
spark_array_repeat, spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor,
2424
spark_isnan, spark_lpad, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad,
25-
spark_unhex, spark_unscaled_value, EvalMode, SparkBitwiseCount, SparkDateTrunc, SparkSizeFunc,
26-
SparkStringSpace,
25+
spark_unhex, spark_unscaled_value, EvalMode, SparkBitwiseCount, SparkContains, SparkDateTrunc,
26+
SparkSizeFunc, SparkStringSpace,
2727
};
2828
use arrow::datatypes::DataType;
2929
use datafusion::common::{DataFusionError, Result as DataFusionResult};
@@ -192,6 +192,7 @@ pub fn create_comet_physical_fun_with_eval_mode(
192192
fn all_scalar_functions() -> Vec<Arc<ScalarUDF>> {
193193
vec![
194194
Arc::new(ScalarUDF::new_from_impl(SparkBitwiseCount::default())),
195+
Arc::new(ScalarUDF::new_from_impl(SparkContains::default())),
195196
Arc::new(ScalarUDF::new_from_impl(SparkDateTrunc::default())),
196197
Arc::new(ScalarUDF::new_from_impl(SparkStringSpace::default())),
197198
Arc::new(ScalarUDF::new_from_impl(SparkSizeFunc::default())),
Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! Optimized `contains` string function for Spark compatibility.
19+
//!
20+
//! This implementation is optimized for the common case where the pattern
21+
//! (second argument) is a scalar value. In this case, we use `memchr::memmem::Finder`
22+
//! which is SIMD-optimized and reuses a single finder instance across all rows.
23+
//!
24+
//! The DataFusion built-in `contains` function uses `make_scalar_function` which
25+
//! expands scalar values to arrays, losing the performance benefit of the optimized
26+
//! scalar path in arrow-rs.
27+
28+
use arrow::array::{Array, ArrayRef, AsArray, BooleanArray};
29+
use arrow::datatypes::DataType;
30+
use arrow_string::like::contains as arrow_contains;
31+
use datafusion::common::{exec_err, Result, ScalarValue};
32+
use datafusion::logical_expr::{
33+
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
34+
};
35+
use memchr::memmem::Finder;
36+
use std::any::Any;
37+
use std::sync::Arc;
38+
39+
/// Spark-optimized contains function.
40+
///
41+
/// Returns true if the first string argument contains the second string argument.
42+
/// Optimized for the common case where the pattern is a scalar constant.
43+
#[derive(Debug, PartialEq, Eq, Hash)]
44+
pub struct SparkContains {
45+
signature: Signature,
46+
}
47+
48+
impl Default for SparkContains {
49+
fn default() -> Self {
50+
Self::new()
51+
}
52+
}
53+
54+
impl SparkContains {
55+
pub fn new() -> Self {
56+
Self {
57+
signature: Signature::variadic_any(Volatility::Immutable),
58+
}
59+
}
60+
}
61+
62+
impl ScalarUDFImpl for SparkContains {
63+
fn as_any(&self) -> &dyn Any {
64+
self
65+
}
66+
67+
fn name(&self) -> &str {
68+
"contains"
69+
}
70+
71+
fn signature(&self) -> &Signature {
72+
&self.signature
73+
}
74+
75+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
76+
Ok(DataType::Boolean)
77+
}
78+
79+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
80+
if args.args.len() != 2 {
81+
return exec_err!("contains function requires exactly 2 arguments");
82+
}
83+
spark_contains(&args.args[0], &args.args[1])
84+
}
85+
}
86+
87+
/// Execute the contains function with optimized scalar pattern handling.
88+
fn spark_contains(haystack: &ColumnarValue, needle: &ColumnarValue) -> Result<ColumnarValue> {
89+
match (haystack, needle) {
90+
// Case 1: Both are arrays - use arrow's contains directly
91+
(ColumnarValue::Array(haystack_array), ColumnarValue::Array(needle_array)) => {
92+
let result = arrow_contains(haystack_array, needle_array)?;
93+
Ok(ColumnarValue::Array(Arc::new(result)))
94+
}
95+
96+
// Case 2: Haystack is array, needle is scalar - OPTIMIZED PATH
97+
// This is the common case in SQL like: WHERE col CONTAINS 'pattern'
98+
(ColumnarValue::Array(haystack_array), ColumnarValue::Scalar(needle_scalar)) => {
99+
let result = contains_with_scalar_pattern(haystack_array, needle_scalar)?;
100+
Ok(ColumnarValue::Array(result))
101+
}
102+
103+
// Case 3: Haystack is scalar, needle is array - less common
104+
(ColumnarValue::Scalar(haystack_scalar), ColumnarValue::Array(needle_array)) => {
105+
// Convert scalar to array and use arrow's contains
106+
let haystack_array = haystack_scalar.to_array_of_size(needle_array.len())?;
107+
let result = arrow_contains(&haystack_array, needle_array)?;
108+
Ok(ColumnarValue::Array(Arc::new(result)))
109+
}
110+
111+
// Case 4: Both are scalars - compute single result
112+
(ColumnarValue::Scalar(haystack_scalar), ColumnarValue::Scalar(needle_scalar)) => {
113+
let result = contains_scalar_scalar(haystack_scalar, needle_scalar)?;
114+
Ok(ColumnarValue::Scalar(result))
115+
}
116+
}
117+
}
118+
119+
/// Optimized contains for array haystack with scalar needle pattern.
120+
/// Uses memchr's SIMD-optimized Finder for efficient repeated searches.
121+
fn contains_with_scalar_pattern(
122+
haystack_array: &ArrayRef,
123+
needle_scalar: &ScalarValue,
124+
) -> Result<ArrayRef> {
125+
// Handle null needle
126+
if needle_scalar.is_null() {
127+
return Ok(Arc::new(BooleanArray::new_null(haystack_array.len())));
128+
}
129+
130+
// Extract the needle string
131+
let needle_str = match needle_scalar {
132+
ScalarValue::Utf8(Some(s))
133+
| ScalarValue::LargeUtf8(Some(s))
134+
| ScalarValue::Utf8View(Some(s)) => s.as_str(),
135+
_ => {
136+
return exec_err!(
137+
"contains function requires string type for needle, got {:?}",
138+
needle_scalar.data_type()
139+
)
140+
}
141+
};
142+
143+
// Create a reusable Finder for efficient SIMD-optimized searching
144+
let finder = Finder::new(needle_str.as_bytes());
145+
146+
match haystack_array.data_type() {
147+
DataType::Utf8 => {
148+
let array = haystack_array.as_string::<i32>();
149+
let result: BooleanArray = array
150+
.iter()
151+
.map(|opt_haystack| opt_haystack.map(|h| finder.find(h.as_bytes()).is_some()))
152+
.collect();
153+
Ok(Arc::new(result))
154+
}
155+
DataType::LargeUtf8 => {
156+
let array = haystack_array.as_string::<i64>();
157+
let result: BooleanArray = array
158+
.iter()
159+
.map(|opt_haystack| opt_haystack.map(|h| finder.find(h.as_bytes()).is_some()))
160+
.collect();
161+
Ok(Arc::new(result))
162+
}
163+
DataType::Utf8View => {
164+
let array = haystack_array.as_string_view();
165+
let result: BooleanArray = array
166+
.iter()
167+
.map(|opt_haystack| opt_haystack.map(|h| finder.find(h.as_bytes()).is_some()))
168+
.collect();
169+
Ok(Arc::new(result))
170+
}
171+
other => exec_err!(
172+
"contains function requires string type for haystack, got {:?}",
173+
other
174+
),
175+
}
176+
}
177+
178+
/// Contains for two scalar values.
179+
fn contains_scalar_scalar(
180+
haystack_scalar: &ScalarValue,
181+
needle_scalar: &ScalarValue,
182+
) -> Result<ScalarValue> {
183+
// Handle nulls
184+
if haystack_scalar.is_null() || needle_scalar.is_null() {
185+
return Ok(ScalarValue::Boolean(None));
186+
}
187+
188+
let haystack_str = match haystack_scalar {
189+
ScalarValue::Utf8(Some(s))
190+
| ScalarValue::LargeUtf8(Some(s))
191+
| ScalarValue::Utf8View(Some(s)) => s.as_str(),
192+
_ => {
193+
return exec_err!(
194+
"contains function requires string type for haystack, got {:?}",
195+
haystack_scalar.data_type()
196+
)
197+
}
198+
};
199+
200+
let needle_str = match needle_scalar {
201+
ScalarValue::Utf8(Some(s))
202+
| ScalarValue::LargeUtf8(Some(s))
203+
| ScalarValue::Utf8View(Some(s)) => s.as_str(),
204+
_ => {
205+
return exec_err!(
206+
"contains function requires string type for needle, got {:?}",
207+
needle_scalar.data_type()
208+
)
209+
}
210+
};
211+
212+
Ok(ScalarValue::Boolean(Some(
213+
haystack_str.contains(needle_str),
214+
)))
215+
}
216+
217+
#[cfg(test)]
218+
mod tests {
219+
use super::*;
220+
use arrow::array::StringArray;
221+
222+
#[test]
223+
fn test_contains_array_scalar() {
224+
let haystack = Arc::new(StringArray::from(vec![
225+
Some("hello world"),
226+
Some("foo bar"),
227+
Some("testing"),
228+
None,
229+
])) as ArrayRef;
230+
let needle = ScalarValue::Utf8(Some("world".to_string()));
231+
232+
let result = contains_with_scalar_pattern(&haystack, &needle).unwrap();
233+
let bool_array = result.as_any().downcast_ref::<BooleanArray>().unwrap();
234+
235+
assert_eq!(bool_array.value(0), true); // "hello world" contains "world"
236+
assert_eq!(bool_array.value(1), false); // "foo bar" does not contain "world"
237+
assert_eq!(bool_array.value(2), false); // "testing" does not contain "world"
238+
assert!(bool_array.is_null(3)); // null input => null output
239+
}
240+
241+
#[test]
242+
fn test_contains_scalar_scalar() {
243+
let haystack = ScalarValue::Utf8(Some("hello world".to_string()));
244+
let needle = ScalarValue::Utf8(Some("world".to_string()));
245+
246+
let result = contains_scalar_scalar(&haystack, &needle).unwrap();
247+
assert_eq!(result, ScalarValue::Boolean(Some(true)));
248+
249+
let needle_not_found = ScalarValue::Utf8(Some("xyz".to_string()));
250+
let result = contains_scalar_scalar(&haystack, &needle_not_found).unwrap();
251+
assert_eq!(result, ScalarValue::Boolean(Some(false)));
252+
}
253+
254+
#[test]
255+
fn test_contains_null_needle() {
256+
let haystack = Arc::new(StringArray::from(vec![
257+
Some("hello world"),
258+
Some("foo bar"),
259+
])) as ArrayRef;
260+
let needle = ScalarValue::Utf8(None);
261+
262+
let result = contains_with_scalar_pattern(&haystack, &needle).unwrap();
263+
let bool_array = result.as_any().downcast_ref::<BooleanArray>().unwrap();
264+
265+
// Null needle should produce null results
266+
assert!(bool_array.is_null(0));
267+
assert!(bool_array.is_null(1));
268+
}
269+
270+
#[test]
271+
fn test_contains_empty_needle() {
272+
let haystack = Arc::new(StringArray::from(vec![Some("hello world"), Some("")])) as ArrayRef;
273+
let needle = ScalarValue::Utf8(Some("".to_string()));
274+
275+
let result = contains_with_scalar_pattern(&haystack, &needle).unwrap();
276+
let bool_array = result.as_any().downcast_ref::<BooleanArray>().unwrap();
277+
278+
// Empty string is contained in any string
279+
assert_eq!(bool_array.value(0), true);
280+
assert_eq!(bool_array.value(1), true);
281+
}
282+
}

native/spark-expr/src/string_funcs/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
mod contains;
1819
mod string_space;
1920
mod substring;
2021

22+
pub use contains::SparkContains;
2123
pub use string_space::SparkStringSpace;
2224
pub use substring::SubstringExpr;

spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1107,7 +1107,24 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
11071107

11081108
// Filter rows that contains 'rose' in 'name' column
11091109
val queryContains = sql(s"select id from $table where contains (name, 'rose')")
1110-
checkAnswer(queryContains, Row(5) :: Nil)
1110+
checkSparkAnswerAndOperator(queryContains)
1111+
1112+
// Additional test cases for optimized contains implementation
1113+
// Test with empty pattern (should match all non-null rows)
1114+
val queryEmptyPattern = sql(s"select id from $table where contains (name, '')")
1115+
checkSparkAnswerAndOperator(queryEmptyPattern)
1116+
1117+
// Test with pattern not found
1118+
val queryNotFound = sql(s"select id from $table where contains (name, 'xyz')")
1119+
checkSparkAnswerAndOperator(queryNotFound)
1120+
1121+
// Test with pattern at start
1122+
val queryStart = sql(s"select id from $table where contains (name, 'James')")
1123+
checkSparkAnswerAndOperator(queryStart)
1124+
1125+
// Test with pattern at end
1126+
val queryEnd = sql(s"select id from $table where contains (name, 'Smith')")
1127+
checkSparkAnswerAndOperator(queryEnd)
11111128
}
11121129
}
11131130

spark/src/test/scala/org/apache/spark/sql/benchmark/CometStringExpressionBenchmark.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ object CometStringExpressionBenchmark extends CometBenchmarkBase {
7272
StringExprConfig("initCap", "select initCap(c1) from parquetV1Table"),
7373
StringExprConfig("trim", "select trim(c1) from parquetV1Table"),
7474
StringExprConfig("concatws", "select concat_ws(' ', c1, c1) from parquetV1Table"),
75+
StringExprConfig("contains", "select contains(c1, '123') from parquetV1Table"),
7576
StringExprConfig("length", "select length(c1) from parquetV1Table"),
7677
StringExprConfig("repeat", "select repeat(c1, 3) from parquetV1Table"),
7778
StringExprConfig("reverse", "select reverse(c1) from parquetV1Table"),

0 commit comments

Comments
 (0)