Skip to content

Commit 565d6b8

Browse files
Implement Legacy String Optimization (Utf8TwoStageFilter)
Port of the two-stage View optimization to standard Utf8 and LargeUtf8 types. Encodes strings as i128 (len + prefix) for fast O(1) pre-filtering before falling back to full string comparison. Triggers for Utf8 and LargeUtf8.
1 parent 058e35a commit 565d6b8

File tree

3 files changed

+283
-11
lines changed

3 files changed

+283
-11
lines changed

datafusion/physical-expr/src/expressions/in_list.rs

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,13 @@
1717

1818
//! Implementation of `InList` expressions: [`InListExpr`]
1919
20+
mod nested_filter;
21+
mod primitive_filter;
22+
mod result;
23+
mod static_filter;
24+
mod strategy;
25+
mod transform;
26+
2027
use std::any::Any;
2128
use std::fmt::Debug;
2229
use std::hash::{Hash, Hasher};
@@ -30,19 +37,11 @@ use arrow::buffer::{BooleanBuffer, NullBuffer};
3037
use arrow::compute::SortOptions;
3138
use arrow::compute::kernels::boolean::{not, or_kleene};
3239
use arrow::datatypes::*;
33-
3440
use datafusion_common::{
3541
DFSchema, Result, ScalarValue, assert_or_internal_err, exec_err,
3642
};
3743
use datafusion_expr::{ColumnarValue, expr_vec_fmt};
3844

39-
mod nested_filter;
40-
mod primitive_filter;
41-
mod result;
42-
mod static_filter;
43-
mod strategy;
44-
mod transform;
45-
4645
use static_filter::StaticFilter;
4746
use strategy::instantiate_static_filter;
4847

datafusion/physical-expr/src/expressions/in_list/strategy.rs

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,17 @@
1616
// under the License.
1717

1818
//! Filter selection strategy for InList expressions
19+
//!
20+
//! Selects the optimal lookup strategy based on data type and list size:
21+
//!
22+
//! - 1-byte types (Int8/UInt8): bitmap (32 bytes, O(1) bit test)
23+
//! - 2-byte types (Int16/UInt16): bitmap (8 KB, O(1) bit test)
24+
//! - 4-byte types (Int32/Float32): branchless (≤32) or hash (>32)
25+
//! - 8-byte types (Int64/Float64): branchless (≤16) or hash (>16)
26+
//! - 16-byte types (Decimal128): branchless (≤4) or hash (>4)
27+
//! - Utf8View (short strings): branchless (≤4) or hash (>4)
28+
//! - Byte arrays (Utf8, Binary, etc.): ByteArrayFilter / ByteViewFilter
29+
//! - Other types: NestedTypeFilter (fallback for List, Struct, Map, etc.)
1930
2031
use std::sync::Arc;
2132

@@ -29,13 +40,25 @@ use super::result::handle_dictionary;
2940
use super::static_filter::StaticFilter;
3041
use super::transform::{
3142
make_bitmap_filter, make_branchless_filter, make_byte_view_masked_filter,
32-
make_utf8view_branchless_filter, make_utf8view_hash_filter,
33-
reinterpret_any_primitive_to, utf8view_all_short_strings,
43+
make_utf8_two_stage_filter, make_utf8view_branchless_filter,
44+
make_utf8view_hash_filter, utf8view_all_short_strings,
3445
};
3546

3647
// =============================================================================
3748
// LOOKUP STRATEGY THRESHOLDS (tuned via microbenchmarks)
3849
// =============================================================================
50+
//
51+
// Based on minimum batch time (8192 lookups per batch):
52+
// - Int8 (1 byte): BITMAP (32 bytes, always fastest)
53+
// - Int16 (2 bytes): BITMAP (8 KB, always fastest)
54+
// - Int32 (4 bytes): branchless up to 32, then hashset
55+
// - Int64 (8 bytes): branchless up to 16, then hashset
56+
// - Int128 (16 bytes): branchless up to 4, then hashset
57+
// - Byte arrays: ByteArrayFilter / ByteViewFilter
58+
// - Other types: NestedTypeFilter (fallback for List, Struct, Map, etc.)
59+
//
60+
// NOTE: Binary search and linear scan were benchmarked but consistently
61+
// lost to the strategies above at all tested list sizes.
3962

4063
/// Maximum list size for branchless lookup on 4-byte primitives (Int32, UInt32, Float32).
4164
const BRANCHLESS_MAX_4B: usize = 32;
@@ -65,6 +88,10 @@ enum FilterStrategy {
6588
}
6689

