|
19 | 19 | //! |
20 | 20 | //! This module provides high-performance membership testing for Arrow primitive types. |
21 | 21 |
|
22 | | -use std::hash::Hash; |
23 | | - |
24 | 22 | use arrow::array::{Array, ArrayRef, AsArray, BooleanArray}; |
25 | 23 | use arrow::datatypes::ArrowPrimitiveType; |
26 | 24 | use datafusion_common::{HashSet, Result, exec_datafusion_err}; |
@@ -256,76 +254,218 @@ where |
256 | 254 | } |
257 | 255 |
|
258 | 256 | // ============================================================================= |
259 | | -// PRIMITIVE FILTER (Hash-based) |
| 257 | +// DIRECT PROBE HASH FILTER (O(1) lookup with open addressing) |
260 | 258 | // ============================================================================= |
261 | 259 |
|
262 | | -/// Hash-based filter for primitive types with larger IN lists. |
263 | | -pub(crate) struct PrimitiveFilter<T: ArrowPrimitiveType> { |
| 260 | +/// Load factor inverse for DirectProbeFilter hash table. |
| 261 | +/// A value of 4 means 25% load factor (table is 4x the number of elements). |
| 262 | +const LOAD_FACTOR_INVERSE: usize = 4; |
| 263 | + |
| 264 | +/// Minimum table size for DirectProbeFilter. |
| 265 | +/// Ensures reasonable performance even for very small IN lists. |
| 266 | +const MIN_TABLE_SIZE: usize = 16; |
| 267 | + |
| 268 | +/// Golden ratio constant for 32-bit hash mixing. |
| 269 | +/// Derived from (2^32 / phi) where phi = (1 + sqrt(5)) / 2. |
| 270 | +const GOLDEN_RATIO_32: u32 = 0x9e3779b9; |
| 271 | + |
| 272 | +/// Golden ratio constant for 64-bit hash mixing. |
| 273 | +/// Derived from (2^64 / phi) where phi = (1 + sqrt(5)) / 2. |
| 274 | +const GOLDEN_RATIO_64: u64 = 0x9e3779b97f4a7c15; |
| 275 | + |
| 276 | +/// Secondary mixing constant for 128-bit hashing (from SplitMix64). |
| 277 | +/// Using a different constant for hi/lo avoids collisions when lo = hi * C. |
| 278 | +const SPLITMIX_CONSTANT: u64 = 0xbf58476d1ce4e5b9; |
| 279 | + |
| 280 | +/// Fast hash filter using open addressing with linear probing. |
| 281 | +/// |
| 282 | +/// Uses a power-of-2 sized hash table for O(1) average-case lookups. |
| 283 | +/// Optimized for the IN list use case with: |
| 284 | +/// - Simple/fast hash function (golden ratio multiply + xor-shift) |
| 285 | +/// - 25% load factor for minimal collisions |
| 286 | +/// - Direct array storage for cache-friendly access |
| 287 | +pub(crate) struct DirectProbeFilter<T: ArrowPrimitiveType> |
| 288 | +where |
| 289 | + T::Native: DirectProbeHashable, |
| 290 | +{ |
264 | 291 | null_count: usize, |
265 | | - set: HashSet<T::Native>, |
| 292 | + /// Hash table with open addressing. None = empty slot, Some(v) = value present |
| 293 | + table: Box<[Option<T::Native>]>, |
| 294 | + /// Mask for slot index (table.len() - 1, always power of 2 minus 1) |
| 295 | + mask: usize, |
| 296 | +} |
| 297 | + |
| 298 | +/// Trait for types that can be hashed for the direct probe filter. |
| 299 | +/// |
| 300 | +/// Requires `Hash + Eq` for deduplication via `HashSet`, even though we use |
| 301 | +/// a custom `probe_hash()` for the actual hash table lookups. |
| 302 | +pub(crate) trait DirectProbeHashable: |
| 303 | + Copy + PartialEq + std::hash::Hash + Eq |
| 304 | +{ |
| 305 | + fn probe_hash(self) -> usize; |
| 306 | +} |
| 307 | + |
| 308 | +// Simple but fast hash - golden ratio multiply + xor-shift |
| 309 | +impl DirectProbeHashable for i32 { |
| 310 | + #[inline(always)] |
| 311 | + fn probe_hash(self) -> usize { |
| 312 | + let x = self as u32; |
| 313 | + let x = x.wrapping_mul(GOLDEN_RATIO_32); |
| 314 | + (x ^ (x >> 16)) as usize |
| 315 | + } |
266 | 316 | } |
267 | 317 |
|
268 | | -impl<T: ArrowPrimitiveType> PrimitiveFilter<T> |
| 318 | +impl DirectProbeHashable for i64 { |
| 319 | + #[inline(always)] |
| 320 | + fn probe_hash(self) -> usize { |
| 321 | + let x = self as u64; |
| 322 | + let x = x.wrapping_mul(GOLDEN_RATIO_64); |
| 323 | + (x ^ (x >> 32)) as usize |
| 324 | + } |
| 325 | +} |
| 326 | + |
| 327 | +impl DirectProbeHashable for u32 { |
| 328 | + #[inline(always)] |
| 329 | + fn probe_hash(self) -> usize { |
| 330 | + (self as i32).probe_hash() |
| 331 | + } |
| 332 | +} |
| 333 | + |
| 334 | +impl DirectProbeHashable for u64 { |
| 335 | + #[inline(always)] |
| 336 | + fn probe_hash(self) -> usize { |
| 337 | + (self as i64).probe_hash() |
| 338 | + } |
| 339 | +} |
| 340 | + |
| 341 | +impl DirectProbeHashable for i128 { |
| 342 | + #[inline(always)] |
| 343 | + fn probe_hash(self) -> usize { |
| 344 | + // Mix both halves with different constants to avoid collisions when lo = hi * C |
| 345 | + let lo = self as u64; |
| 346 | + let hi = (self >> 64) as u64; |
| 347 | + let x = lo.wrapping_mul(GOLDEN_RATIO_64) ^ hi.wrapping_mul(SPLITMIX_CONSTANT); |
| 348 | + (x ^ (x >> 32)) as usize |
| 349 | + } |
| 350 | +} |
| 351 | + |
| 352 | +impl<T: ArrowPrimitiveType> DirectProbeFilter<T> |
269 | 353 | where |
270 | | - T::Native: Hash + Eq, |
| 354 | + T::Native: DirectProbeHashable, |
271 | 355 | { |
272 | 356 | pub(crate) fn try_new(in_array: &ArrayRef) -> Result<Self> { |
273 | 357 | let arr = in_array.as_primitive_opt::<T>().ok_or_else(|| { |
274 | 358 | exec_datafusion_err!( |
275 | | - "PrimitiveFilter: expected {} array", |
| 359 | + "DirectProbeFilter: expected {} array", |
276 | 360 | std::any::type_name::<T>() |
277 | 361 | ) |
278 | 362 | })?; |
279 | | - Ok(Self { |
280 | | - null_count: arr.null_count(), |
281 | | - set: arr.iter().flatten().collect(), |
282 | | - }) |
| 363 | + |
| 364 | + // Collect unique values using HashSet for deduplication |
| 365 | + let unique_values: HashSet<_> = arr.iter().flatten().collect(); |
| 366 | + |
| 367 | + Ok(Self::from_values_inner( |
| 368 | + unique_values.into_iter(), |
| 369 | + arr.null_count(), |
| 370 | + )) |
283 | 371 | } |
284 | 372 |
|
285 | | - /// Check membership using a raw values slice (zero-copy path for type reinterpretation). |
| 373 | + /// Creates a DirectProbeFilter from an iterator of values. |
| 374 | + /// |
| 375 | + /// This is useful when building the filter from pre-processed values |
| 376 | + /// (e.g., masked views for Utf8View). |
| 377 | + pub(crate) fn from_values(values: impl Iterator<Item = T::Native>) -> Self { |
| 378 | + // Collect into HashSet for deduplication |
| 379 | + let unique_values: HashSet<_> = values.collect(); |
| 380 | + Self::from_values_inner(unique_values.into_iter(), 0) |
| 381 | + } |
| 382 | + |
| 383 | + /// Internal constructor from deduplicated values |
| 384 | + fn from_values_inner( |
| 385 | + unique_values: impl Iterator<Item = T::Native>, |
| 386 | + null_count: usize, |
| 387 | + ) -> Self { |
| 388 | + let unique_values: Vec<_> = unique_values.collect(); |
| 389 | + |
| 390 | + // Size table to ~25% load factor for fewer collisions |
| 391 | + let n = unique_values.len().max(1); |
| 392 | + let table_size = (n * LOAD_FACTOR_INVERSE) |
| 393 | + .next_power_of_two() |
| 394 | + .max(MIN_TABLE_SIZE); |
| 395 | + let mask = table_size - 1; |
| 396 | + |
| 397 | + let mut table: Box<[Option<T::Native>]> = |
| 398 | + vec![None; table_size].into_boxed_slice(); |
| 399 | + |
| 400 | + // Insert all values using linear probing |
| 401 | + for v in unique_values { |
| 402 | + let mut slot = v.probe_hash() & mask; |
| 403 | + loop { |
| 404 | + if table[slot].is_none() { |
| 405 | + table[slot] = Some(v); |
| 406 | + break; |
| 407 | + } |
| 408 | + slot = (slot + 1) & mask; |
| 409 | + } |
| 410 | + } |
| 411 | + |
| 412 | + Self { |
| 413 | + null_count, |
| 414 | + table, |
| 415 | + mask, |
| 416 | + } |
| 417 | + } |
| 418 | + |
| 419 | + /// O(1) single-value lookup with linear probing. |
| 420 | + /// |
| 421 | + /// Returns true if the value is in the set. |
| 422 | + #[inline(always)] |
| 423 | + pub(crate) fn contains_single(&self, needle: T::Native) -> bool { |
| 424 | + let mut slot = needle.probe_hash() & self.mask; |
| 425 | + loop { |
| 426 | + // SAFETY: `slot` is always < table.len() because: |
| 427 | + // - `slot = hash & mask` where `mask = table.len() - 1` |
| 428 | + // - table size is always a power of 2 |
| 429 | + // - `(slot + 1) & mask` wraps around within bounds |
| 430 | + match unsafe { self.table.get_unchecked(slot) } { |
| 431 | + None => return false, |
| 432 | + Some(v) if *v == needle => return true, |
| 433 | + _ => slot = (slot + 1) & self.mask, |
| 434 | + } |
| 435 | + } |
| 436 | + } |
| 437 | + |
| 438 | + /// Check membership using a raw values slice |
286 | 439 | #[inline] |
287 | 440 | pub(crate) fn contains_slice( |
288 | 441 | &self, |
289 | | - values: &[T::Native], |
| 442 | + input: &[T::Native], |
290 | 443 | nulls: Option<&arrow::buffer::NullBuffer>, |
291 | 444 | negated: bool, |
292 | 445 | ) -> BooleanArray { |
293 | | - build_in_list_result( |
294 | | - values.len(), |
295 | | - nulls, |
296 | | - self.null_count > 0, |
297 | | - negated, |
298 | | - // SAFETY: i is in bounds since we iterate 0..values.len() |
299 | | - |i| self.set.contains(unsafe { values.get_unchecked(i) }), |
300 | | - ) |
| 446 | + build_in_list_result(input.len(), nulls, self.null_count > 0, negated, |i| { |
| 447 | + // SAFETY: i is in bounds since we iterate 0..input.len() |
| 448 | + self.contains_single(unsafe { *input.get_unchecked(i) }) |
| 449 | + }) |
301 | 450 | } |
302 | 451 | } |
303 | 452 |
|
304 | | -impl<T> StaticFilter for PrimitiveFilter<T> |
| 453 | +impl<T> StaticFilter for DirectProbeFilter<T> |
305 | 454 | where |
306 | 455 | T: ArrowPrimitiveType + 'static, |
307 | | - T::Native: Hash + Eq + Send + Sync + 'static, |
| 456 | + T::Native: DirectProbeHashable + Send + Sync + 'static, |
308 | 457 | { |
| 458 | + #[inline] |
309 | 459 | fn null_count(&self) -> usize { |
310 | 460 | self.null_count |
311 | 461 | } |
312 | 462 |
|
| 463 | + #[inline] |
313 | 464 | fn contains(&self, v: &dyn Array, negated: bool) -> Result<BooleanArray> { |
314 | 465 | handle_dictionary!(self, v, negated); |
315 | | - let v = v.as_primitive_opt::<T>().ok_or_else(|| { |
316 | | - exec_datafusion_err!( |
317 | | - "PrimitiveFilter: expected {} array", |
318 | | - std::any::type_name::<T>() |
319 | | - ) |
320 | | - })?; |
321 | | - let values = v.values(); |
322 | | - Ok(build_in_list_result( |
323 | | - v.len(), |
324 | | - v.nulls(), |
325 | | - self.null_count > 0, |
326 | | - negated, |
327 | | - // SAFETY: i is in bounds since we iterate 0..v.len() |
328 | | - |i| self.set.contains(unsafe { values.get_unchecked(i) }), |
329 | | - )) |
| 466 | + // Use raw buffer access for better optimization |
| 467 | + let data = v.to_data(); |
| 468 | + let values: &[T::Native] = data.buffer::<T::Native>(0); |
| 469 | + Ok(self.contains_slice(values, v.nulls(), negated)) |
330 | 470 | } |
331 | 471 | } |
0 commit comments