Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
492 changes: 12 additions & 480 deletions datafusion/physical-expr/src/expressions/in_list.rs

Large diffs are not rendered by default.

170 changes: 170 additions & 0 deletions datafusion/physical-expr/src/expressions/in_list/byte_filter.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Hash-based filter for byte array types (Utf8, Binary, and their variants)

use std::marker::PhantomData;

use ahash::RandomState;
use arrow::array::{Array, ArrayRef, AsArray, BooleanArray};
use arrow::buffer::NullBuffer;
use arrow::datatypes::ByteArrayType;
use datafusion_common::Result;
use datafusion_common::hash_utils::with_hashes;
use hashbrown::HashTable;

use super::result::{build_in_list_result_with_null_shortcircuit, handle_dictionary};
use super::static_filter::StaticFilter;

// =============================================================================
// BYTE ACCESS TRAIT
// =============================================================================

/// Trait abstracting byte array access for GenericByteArray types.
pub(crate) trait ByteAccess: 'static {
type Native: PartialEq + AsRef<[u8]> + ?Sized;

/// Get a value from the array at the given index (unchecked for performance).
///
/// # Safety
/// `idx` must be in bounds for the array.
unsafe fn get_unchecked(arr: &dyn Array, idx: usize) -> &Self::Native;

/// Get the null buffer from the array, if any.
fn nulls(arr: &dyn Array) -> Option<&NullBuffer>;
}

/// Marker type for GenericByteArray access (Utf8, LargeUtf8, Binary, LargeBinary).
pub(crate) struct ByteArrayAccess<T: ByteArrayType>(PhantomData<T>);

impl<T> ByteAccess for ByteArrayAccess<T>
where
T: ByteArrayType + 'static,
T::Native: PartialEq,
{
type Native = T::Native;

#[inline(always)]
unsafe fn get_unchecked(arr: &dyn Array, idx: usize) -> &Self::Native {
unsafe { arr.as_bytes::<T>().value_unchecked(idx) }
}

#[inline(always)]
fn nulls(arr: &dyn Array) -> Option<&NullBuffer> {
arr.as_bytes::<T>().nulls()
}
}

// =============================================================================
// BYTE FILTER
// =============================================================================

/// Hash-based filter for byte array types (Utf8, Binary, and their variants).
///
/// Uses HashTable with batch hashing via `with_hashes` for SIMD-optimized
/// hash computation. Stores indices into the haystack array for O(1) lookup.
pub(crate) struct ByteFilter<A: ByteAccess> {
/// The haystack array containing values to match against.
in_array: ArrayRef,
/// HashTable storing indices into `in_array` for O(1) lookup.
table: HashTable<usize>,
/// Random state for consistent hashing between haystack and needles.
state: RandomState,
_phantom: PhantomData<A>,
}

impl<A: ByteAccess> ByteFilter<A> {
pub(crate) fn try_new(in_array: ArrayRef) -> Result<Self> {
let state = RandomState::new();
let mut table = HashTable::new();

// Build haystack table using batch hashing
with_hashes([in_array.as_ref()], &state, |hashes| {
for i in 0..in_array.len() {
if in_array.is_valid(i) {
let hash = hashes[i];

// Only insert if not already present (deduplication)
// SAFETY: i is in bounds and we checked validity
let val: &[u8] =
unsafe { A::get_unchecked(in_array.as_ref(), i) }.as_ref();
if table
.find(hash, |&idx| {
let stored: &[u8] =
unsafe { A::get_unchecked(in_array.as_ref(), idx) }
.as_ref();
stored == val
})
.is_none()
{
table.insert_unique(hash, i, |&idx| hashes[idx]);
}
}
}
Ok::<_, datafusion_common::DataFusionError>(())
})?;

Ok(Self {
in_array,
table,
state,
_phantom: PhantomData,
})
}
}

impl<A: ByteAccess> StaticFilter for ByteFilter<A> {
fn null_count(&self) -> usize {
self.in_array.null_count()
}

fn contains(&self, v: &dyn Array, negated: bool) -> Result<BooleanArray> {
handle_dictionary!(self, v, negated);

let needle_nulls = A::nulls(v);
let needle_null_count = v.null_count();
let haystack_has_nulls = self.in_array.null_count() > 0;

// Batch hash all needle values using SIMD-optimized hashing
with_hashes([v], &self.state, |needle_hashes| {
// Use null shortcircuit version: string comparison is expensive,
// so skipping lookups for null positions is worth the branch overhead
Ok(build_in_list_result_with_null_shortcircuit(
v.len(),
needle_nulls,
needle_null_count,
haystack_has_nulls,
negated,
#[inline(always)]
|i| {
// SAFETY: i is in bounds from build_in_list_result iteration
let needle_val: &[u8] = unsafe { A::get_unchecked(v, i) }.as_ref();
let hash = needle_hashes[i];
// Look up using pre-computed hash, compare via index into haystack
self.table
.find(hash, |&idx| {
let haystack_val: &[u8] =
unsafe { A::get_unchecked(self.in_array.as_ref(), idx) }
.as_ref();
haystack_val == needle_val
})
.is_some()
},
))
})
}
}
168 changes: 168 additions & 0 deletions datafusion/physical-expr/src/expressions/in_list/nested_filter.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Fallback filter for nested/complex types (List, Struct, Map, Union, etc.)

