Skip to content

Commit b7ccb53

Browse files
authored
[branch-53] backport: Support Spark array_contains builtin function (apache#20685) (apache#20914)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes apache#123` indicates that this PR will close issue apache#123. --> - Closes apache#20611 . ## Rationale for this change <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> ## What changes are included in this PR? The Spark function is actual wrapper on top of `array_has` function. After result is being produced the nulls mask is set respectively for the output indices which correspond to input rows having nulls <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> ## Are these changes tested? <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> ## Are there any user-facing changes? <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> <!-- If there are any breaking changes to public APIs, please add the `api change` label. --> ## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes apache#123` indicates that this PR will close issue apache#123. --> - Closes #. ## Rationale for this change <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> ## What changes are included in this PR? <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> ## Are these changes tested? <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> ## Are there any user-facing changes? <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> <!-- If there are any breaking changes to public APIs, please add the `api change` label. -->
1 parent ed25cc2 commit b7ccb53

File tree

3 files changed

+322
-1
lines changed

3 files changed

+322
-1
lines changed
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::array::{
19+
Array, AsArray, BooleanArray, BooleanBufferBuilder, GenericListArray, OffsetSizeTrait,
20+
};
21+
use arrow::buffer::{BooleanBuffer, NullBuffer};
22+
use arrow::datatypes::DataType;
23+
use datafusion_common::{Result, exec_err};
24+
use datafusion_expr::{
25+
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
26+
};
27+
use datafusion_functions_nested::array_has::array_has_udf;
28+
use std::any::Any;
29+
use std::sync::Arc;
30+
31+
/// Spark-compatible `array_contains` function.
32+
///
33+
/// Calls DataFusion's `array_has` and then applies Spark's null semantics:
34+
/// - If the result from `array_has` is `true`, return `true`.
35+
/// - If the result is `false` and the input array row contains any null elements,
36+
/// return `null` (because the element might have been the null).
37+
/// - If the result is `false` and the input array row has no null elements,
38+
/// return `false`.
39+
#[derive(Debug, PartialEq, Eq, Hash)]
40+
pub struct SparkArrayContains {
41+
signature: Signature,
42+
}
43+
44+
impl Default for SparkArrayContains {
45+
fn default() -> Self {
46+
Self::new()
47+
}
48+
}
49+
50+
impl SparkArrayContains {
51+
pub fn new() -> Self {
52+
Self {
53+
signature: Signature::array_and_element(Volatility::Immutable),
54+
}
55+
}
56+
}
57+
58+
impl ScalarUDFImpl for SparkArrayContains {
59+
fn as_any(&self) -> &dyn Any {
60+
self
61+
}
62+
63+
fn name(&self) -> &str {
64+
"array_contains"
65+
}
66+
67+
fn signature(&self) -> &Signature {
68+
&self.signature
69+
}
70+
71+
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
72+
Ok(DataType::Boolean)
73+
}
74+
75+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
76+
let haystack = args.args[0].clone();
77+
let array_has_result = array_has_udf().invoke_with_args(args)?;
78+
79+
let result_array = array_has_result.to_array(1)?;
80+
let patched = apply_spark_null_semantics(result_array.as_boolean(), &haystack)?;
81+
Ok(ColumnarValue::Array(Arc::new(patched)))
82+
}
83+
}
84+
85+
/// For each row where `array_has` returned `false`, set the output to null
86+
/// if that row's input array contains any null elements.
87+
fn apply_spark_null_semantics(
88+
result: &BooleanArray,
89+
haystack_arg: &ColumnarValue,
90+
) -> Result<BooleanArray> {
91+
// happy path
92+
if result.false_count() == 0 || haystack_arg.data_type() == DataType::Null {
93+
return Ok(result.clone());
94+
}
95+
96+
let haystack = haystack_arg.to_array_of_size(result.len())?;
97+
98+
let row_has_nulls = compute_row_has_nulls(&haystack)?;
99+
100+
// A row keeps its validity when result is true OR the row has no nulls.
101+
let keep_mask = result.values() | &!&row_has_nulls;
102+
let new_validity = match result.nulls() {
103+
Some(n) => n.inner() & &keep_mask,
104+
None => keep_mask,
105+
};
106+
107+
Ok(BooleanArray::new(
108+
result.values().clone(),
109+
Some(NullBuffer::new(new_validity)),
110+
))
111+
}
112+
113+
/// Returns a per-row bitmap where bit i is set if row i's list contains any null element.
114+
fn compute_row_has_nulls(haystack: &dyn Array) -> Result<BooleanBuffer> {
115+
match haystack.data_type() {
116+
DataType::List(_) => generic_list_row_has_nulls(haystack.as_list::<i32>()),
117+
DataType::LargeList(_) => generic_list_row_has_nulls(haystack.as_list::<i64>()),
118+
DataType::FixedSizeList(_, _) => {
119+
let list = haystack.as_fixed_size_list();
120+
let buf = match list.values().nulls() {
121+
Some(nulls) => {
122+
let validity = nulls.inner();
123+
let vl = list.value_length() as usize;
124+
let mut builder = BooleanBufferBuilder::new(list.len());
125+
for i in 0..list.len() {
126+
builder.append(validity.slice(i * vl, vl).count_set_bits() < vl);
127+
}
128+
builder.finish()
129+
}
130+
None => BooleanBuffer::new_unset(list.len()),
131+
};
132+
Ok(mask_with_list_nulls(buf, list.nulls()))
133+
}
134+
dt => exec_err!("compute_row_has_nulls: unsupported data type {dt}"),
135+
}
136+
}
137+
138+
/// Computes per-row null presence for `List` and `LargeList` arrays.
139+
fn generic_list_row_has_nulls<O: OffsetSizeTrait>(
140+
list: &GenericListArray<O>,
141+
) -> Result<BooleanBuffer> {
142+
let buf = match list.values().nulls() {
143+
Some(nulls) => {
144+
let validity = nulls.inner();
145+
let offsets = list.offsets();
146+
let mut builder = BooleanBufferBuilder::new(list.len());
147+
for i in 0..list.len() {
148+
let s = offsets[i].as_usize();
149+
let len = offsets[i + 1].as_usize() - s;
150+
builder.append(validity.slice(s, len).count_set_bits() < len);
151+
}
152+
builder.finish()
153+
}
154+
None => BooleanBuffer::new_unset(list.len()),
155+
};
156+
Ok(mask_with_list_nulls(buf, list.nulls()))
157+
}
158+
159+
/// Rows where the list itself is null should not be marked as "has nulls".
160+
fn mask_with_list_nulls(
161+
buf: BooleanBuffer,
162+
list_nulls: Option<&NullBuffer>,
163+
) -> BooleanBuffer {
164+
match list_nulls {
165+
Some(n) => &buf & n.inner(),
166+
None => buf,
167+
}
168+
}

