Skip to content

Commit e9138cd

Browse files
committed
Add tests for non-identity transform residuals and fix scalastyle errors
1 parent db2fa02 commit e9138cd

File tree

3 files changed

+440
-2
lines changed

3 files changed

+440
-2
lines changed
Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! Optimized `contains` string function for Spark compatibility.
19+
//!
20+
//! This implementation is optimized for the common case where the pattern
21+
//! (second argument) is a scalar value. In this case, we use `memchr::memmem::Finder`
22+
//! which is SIMD-optimized and reuses a single finder instance across all rows.
23+
//!
24+
//! The DataFusion built-in `contains` function uses `make_scalar_function` which
25+
//! expands scalar values to arrays, losing the performance benefit of the optimized
26+
//! scalar path in arrow-rs.
27+
28+
use arrow::array::{Array, ArrayRef, AsArray, BooleanArray};
29+
use arrow::compute::kernels::comparison::contains as arrow_contains;
30+
use arrow::datatypes::DataType;
31+
use datafusion::common::{exec_err, Result, ScalarValue};
32+
use datafusion::logical_expr::{
33+
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
34+
};
35+
use memchr::memmem::Finder;
36+
use std::any::Any;
37+
use std::sync::Arc;
38+
39+
/// Spark-optimized contains function.
40+
///
41+
/// Returns true if the first string argument contains the second string argument.
42+
/// Optimized for the common case where the pattern is a scalar constant.
43+
#[derive(Debug, PartialEq, Eq, Hash)]
44+
pub struct SparkContains {
45+
signature: Signature,
46+
}
47+
48+
impl Default for SparkContains {
49+
fn default() -> Self {
50+
Self::new()
51+
}
52+
}
53+
54+
impl SparkContains {
55+
pub fn new() -> Self {
56+
Self {
57+
signature: Signature::variadic_any(Volatility::Immutable),
58+
}
59+
}
60+
}
61+
62+
impl ScalarUDFImpl for SparkContains {
63+
fn as_any(&self) -> &dyn Any {
64+
self
65+
}
66+
67+
fn name(&self) -> &str {
68+
"contains"
69+
}
70+
71+
fn signature(&self) -> &Signature {
72+
&self.signature
73+
}
74+
75+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
76+
Ok(DataType::Boolean)
77+
}
78+
79+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
80+
if args.args.len() != 2 {
81+
return exec_err!("contains function requires exactly 2 arguments");
82+
}
83+
spark_contains(&args.args[0], &args.args[1])
84+
}
85+
}
86+
87+
/// Execute the contains function with optimized scalar pattern handling.
88+
fn spark_contains(haystack: &ColumnarValue, needle: &ColumnarValue) -> Result<ColumnarValue> {
89+
match (haystack, needle) {
90+
// Case 1: Both are arrays - use arrow's contains directly
91+
(ColumnarValue::Array(haystack_array), ColumnarValue::Array(needle_array)) => {
92+
let result = arrow_contains(haystack_array, needle_array)?;
93+
Ok(ColumnarValue::Array(Arc::new(result)))
94+
}
95+
96+
// Case 2: Haystack is array, needle is scalar - OPTIMIZED PATH
97+
// This is the common case in SQL like: WHERE col CONTAINS 'pattern'
98+
(ColumnarValue::Array(haystack_array), ColumnarValue::Scalar(needle_scalar)) => {
99+
let result = contains_with_scalar_pattern(haystack_array, needle_scalar)?;
100+
Ok(ColumnarValue::Array(result))
101+
}
102+
103+
// Case 3: Haystack is scalar, needle is array - less common
104+
(ColumnarValue::Scalar(haystack_scalar), ColumnarValue::Array(needle_array)) => {
105+
// Convert scalar to array and use arrow's contains
106+
let haystack_array = haystack_scalar.to_array_of_size(needle_array.len())?;
107+
let result = arrow_contains(&haystack_array, needle_array)?;
108+
Ok(ColumnarValue::Array(Arc::new(result)))
109+
}
110+
111+
// Case 4: Both are scalars - compute single result
112+
(ColumnarValue::Scalar(haystack_scalar), ColumnarValue::Scalar(needle_scalar)) => {
113+
let result = contains_scalar_scalar(haystack_scalar, needle_scalar)?;
114+
Ok(ColumnarValue::Scalar(result))
115+
}
116+
}
117+
}
118+
119+
/// Optimized contains for array haystack with scalar needle pattern.
120+
/// Uses memchr's SIMD-optimized Finder for efficient repeated searches.
121+
fn contains_with_scalar_pattern(
122+
haystack_array: &ArrayRef,
123+
needle_scalar: &ScalarValue,
124+
) -> Result<ArrayRef> {
125+
// Handle null needle
126+
if needle_scalar.is_null() {
127+
return Ok(Arc::new(BooleanArray::new_null(haystack_array.len())));
128+
}
129+
130+
// Extract the needle string
131+
let needle_str = match needle_scalar {
132+
ScalarValue::Utf8(Some(s))
133+
| ScalarValue::LargeUtf8(Some(s))
134+
| ScalarValue::Utf8View(Some(s)) => s.as_str(),
135+
_ => {
136+
return exec_err!(
137+
"contains function requires string type for needle, got {:?}",
138+
needle_scalar.data_type()
139+
)
140+
}
141+
};
142+
143+
// Create a reusable Finder for efficient SIMD-optimized searching
144+
let finder = Finder::new(needle_str.as_bytes());
145+
146+
match haystack_array.data_type() {
147+
DataType::Utf8 => {
148+
let array = haystack_array.as_string::<i32>();
149+
let result: BooleanArray = array
150+
.iter()
151+
.map(|opt_haystack| opt_haystack.map(|h| finder.find(h.as_bytes()).is_some()))
152+
.collect();
153+
Ok(Arc::new(result))
154+
}
155+
DataType::LargeUtf8 => {
156+
let array = haystack_array.as_string::<i64>();
157+
let result: BooleanArray = array
158+
.iter()
159+
.map(|opt_haystack| opt_haystack.map(|h| finder.find(h.as_bytes()).is_some()))
160+
.collect();
161+
Ok(Arc::new(result))
162+
}
163+
DataType::Utf8View => {
164+
let array = haystack_array.as_string_view();
165+
let result: BooleanArray = array
166+
.iter()
167+
.map(|opt_haystack| opt_haystack.map(|h| finder.find(h.as_bytes()).is_some()))
168+
.collect();
169+
Ok(Arc::new(result))
170+
}
171+
other => exec_err!(
172+
"contains function requires string type for haystack, got {:?}",
173+
other
174+
),
175+
}
176+
}
177+
178+
/// Contains for two scalar values.
179+
fn contains_scalar_scalar(
180+
haystack_scalar: &ScalarValue,
181+
needle_scalar: &ScalarValue,
182+
) -> Result<ScalarValue> {
183+
// Handle nulls
184+
if haystack_scalar.is_null() || needle_scalar.is_null() {
185+
return Ok(ScalarValue::Boolean(None));
186+
}
187+
188+
let haystack_str = match haystack_scalar {
189+
ScalarValue::Utf8(Some(s))
190+
| ScalarValue::LargeUtf8(Some(s))
191+
| ScalarValue::Utf8View(Some(s)) => s.as_str(),
192+
_ => {
193+
return exec_err!(
194+
"contains function requires string type for haystack, got {:?}",
195+
haystack_scalar.data_type()
196+
)
197+
}
198+
};
199+
200+
let needle_str = match needle_scalar {
201+
ScalarValue::Utf8(Some(s))
202+
| ScalarValue::LargeUtf8(Some(s))
203+
| ScalarValue::Utf8View(Some(s)) => s.as_str(),
204+
_ => {
205+
return exec_err!(
206+
"contains function requires string type for needle, got {:?}",
207+
needle_scalar.data_type()
208+
)
209+
}
210+
};
211+
212+
Ok(ScalarValue::Boolean(Some(
213+
haystack_str.contains(needle_str),
214+
)))
215+
}
216+
217+
#[cfg(test)]
218+
mod tests {
219+
use super::*;
220+
use arrow::array::StringArray;
221+
222+
#[test]
223+
fn test_contains_array_scalar() {
224+
let haystack = Arc::new(StringArray::from(vec![
225+
Some("hello world"),
226+
Some("foo bar"),
227+
Some("testing"),
228+
None,
229+
])) as ArrayRef;
230+
let needle = ScalarValue::Utf8(Some("world".to_string()));
231+
232+
let result = contains_with_scalar_pattern(&haystack, &needle).unwrap();
233+
let bool_array = result.as_any().downcast_ref::<BooleanArray>().unwrap();
234+
235+
assert!(bool_array.value(0)); // "hello world" contains "world"
236+
assert!(!bool_array.value(1)); // "foo bar" does not contain "world"
237+
assert!(!bool_array.value(2)); // "testing" does not contain "world"
238+
assert!(bool_array.is_null(3)); // null input => null output
239+
}
240+
241+
#[test]
242+
fn test_contains_scalar_scalar() {
243+
let haystack = ScalarValue::Utf8(Some("hello world".to_string()));
244+
let needle = ScalarValue::Utf8(Some("world".to_string()));
245+
246+
let result = contains_scalar_scalar(&haystack, &needle).unwrap();
247+
assert_eq!(result, ScalarValue::Boolean(Some(true)));
248+
249+
let needle_not_found = ScalarValue::Utf8(Some("xyz".to_string()));
250+
let result = contains_scalar_scalar(&haystack, &needle_not_found).unwrap();
251+
assert_eq!(result, ScalarValue::Boolean(Some(false)));
252+
}
253+
254+
#[test]
255+
fn test_contains_null_needle() {
256+
let haystack = Arc::new(StringArray::from(vec![
257+
Some("hello world"),
258+
Some("foo bar"),
259+
])) as ArrayRef;
260+
let needle = ScalarValue::Utf8(None);
261+
262+
let result = contains_with_scalar_pattern(&haystack, &needle).unwrap();
263+
let bool_array = result.as_any().downcast_ref::<BooleanArray>().unwrap();
264+
265+
// Null needle should produce null results
266+
assert!(bool_array.is_null(0));
267+
assert!(bool_array.is_null(1));
268+
}
269+
270+
#[test]
271+
fn test_contains_empty_needle() {
272+
let haystack = Arc::new(StringArray::from(vec![Some("hello world"), Some("")])) as ArrayRef;
273+
let needle = ScalarValue::Utf8(Some("".to_string()));
274+
275+
let result = contains_with_scalar_pattern(&haystack, &needle).unwrap();
276+
let bool_array = result.as_any().downcast_ref::<BooleanArray>().unwrap();
277+
278+
// Empty string is contained in any string
279+
assert!(bool_array.value(0));
280+
assert!(bool_array.value(1));
281+
}
282+
}

spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -480,14 +480,16 @@ case class CometScanRule(session: SparkSession) extends Rule[SparkPlan] with Com
480480

481481
// Check for transform functions in residual expressions
482482
// Non-identity transforms (truncate, bucket, year, month, day, hour) in residuals
483-
// are now supported - they skip row-group filtering and are handled post-scan by CometFilter.
483+
// are now supported - they skip row-group filtering and are handled
484+
// post-scan by CometFilter.
484485
// This is less optimal than row-group filtering but still allows native execution.
485486
val transformFunctionsSupported =
486487
try {
487488
IcebergReflection.findNonIdentityTransformInResiduals(metadata.tasks) match {
488489
case Some(transformType) =>
489490
// Found non-identity transform - log info and continue with native scan
490-
// Row-group filtering will skip these predicates, but post-scan filtering will apply
491+
// Row-group filtering will skip these predicates, but post-scan
492+
// filtering will apply
491493
logInfo(
492494
s"Iceberg residual contains transform '$transformType' - " +
493495
"row-group filtering will skip this predicate, " +

0 commit comments

Comments
 (0)