use arrow::array::{
Array, ArrayRef, BooleanArray, downcast_array, downcast_dictionary_array,
make_comparator,
};
use arrow::buffer::{BooleanBuffer, NullBuffer};
use arrow::compute::{SortOptions, take};
use arrow::datatypes::DataType;
use arrow::util::bit_iterator::BitIndexIterator;
use datafusion_common::Result;
use datafusion_common::hash_utils::with_hashes;

use ahash::RandomState;
use hashbrown::HashTable;

use super::result::build_in_list_result;
use super::static_filter::StaticFilter;

/// Fallback filter for nested/complex types (List, Struct, Map, Union, etc.)
///
/// Uses dynamic comparator via `make_comparator` since these types don't have
/// a simple typed comparison. For primitive and byte array types, use the
/// specialized filters instead (PrimitiveFilter, ByteArrayFilter, etc.)
#[derive(Debug, Clone)]
pub(crate) struct NestedTypeFilter {
in_array: ArrayRef,
state: RandomState,
/// Stores indices into `in_array` for O(1) lookups.
table: HashTable<usize>,
}

impl NestedTypeFilter {
/// Creates a filter for nested/complex array types.
///
/// This filter uses dynamic comparison and should only be used for types
/// that don't have specialized filters (List, Struct, Map, Union).
pub(crate) fn try_new(in_array: ArrayRef) -> Result<Self> {
// Null type has no natural order - return empty hash set
if in_array.data_type() == &DataType::Null {
return Ok(Self {
in_array,
state: RandomState::new(),
table: HashTable::new(),
});
}

let state = RandomState::new();
let table = Self::build_haystack_table(&in_array, &state)?;

Ok(Self {
in_array,
state,
table,
})
}

/// Build a hash table from haystack values for O(1) lookups.
///
/// Each unique non-null value's index is stored, keyed by its hash.
/// Uses dynamic comparison via `make_comparator` for complex types.
fn build_haystack_table(
haystack: &ArrayRef,
state: &RandomState,
) -> Result<HashTable<usize>> {
let mut table = HashTable::new();

with_hashes([haystack.as_ref()], state, |hashes| -> Result<()> {
let cmp = make_comparator(haystack, haystack, SortOptions::default())?;

let insert_value = |idx| {
let hash = hashes[idx];
// Only insert if not already present (deduplication)
if table.find(hash, |&x| cmp(x, idx).is_eq()).is_none() {
table.insert_unique(hash, idx, |&x| hashes[x]);
}
};

match haystack.nulls() {
Some(nulls) => {
BitIndexIterator::new(nulls.validity(), nulls.offset(), nulls.len())
.for_each(insert_value)
}
None => (0..haystack.len()).for_each(insert_value),
}

Ok(())
})?;

Ok(table)
}

/// Check which needle values exist in the haystack.
///
/// Hashes each needle value and looks it up in the pre-built haystack table.
/// Uses dynamic comparison via `make_comparator` for complex types.
fn find_needles_in_haystack(
&self,
needles: &dyn Array,
negated: bool,
) -> Result<BooleanArray> {
let needle_nulls = needles.logical_nulls();
let haystack_has_nulls = self.in_array.null_count() != 0;

with_hashes([needles], &self.state, |needle_hashes| {
let cmp = make_comparator(needles, &self.in_array, SortOptions::default())?;

Ok(build_in_list_result(
needles.len(),
needle_nulls.as_ref(),
haystack_has_nulls,
negated,
#[inline(always)]
|i| {
let hash = needle_hashes[i];
self.table.find(hash, |&idx| cmp(i, idx).is_eq()).is_some()
},
))
})
}
}

impl StaticFilter for NestedTypeFilter {
fn null_count(&self) -> usize {
self.in_array.null_count()
}

fn contains(&self, v: &dyn Array, negated: bool) -> Result<BooleanArray> {
// Null type comparisons always return null (SQL three-valued logic)
if v.data_type() == &DataType::Null
|| self.in_array.data_type() == &DataType::Null
{
let nulls = NullBuffer::new_null(v.len());
return Ok(BooleanArray::new(
BooleanBuffer::new_unset(v.len()),
Some(nulls),
));
}

downcast_dictionary_array! {
v => {
let values_contains = self.contains(v.values().as_ref(), negated)?;
let result = take(&values_contains, v.keys(), None)?;
return Ok(downcast_array(result.as_ref()))
}
_ => {}
}

self.find_needles_in_haystack(v, negated)
}
}
Loading