Skip to content

Commit f2e6748

Browse files
committed
Allow Spark partial / Comet final for compatible aggregates
1 parent dca45ea commit f2e6748

File tree

6 files changed

+377
-6
lines changed

6 files changed

+377
-6
lines changed
Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! Optimized `contains` string function for Spark compatibility.
19+
//!
20+
//! This implementation is optimized for the common case where the pattern
21+
//! (second argument) is a scalar value. In this case, we use `memchr::memmem::Finder`
22+
//! which is SIMD-optimized and reuses a single finder instance across all rows.
23+
//!
24+
//! The DataFusion built-in `contains` function uses `make_scalar_function` which
25+
//! expands scalar values to arrays, losing the performance benefit of the optimized
26+
//! scalar path in arrow-rs.
27+
28+
use arrow::array::{Array, ArrayRef, AsArray, BooleanArray};
29+
use arrow::datatypes::DataType;
30+
use arrow_string::like::contains as arrow_contains;
31+
use datafusion::common::{exec_err, Result, ScalarValue};
32+
use datafusion::logical_expr::{
33+
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
34+
};
35+
use memchr::memmem::Finder;
36+
use std::any::Any;
37+
use std::sync::Arc;
38+
39+
/// Spark-optimized contains function.
40+
///
41+
/// Returns true if the first string argument contains the second string argument.
42+
/// Optimized for the common case where the pattern is a scalar constant.
43+
#[derive(Debug, PartialEq, Eq, Hash)]
44+
pub struct SparkContains {
45+
signature: Signature,
46+
}
47+
48+
impl Default for SparkContains {
49+
fn default() -> Self {
50+
Self::new()
51+
}
52+
}
53+
54+
impl SparkContains {
55+
pub fn new() -> Self {
56+
Self {
57+
signature: Signature::variadic_any(Volatility::Immutable),
58+
}
59+
}
60+
}
61+
62+
impl ScalarUDFImpl for SparkContains {
63+
fn as_any(&self) -> &dyn Any {
64+
self
65+
}
66+
67+
fn name(&self) -> &str {
68+
"contains"
69+
}
70+
71+
fn signature(&self) -> &Signature {
72+
&self.signature
73+
}
74+
75+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
76+
Ok(DataType::Boolean)
77+
}
78+
79+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
80+
if args.args.len() != 2 {
81+
return exec_err!("contains function requires exactly 2 arguments");
82+
}
83+
spark_contains(&args.args[0], &args.args[1])
84+
}
85+
}
86+
87+
/// Execute the contains function with optimized scalar pattern handling.
88+
fn spark_contains(haystack: &ColumnarValue, needle: &ColumnarValue) -> Result<ColumnarValue> {
89+
match (haystack, needle) {
90+
// Case 1: Both are arrays - use arrow's contains directly
91+
(ColumnarValue::Array(haystack_array), ColumnarValue::Array(needle_array)) => {
92+
let result = arrow_contains(haystack_array, needle_array)?;
93+
Ok(ColumnarValue::Array(Arc::new(result)))
94+
}
95+
96+
// Case 2: Haystack is array, needle is scalar - OPTIMIZED PATH
97+
// This is the common case in SQL like: WHERE col CONTAINS 'pattern'
98+
(ColumnarValue::Array(haystack_array), ColumnarValue::Scalar(needle_scalar)) => {
99+
let result = contains_with_scalar_pattern(haystack_array, needle_scalar)?;
100+
Ok(ColumnarValue::Array(result))
101+
}
102+
103+
// Case 3: Haystack is scalar, needle is array - less common
104+
(ColumnarValue::Scalar(haystack_scalar), ColumnarValue::Array(needle_array)) => {
105+
// Convert scalar to array and use arrow's contains
106+
let haystack_array = haystack_scalar.to_array_of_size(needle_array.len())?;
107+
let result = arrow_contains(&haystack_array, needle_array)?;
108+
Ok(ColumnarValue::Array(Arc::new(result)))
109+
}
110+
111+
// Case 4: Both are scalars - compute single result
112+
(ColumnarValue::Scalar(haystack_scalar), ColumnarValue::Scalar(needle_scalar)) => {
113+
let result = contains_scalar_scalar(haystack_scalar, needle_scalar)?;
114+
Ok(ColumnarValue::Scalar(result))
115+
}
116+
}
117+
}
118+
119+
/// Optimized contains for array haystack with scalar needle pattern.
120+
/// Uses memchr's SIMD-optimized Finder for efficient repeated searches.
121+
fn contains_with_scalar_pattern(
122+
haystack_array: &ArrayRef,
123+
needle_scalar: &ScalarValue,
124+
) -> Result<ArrayRef> {
125+
// Handle null needle
126+
if needle_scalar.is_null() {
127+
return Ok(Arc::new(BooleanArray::new_null(haystack_array.len())));
128+
}
129+
130+
// Extract the needle string
131+
let needle_str = match needle_scalar {
132+
ScalarValue::Utf8(Some(s))
133+
| ScalarValue::LargeUtf8(Some(s))
134+
| ScalarValue::Utf8View(Some(s)) => s.as_str(),
135+
_ => {
136+
return exec_err!(
137+
"contains function requires string type for needle, got {:?}",
138+
needle_scalar.data_type()
139+
)
140+
}
141+
};
142+
143+
// Create a reusable Finder for efficient SIMD-optimized searching
144+
let finder = Finder::new(needle_str.as_bytes());
145+
146+
match haystack_array.data_type() {
147+
DataType::Utf8 => {
148+
let array = haystack_array.as_string::<i32>();
149+
let result: BooleanArray = array
150+
.iter()
151+
.map(|opt_haystack| opt_haystack.map(|h| finder.find(h.as_bytes()).is_some()))
152+
.collect();
153+
Ok(Arc::new(result))
154+
}
155+
DataType::LargeUtf8 => {
156+
let array = haystack_array.as_string::<i64>();
157+
let result: BooleanArray = array
158+
.iter()
159+
.map(|opt_haystack| opt_haystack.map(|h| finder.find(h.as_bytes()).is_some()))
160+
.collect();
161+
Ok(Arc::new(result))
162+
}
163+
DataType::Utf8View => {
164+
let array = haystack_array.as_string_view();
165+
let result: BooleanArray = array
166+
.iter()
167+
.map(|opt_haystack| opt_haystack.map(|h| finder.find(h.as_bytes()).is_some()))
168+
.collect();
169+
Ok(Arc::new(result))
170+
}
171+
other => exec_err!(
172+
"contains function requires string type for haystack, got {:?}",
173+
other
174+
),
175+
}
176+
}
177+
178+
/// Contains for two scalar values.
179+
fn contains_scalar_scalar(
180+
haystack_scalar: &ScalarValue,
181+
needle_scalar: &ScalarValue,
182+
) -> Result<ScalarValue> {
183+
// Handle nulls
184+
if haystack_scalar.is_null() || needle_scalar.is_null() {
185+
return Ok(ScalarValue::Boolean(None));
186+
}
187+
188+
let haystack_str = match haystack_scalar {
189+
ScalarValue::Utf8(Some(s))
190+
| ScalarValue::LargeUtf8(Some(s))
191+
| ScalarValue::Utf8View(Some(s)) => s.as_str(),
192+
_ => {
193+
return exec_err!(
194+
"contains function requires string type for haystack, got {:?}",
195+
haystack_scalar.data_type()
196+
)
197+
}
198+
};
199+
200+
let needle_str = match needle_scalar {
201+
ScalarValue::Utf8(Some(s))
202+
| ScalarValue::LargeUtf8(Some(s))
203+
| ScalarValue::Utf8View(Some(s)) => s.as_str(),
204+
_ => {
205+
return exec_err!(
206+
"contains function requires string type for needle, got {:?}",
207+
needle_scalar.data_type()
208+
)
209+
}
210+
};
211+
212+
Ok(ScalarValue::Boolean(Some(
213+
haystack_str.contains(needle_str),
214+
)))
215+
}
216+
217+
#[cfg(test)]
218+
mod tests {
219+
use super::*;
220+
use arrow::array::StringArray;
221+
222+
#[test]
223+
fn test_contains_array_scalar() {
224+
let haystack = Arc::new(StringArray::from(vec![
225+
Some("hello world"),
226+
Some("foo bar"),
227+
Some("testing"),
228+
None,
229+
])) as ArrayRef;
230+
let needle = ScalarValue::Utf8(Some("world".to_string()));
231+
232+
let result = contains_with_scalar_pattern(&haystack, &needle).unwrap();
233+
let bool_array = result.as_any().downcast_ref::<BooleanArray>().unwrap();
234+
235+
assert_eq!(bool_array.value(0), true); // "hello world" contains "world"
236+
assert_eq!(bool_array.value(1), false); // "foo bar" does not contain "world"
237+
assert_eq!(bool_array.value(2), false); // "testing" does not contain "world"
238+
assert!(bool_array.is_null(3)); // null input => null output
239+
}
240+
241+
#[test]
242+
fn test_contains_scalar_scalar() {
243+
let haystack = ScalarValue::Utf8(Some("hello world".to_string()));
244+
let needle = ScalarValue::Utf8(Some("world".to_string()));
245+
246+
let result = contains_scalar_scalar(&haystack, &needle).unwrap();
247+
assert_eq!(result, ScalarValue::Boolean(Some(true)));
248+
249+
let needle_not_found = ScalarValue::Utf8(Some("xyz".to_string()));
250+
let result = contains_scalar_scalar(&haystack, &needle_not_found).unwrap();
251+
assert_eq!(result, ScalarValue::Boolean(Some(false)));
252+
}
253+
254+
#[test]
255+
fn test_contains_null_needle() {
256+
let haystack = Arc::new(StringArray::from(vec![
257+
Some("hello world"),
258+
Some("foo bar"),
259+
])) as ArrayRef;
260+
let needle = ScalarValue::Utf8(None);
261+
262+
let result = contains_with_scalar_pattern(&haystack, &needle).unwrap();
263+
let bool_array = result.as_any().downcast_ref::<BooleanArray>().unwrap();
264+
265+
// Null needle should produce null results
266+
assert!(bool_array.is_null(0));
267+
assert!(bool_array.is_null(1));
268+
}
269+
270+
#[test]
271+
fn test_contains_empty_needle() {
272+
let haystack = Arc::new(StringArray::from(vec![Some("hello world"), Some("")])) as ArrayRef;
273+
let needle = ScalarValue::Utf8(Some("".to_string()));
274+
275+
let result = contains_with_scalar_pattern(&haystack, &needle).unwrap();
276+
let bool_array = result.as_any().downcast_ref::<BooleanArray>().unwrap();
277+
278+
// Empty string is contained in any string
279+
assert_eq!(bool_array.value(0), true);
280+
assert_eq!(bool_array.value(1), true);
281+
}
282+
}