6790
/// Determines the optimal lookup strategy based on data type and list size.
91+
///
92+
/// For 1-byte and 2-byte types, bitmap is always used (benchmarks show it's
93+
/// faster than both branchless and hashed at all list sizes).
94+
/// For larger types, cutoffs are tuned per byte-width.
6895
fn select_strategy(dt: &DataType, len: usize) -> FilterStrategy {
6996
match dt.primitive_width() {
7097
Some(1) => FilterStrategy::Bitmap1B,
@@ -99,6 +126,9 @@ fn select_strategy(dt: &DataType, len: usize) -> FilterStrategy {
99126
// =============================================================================
100127

101128
/// Creates the optimal static filter for the given array.
129+
///
130+
/// This is the main entry point for filter creation. It analyzes the array's
131+
/// data type and size to select the best lookup strategy.
102132
pub(crate) fn instantiate_static_filter(
103133
in_array: ArrayRef,
104134
) -> Result<Arc<dyn StaticFilter + Send + Sync>> {
@@ -136,15 +166,28 @@ pub(crate) fn instantiate_static_filter(
136166
exec_datafusion_err!("Hashed strategy selected but no filter for {:?}", dt)
137167
})?,
138168

169+
// Utf8/LargeUtf8: Two-stage filter with length+prefix quick rejection
170+
// Stage 1: O(1) lookup of encoded i128 (length + first 12 bytes)
171+
// Stage 2: Full string comparison only for long strings that pass Stage 1
172+
(DataType::Utf8 | DataType::LargeUtf8, Generic) => {
173+
make_utf8_two_stage_filter(in_array)
174+
}
175+
176+
// Binary variants: Use NestedTypeFilter (make_comparator)
177+
(DataType::Binary | DataType::LargeBinary, Generic) => {
178+
Ok(Arc::new(NestedTypeFilter::try_new(in_array)?))
179+
}
180+
139181
// Byte view filters (Utf8View, BinaryView)
182+
// Both use two-stage filter: masked view pre-check + full verification
140183
(DataType::Utf8View, Generic) => {
141184
make_byte_view_masked_filter::<StringViewType>(in_array)
142185
}
143186
(DataType::BinaryView, Generic) => {
144187
make_byte_view_masked_filter::<BinaryViewType>(in_array)
145188
}
146189

147-
// Fallback for nested/complex types and strings (Phase 4: Strings use fallback)
190+
// Fallback for nested/complex types (List, Struct, Map, Union, etc.)
148191
(_, Generic) => Ok(Arc::new(NestedTypeFilter::try_new(in_array)?)),
149192
}
150193
}
@@ -157,6 +200,7 @@ fn dispatch_branchless(
157200
arr: &ArrayRef,
158201
) -> Option<Result<Arc<dyn StaticFilter + Send + Sync>>> {
159202
// Dispatch to width-specific branchless filter.
203+
// Each width has its own max size: 4B→32, 8B→16, 16B→4
160204
match arr.data_type().primitive_width() {
161205
Some(4) => Some(make_branchless_filter::<UInt32Type>(arr, 4)),
162206
Some(8) => Some(make_branchless_filter::<UInt64Type>(arr, 8)),
@@ -192,6 +236,8 @@ fn dispatch_hashed(
192236
Some(16) => Some(make_direct_probe_filter_reinterpreted::<Decimal128Type>(
193237
arr,
194238
)),
239+
// Other widths (1, 2) use Bitmap strategy and never reach here.
240+
// Unknown widths fall through to Generic strategy.
195241
_ => None,
196242
}
197243
}
@@ -204,6 +250,8 @@ where
204250
D: ArrowPrimitiveType + 'static,
205251
D::Native: Send + Sync + DirectProbeHashable + 'static,
206252
{
253+
use super::transform::reinterpret_any_primitive_to;
254+
207255
// Fast path: already the right type
208256
if in_array.data_type() == &D::DATA_TYPE {
209257
return Ok(Arc::new(DirectProbeFilter::<D>::try_new(in_array)?));

datafusion/physical-expr/src/expressions/in_list/transform.rs

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -523,3 +523,228 @@ where
523523
{
524524
Ok(Arc::new(ByteViewMaskedFilter::<T>::try_new(in_array)?))
525525
}
526+
527+
// =============================================================================
528+
// UTF8 TWO-STAGE FILTER (length+prefix pre-check + full verification)
529+
// =============================================================================
530+
//
531+
// Similar to ByteViewMaskedFilter but for regular Utf8/LargeUtf8 arrays.
532+
// Encodes strings as i128 with length + prefix for quick rejection.
533+
//
534+
// Encoding (Little Endian):
535+
// - Bytes 0-3: length (u32)
536+
// - Bytes 4-15: data (12 bytes)
537+
//
538+
// This naturally distinguishes short from long strings via the length field.
539+
// For short strings (≤12 bytes), the i128 contains all data → match is definitive.
540+
// For long strings (>12 bytes), a match requires full string comparison.
541+
542+
/// Encodes a string as i128 with length + prefix.
543+
/// Format: [len:u32][data:12 bytes] (Little Endian)
544+
#[inline(always)]
545+
fn encode_string_as_i128(s: &[u8]) -> i128 {
546+
let len = s.len();
547+
548+
// Optimization: Construct the i128 directly using arithmetic and pointer copy
549+
// to avoid Store-to-Load Forwarding (STLF) stalls on x64 and minimize LSU pressure on ARM.
550+
//
551+
// The layout in memory must match Utf8View: [4 bytes len][12 bytes data]
552+
let mut val: u128 = len as u128; // Length in bytes 0-3
553+
554+
// Safety: writing to the remaining bytes of an initialized u128.
555+
// We use a pointer copy for the string data as it is variable length (0-12 bytes).
556+
unsafe {
557+
let dst = (&mut val as *mut u128 as *mut u8).add(4);
558+
std::ptr::copy_nonoverlapping(s.as_ptr(), dst, len.min(INLINE_STRING_LEN));
559+
}
560+
561+
val as i128
562+
}
563+
564+
/// Two-stage filter for Utf8/LargeUtf8 arrays.
565+
///
566+
/// Stage 1: Quick rejection using length+prefix as i128
567+
/// - Non-matches rejected via O(1) DirectProbeFilter lookup
568+
/// - Short string matches (≤12 bytes) accepted immediately
569+
///
570+
/// Stage 2: Full verification for long string matches
571+
/// - Only reached when encoded i128 matches AND string length >12 bytes
572+
/// - Uses HashTable with full string comparison
573+
pub(crate) struct Utf8TwoStageFilter<O: arrow::array::OffsetSizeTrait> {
574+
/// The haystack array containing values to match against
575+
in_array: ArrayRef,
576+
/// DirectProbeFilter for O(1) encoded i128 quick rejection
577+
encoded_filter: DirectProbeFilter<Decimal128Type>,
578+
/// HashTable storing indices of long strings (>12 bytes) for Stage 2
579+
long_string_table: HashTable<usize>,
580+
/// Random state for consistent hashing
581+
state: RandomState,
582+
/// Whether all haystack strings are short (≤12 bytes) - enables fast path
583+
all_short: bool,
584+
_phantom: PhantomData<O>,
585+
}
586+
587+
impl<O: arrow::array::OffsetSizeTrait + 'static> Utf8TwoStageFilter<O> {
588+
pub(crate) fn try_new(in_array: ArrayRef) -> Result<Self> {
589+
use arrow::array::GenericStringArray;
590+
591+
let arr = in_array
592+
.as_any()
593+
.downcast_ref::<GenericStringArray<O>>()
594+
.expect("Utf8TwoStageFilter requires GenericStringArray");
595+
596+
let len = arr.len();
597+
let mut encoded_values = Vec::with_capacity(len);
598+
let state = RandomState::new();
599+
let mut long_string_table = HashTable::new();
600+
let mut all_short = true;
601+
602+
// Build encoded values and long string table
603+
for i in 0..len {
604+
if arr.is_null(i) {
605+
encoded_values.push(0);
606+
continue;
607+
}
608+
609+
let s = arr.value(i);
610+
let bytes = s.as_bytes();
611+
encoded_values.push(encode_string_as_i128(bytes));
612+
613+
if bytes.len() > INLINE_STRING_LEN {
614+
all_short = false;
615+
// Add to long string table for Stage 2 verification (with deduplication)
616+
let hash = state.hash_one(bytes);
617+
if long_string_table
618+
.find(hash, |&stored_idx| {
619+
arr.value(stored_idx).as_bytes() == bytes
620+
})
621+
.is_none()
622+
{
623+
long_string_table.insert_unique(hash, i, |&idx| {
624+
state.hash_one(arr.value(idx).as_bytes())
625+
});
626+
}
627+
}
628+
}
629+
630+
// Build DirectProbeFilter from encoded values
631+
let nulls = arr
632+
.nulls()
633+
.map(|n| arrow::buffer::NullBuffer::new(n.inner().clone()));
634+
let encoded_array: ArrayRef = Arc::new(PrimitiveArray::<Decimal128Type>::new(
635+
ScalarBuffer::from(encoded_values),
636+
nulls,
637+
));
638+
let encoded_filter =
639+
DirectProbeFilter::<Decimal128Type>::try_new(&encoded_array)?;
640+
641+
Ok(Self {
642+
in_array,
643+
encoded_filter,
644+
long_string_table,
645+
state,
646+
all_short,
647+
_phantom: PhantomData,
648+
})
649+
}
650+
}
651+
652+
impl<O: arrow::array::OffsetSizeTrait + 'static> StaticFilter for Utf8TwoStageFilter<O> {
653+
fn null_count(&self) -> usize {
654+
self.in_array.null_count()
655+
}
656+
657+
fn contains(&self, v: &dyn Array, negated: bool) -> Result<BooleanArray> {
658+
use arrow::array::GenericStringArray;
659+
660+
handle_dictionary!(self, v, negated);
661+
662+
let needle_arr = v
663+
.as_any()
664+
.downcast_ref::<GenericStringArray<O>>()
665+
.expect("needle array type mismatch in Utf8TwoStageFilter");
666+
let haystack_arr = self
667+
.in_array
668+
.as_any()
669+
.downcast_ref::<GenericStringArray<O>>()
670+
.expect("haystack array type mismatch in Utf8TwoStageFilter");
671+
672+
let haystack_has_nulls = self.in_array.null_count() > 0;
673+
674+
if self.all_short {
675+
// Fast path: all haystack strings are short
676+
// Batch-encode all needles and do bulk lookup
677+
let needle_encoded: Vec<i128> = (0..needle_arr.len())
678+
.map(|i| {
679+
if needle_arr.is_null(i) {
680+
0
681+
} else {
682+
encode_string_as_i128(needle_arr.value(i).as_bytes())
683+
}
684+
})
685+
.collect();
686+
687+
// For short haystack, encoded match is definitive for short needles.
688+
// Long needles (>12 bytes) can never match, but their encoded form
689+
// won't match any short haystack encoding (different length field).
690+
return Ok(self.encoded_filter.contains_slice(
691+
&needle_encoded,
692+
needle_arr.nulls(),
693+
negated,
694+
));
695+
}
696+
697+
// Two-stage path: haystack has long strings
698+
Ok(super::result::build_in_list_result(
699+
v.len(),
700+
needle_arr.nulls(),
701+
haystack_has_nulls,
702+
negated,
703+
|i| {
704+
// SAFETY: i is in bounds [0, v.len()), guaranteed by build_in_list_result
705+
let needle_bytes = unsafe { needle_arr.value_unchecked(i) }.as_bytes();
706+
let encoded = encode_string_as_i128(needle_bytes);
707+
708+
// Stage 1: Quick rejection via encoded i128
709+
if !self.encoded_filter.contains_single(encoded) {
710+
return false;
711+
}
712+
713+
// Encoded match found
714+
let needle_len = needle_bytes.len();
715+
if needle_len <= INLINE_STRING_LEN {
716+
// Short needle: encoded contains all data, match is definitive
717+
// (If haystack had a long string with same prefix, its length
718+
// field would differ, so encoded wouldn't match)
719+
return true;
720+
}
721+
722+
// Stage 2: Long needle - verify with full string comparison
723+
let hash = self.state.hash_one(needle_bytes);
724+
self.long_string_table
725+
.find(hash, |&idx| {
726+
// SAFETY: idx was stored in try_new from valid indices into in_array
727+
unsafe { haystack_arr.value_unchecked(idx) }.as_bytes()
728+
== needle_bytes
729+
})
730+
.is_some()
731+
},
732+
))
733+
}
734+
}
735+
736+
/// Creates a two-stage filter for Utf8/LargeUtf8 arrays.
737+
pub(crate) fn make_utf8_two_stage_filter(
738+
in_array: ArrayRef,
739+
) -> Result<Arc<dyn StaticFilter + Send + Sync>> {
740+
use arrow::datatypes::DataType;
741+
match in_array.data_type() {
742+
DataType::Utf8 => Ok(Arc::new(Utf8TwoStageFilter::<i32>::try_new(in_array)?)),
743+
DataType::LargeUtf8 => {
744+
Ok(Arc::new(Utf8TwoStageFilter::<i64>::try_new(in_array)?))
745+
}
746+
dt => datafusion_common::exec_err!(
747+
"Unsupported data type for Utf8 two-stage filter: {dt}"
748+
),
749+
}
750+
}

0 commit comments

Comments
 (0)