datafusion/spark/src/function/array/mod.rs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
pub mod array_contains;
1819
pub mod repeat;
1920
pub mod shuffle;
2021
pub mod slice;
@@ -24,6 +25,7 @@ use datafusion_expr::ScalarUDF;
2425
use datafusion_functions::make_udf_function;
2526
use std::sync::Arc;
2627

28+
make_udf_function!(array_contains::SparkArrayContains, spark_array_contains);
2729
make_udf_function!(spark_array::SparkArray, array);
2830
make_udf_function!(shuffle::SparkShuffle, shuffle);
2931
make_udf_function!(repeat::SparkArrayRepeat, array_repeat);
@@ -32,6 +34,11 @@ make_udf_function!(slice::SparkSlice, slice);
3234
pub mod expr_fn {
3335
use datafusion_functions::export_functions;
3436

37+
export_functions!((
38+
spark_array_contains,
39+
"Returns true if the array contains the element (Spark semantics).",
40+
array element
41+
));
3542
export_functions!((array, "Returns an array with the given elements.", args));
3643
export_functions!((
3744
shuffle,
@@ -51,5 +58,11 @@ pub mod expr_fn {
5158
}
5259

5360
pub fn functions() -> Vec<Arc<ScalarUDF>> {
54-
vec![array(), shuffle(), array_repeat(), slice()]
61+
vec![
62+
spark_array_contains(),
63+
array(),
64+
shuffle(),
65+
array_repeat(),
66+
slice(),
67+
]
5568
}
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
# Tests for Spark-compatible array_contains function.
19+
# Spark semantics: if element is found -> true; if not found and array has nulls -> null; if not found and no nulls -> false.
20+
21+
###
22+
### Scalar tests
23+
###
24+
25+
# Element found in array
26+
query B
27+
SELECT array_contains(array(1, 2, 3), 2);
28+
----
29+
true
30+
31+
# Element not found, no nulls in array
32+
query B
33+
SELECT array_contains(array(1, 2, 3), 4);
34+
----
35+
false
36+
37+
# Element not found, array has null elements -> null
38+
query B
39+
SELECT array_contains(array(1, NULL, 3), 2);
40+
----
41+
NULL
42+
43+
# Element found, array has null elements -> true (nulls don't matter)
44+
query B
45+
SELECT array_contains(array(1, NULL, 3), 1);
46+
----
47+
true
48+
49+
# Element found at the end, array has null elements -> true
50+
query B
51+
SELECT array_contains(array(1, NULL, 3), 3);
52+
----
53+
true
54+
55+
# Null array -> null
56+
query B
57+
SELECT array_contains(NULL, 1);
58+
----
59+
NULL
60+
61+
# Null element -> null
62+
query B
63+
SELECT array_contains(array(1, 2, 3), NULL);
64+
----
65+
NULL
66+
67+
# Empty array, element not found -> false
68+
query B
69+
SELECT array_contains(array(), 1);
70+
----
71+
false
72+
73+
# Array with only nulls, element not found -> null
74+
query B
75+
SELECT array_contains(array(NULL, NULL), 1);
76+
----
77+
NULL
78+
79+
# String array, element found
80+
query B
81+
SELECT array_contains(array('a', 'b', 'c'), 'b');
82+
----
83+
true
84+
85+
# String array, element not found, no nulls
86+
query B
87+
SELECT array_contains(array('a', 'b', 'c'), 'd');
88+
----
89+
false
90+
91+
# String array, element not found, has null
92+
query B
93+
SELECT array_contains(array('a', NULL, 'c'), 'd');
94+
----
95+
NULL
96+
97+
###
98+
### Columnar tests with a table
99+
###
100+
101+
statement ok
102+
CREATE TABLE test_arrays AS VALUES
103+
(1, make_array(1, 2, 3), 10),
104+
(2, make_array(4, NULL, 6), 5),
105+
(3, make_array(7, 8, 9), 10),
106+
(4, NULL, 1),
107+
(5, make_array(10, NULL, NULL), 10);
108+
109+
# Column needle against column array
110+
query IBB
111+
SELECT column1,
112+
array_contains(column2, column3),
113+
array_contains(column2, 10)
114+
FROM test_arrays
115+
ORDER BY column1;
116+
----
117+
1 false false
118+
2 NULL NULL
119+
3 false false
120+
4 NULL NULL
121+
5 true true
122+
123+
statement ok
124+
DROP TABLE test_arrays;
125+
126+
###
127+
### Nested array tests
128+
###
129+
130+
# Nested array element found
131+
query B
132+
SELECT array_contains(array(array(1, 2), array(3, 4)), array(3, 4));
133+
----
134+
true
135+
136+
# Nested array element not found, no nulls
137+
query B
138+
SELECT array_contains(array(array(1, 2), array(3, 4)), array(5, 6));
139+
----
140+
false

0 commit comments

Comments
 (0)