spark/src/main/scala/org/apache/comet/serde/CometAggregateExpressionSerde.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,20 @@ trait CometAggregateExpressionSerde[T <: AggregateFunction] {
4949
*/
5050
def getSupportLevel(expr: T): SupportLevel = Compatible(None)
5151

52+
/**
53+
* Indicates whether this aggregate function supports "Spark partial / Comet final" mixed
54+
* execution. This requires the intermediate buffer format to be compatible between Spark and
55+
* Comet.
56+
*
57+
* Only aggregates with simple, compatible intermediate buffers should return true. Aggregates
58+
* with complex buffers or those with known incompatibilities (e.g., decimal overflow handling
59+
* differences) should return false.
60+
*
61+
* @return
62+
* true if the aggregate can safely run with Spark partial and Comet final, false otherwise
63+
*/
64+
def supportsSparkPartialCometFinal: Boolean = false
65+
5266
/**
5367
* Convert a Spark expression into a protocol buffer representation that can be passed into
5468
* native code.

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,23 @@ object QueryPlanSerde extends Logging with CometExprShim {
253253
classOf[VariancePop] -> CometVariancePop,
254254
classOf[VarianceSamp] -> CometVarianceSamp)
255255

256+
/**
257+
* Checks if the given aggregate function supports "Spark partial / Comet final" mixed
258+
* execution. This is used to determine if Comet can process a final aggregate even when the
259+
* partial aggregate was performed by Spark.
260+
*
261+
* @param fn
262+
* The aggregate function to check
263+
* @return
264+
* true if the aggregate supports mixed execution, false otherwise
265+
*/
266+
def aggSupportsMixedExecution(fn: AggregateFunction): Boolean = {
267+
aggrSerdeMap.get(fn.getClass) match {
268+
case Some(handler) => handler.supportsSparkPartialCometFinal
269+
case None => false
270+
}
271+
}
272+
256273
def supportedDataType(dt: DataType, allowComplex: Boolean = false): Boolean = dt match {
257274
case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType |
258275
_: DoubleType | _: StringType | _: BinaryType | _: TimestampType | _: TimestampNTZType |

spark/src/main/scala/org/apache/comet/serde/aggregates.scala

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ import org.apache.comet.shims.CometEvalModeUtil
3434

3535
object CometMin extends CometAggregateExpressionSerde[Min] {
3636

37+
// Min has a simple intermediate buffer (single value) compatible between Spark and Comet
38+
override def supportsSparkPartialCometFinal: Boolean = true
39+
3740
override def convert(
3841
aggExpr: AggregateExpression,
3942
expr: Min,
@@ -81,6 +84,9 @@ object CometMin extends CometAggregateExpressionSerde[Min] {
8184

8285
object CometMax extends CometAggregateExpressionSerde[Max] {
8386

87+
// Max has a simple intermediate buffer (single value) compatible between Spark and Comet
88+
override def supportsSparkPartialCometFinal: Boolean = true
89+
8490
override def convert(
8591
aggExpr: AggregateExpression,
8692
expr: Max,
@@ -127,6 +133,10 @@ object CometMax extends CometAggregateExpressionSerde[Max] {
127133
}
128134

129135
object CometCount extends CometAggregateExpressionSerde[Count] {
136+
137+
// Count has a simple intermediate buffer (single Long) compatible between Spark and Comet
138+
override def supportsSparkPartialCometFinal: Boolean = true
139+
130140
override def convert(
131141
aggExpr: AggregateExpression,
132142
expr: Count,
@@ -317,6 +327,11 @@ object CometLast extends CometAggregateExpressionSerde[Last] {
317327
}
318328

319329
object CometBitAndAgg extends CometAggregateExpressionSerde[BitAndAgg] {
330+
331+
// BitAnd has a simple intermediate buffer (single integral value)
332+
// compatible between Spark and Comet
333+
override def supportsSparkPartialCometFinal: Boolean = true
334+
320335
override def convert(
321336
aggExpr: AggregateExpression,
322337
bitAnd: BitAndAgg,
@@ -351,6 +366,11 @@ object CometBitAndAgg extends CometAggregateExpressionSerde[BitAndAgg] {
351366
}
352367

353368
object CometBitOrAgg extends CometAggregateExpressionSerde[BitOrAgg] {
369+
370+
// BitOr has a simple intermediate buffer (single integral value)
371+
// compatible between Spark and Comet
372+
override def supportsSparkPartialCometFinal: Boolean = true
373+
354374
override def convert(
355375
aggExpr: AggregateExpression,
356376
bitOr: BitOrAgg,
@@ -385,6 +405,11 @@ object CometBitOrAgg extends CometAggregateExpressionSerde[BitOrAgg] {
385405
}
386406

387407
object CometBitXOrAgg extends CometAggregateExpressionSerde[BitXorAgg] {
408+
409+
// BitXor has a simple intermediate buffer (single integral value)
410+
// compatible between Spark and Comet
411+
override def supportsSparkPartialCometFinal: Boolean = true
412+
388413
override def convert(
389414
aggExpr: AggregateExpression,
390415
bitXor: BitXorAgg,

0 commit comments

Comments
 (0)