diff --git a/datafusion/physical-plan/src/joins/hash_join/partitioned_hash_eval.rs b/datafusion/physical-plan/src/joins/hash_join/partitioned_hash_eval.rs index 4c437e813139d..e1831ad149347 100644 --- a/datafusion/physical-plan/src/joins/hash_join/partitioned_hash_eval.rs +++ b/datafusion/physical-plan/src/joins/hash_join/partitioned_hash_eval.rs @@ -21,18 +21,18 @@ use std::{any::Any, fmt::Display, hash::Hash, sync::Arc}; use ahash::RandomState; use arrow::{ - array::{BooleanArray, UInt64Array}, - buffer::MutableBuffer, + array::{ArrayRef, UInt64Array}, datatypes::{DataType, Schema}, - util::bit_util, + record_batch::RecordBatch, }; -use datafusion_common::{Result, internal_datafusion_err, internal_err}; +use datafusion_common::Result; +use datafusion_common::hash_utils::{create_hashes, with_hashes}; use datafusion_expr::ColumnarValue; use datafusion_physical_expr_common::physical_expr::{ DynHash, PhysicalExpr, PhysicalExprRef, }; -use crate::{hash_utils::create_hashes, joins::utils::JoinHashMapType}; +use crate::joins::utils::JoinHashMapType; /// RandomState wrapper that preserves the seeds used to create it. /// @@ -181,18 +181,11 @@ impl PhysicalExpr for HashExpr { Ok(false) } - fn evaluate( - &self, - batch: &arrow::record_batch::RecordBatch, - ) -> Result { + fn evaluate(&self, batch: &RecordBatch) -> Result { let num_rows = batch.num_rows(); // Evaluate columns - let keys_values = self - .on_columns - .iter() - .map(|c| c.evaluate(batch)?.into_array(num_rows)) - .collect::>>()?; + let keys_values = evaluate_columns(&self.on_columns, batch)?; // Compute hashes let mut hashes_buffer = vec![0; num_rows]; @@ -217,8 +210,10 @@ impl PhysicalExpr for HashExpr { /// Takes a UInt64Array of hash values and checks membership in a hash table. /// Returns a BooleanArray indicating which hashes exist. pub struct HashTableLookupExpr { - /// Expression that computes hash values (should be a HashExpr) - hash_expr: PhysicalExprRef, + /// Columns to hash + on_columns: Vec, + /// Random state for hashing (with seeds preserved for serialization) + random_state: SeededRandomState, /// Hash table to check against hash_map: Arc, /// Description for display @@ -229,7 +224,8 @@ impl HashTableLookupExpr { /// Create a new HashTableLookupExpr /// /// # Arguments - /// * `hash_expr` - Expression that computes hash values + /// * `on_columns` - Columns to hash + /// * `random_state` - SeededRandomState for hashing /// * `hash_map` - Hash table to check membership /// * `description` - Description for debugging /// @@ -237,12 +233,14 @@ impl HashTableLookupExpr { /// This is public for internal testing purposes only and is not /// guaranteed to be stable across versions. pub fn new( - hash_expr: PhysicalExprRef, + on_columns: Vec, + random_state: SeededRandomState, hash_map: Arc, description: String, ) -> Self { Self { - hash_expr, + on_columns, + random_state, hash_map, description, } @@ -251,14 +249,22 @@ impl HashTableLookupExpr { impl std::fmt::Debug for HashTableLookupExpr { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}({:?})", self.description, self.hash_expr) + let cols = self + .on_columns + .iter() + .map(|e| e.to_string()) + .collect::>() + .join(", "); + let (s1, s2, s3, s4) = self.random_state.seeds(); + write!(f, "{}({cols}, [{s1},{s2},{s3},{s4}])", self.description) } } impl Hash for HashTableLookupExpr { fn hash(&self, state: &mut H) { - self.hash_expr.dyn_hash(state); + self.on_columns.dyn_hash(state); self.description.hash(state); + self.random_state.seeds().hash(state); // Note that we compare hash_map by pointer equality. // Actually comparing the contents of the hash maps would be expensive. // The way these hash maps are used in actuality is that HashJoinExec creates @@ -279,8 +285,9 @@ impl PartialEq for HashTableLookupExpr { // hash maps to have the same content in practice. // Theoretically this is a public API and users could create identical hash maps, // but that seems unlikely and not worth paying the cost of deep comparison all the time. - self.hash_expr.as_ref() == other.hash_expr.as_ref() + self.on_columns == other.on_columns && self.description == other.description + && self.random_state.seeds() == other.random_state.seeds() && Arc::ptr_eq(&self.hash_map, &other.hash_map) } } @@ -299,21 +306,16 @@ impl PhysicalExpr for HashTableLookupExpr { } fn children(&self) -> Vec<&Arc> { - vec![&self.hash_expr] + self.on_columns.iter().collect() } fn with_new_children( self: Arc, children: Vec>, ) -> Result> { - if children.len() != 1 { - return internal_err!( - "HashTableLookupExpr expects exactly 1 child, got {}", - children.len() - ); - } Ok(Arc::new(HashTableLookupExpr::new( - Arc::clone(&children[0]), + children, + self.random_state.clone(), Arc::clone(&self.hash_map), self.description.clone(), ))) @@ -327,36 +329,14 @@ impl PhysicalExpr for HashTableLookupExpr { Ok(false) } - fn evaluate( - &self, - batch: &arrow::record_batch::RecordBatch, - ) -> Result { - let num_rows = batch.num_rows(); - - // Evaluate hash expression to get hash values - let hash_array = self.hash_expr.evaluate(batch)?.into_array(num_rows)?; - let hash_array = hash_array.as_any().downcast_ref::().ok_or( - internal_datafusion_err!( - "HashTableLookupExpr expects UInt64Array from hash expression" - ), - )?; - - // Check each hash against the hash table - let mut buf = MutableBuffer::from_len_zeroed(bit_util::ceil(num_rows, 8)); - for (idx, hash_value) in hash_array.values().iter().enumerate() { - // Use get_matched_indices to check - if it returns any indices, the hash exists - let (matched_indices, _) = self - .hash_map - .get_matched_indices(Box::new(std::iter::once((idx, hash_value))), None); - - if !matched_indices.is_empty() { - bit_util::set_bit(buf.as_slice_mut(), idx); - } - } + fn evaluate(&self, batch: &RecordBatch) -> Result { + // Evaluate columns + let keys_values = evaluate_columns(&self.on_columns, batch)?; - Ok(ColumnarValue::Array(Arc::new( - BooleanArray::new_from_packed(buf, 0, num_rows), - ))) + with_hashes(&keys_values, self.random_state.random_state(), |hashes| { + let array = self.hash_map.contain_hashes(hashes); + Ok(ColumnarValue::Array(Arc::new(array))) + }) } fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -364,6 +344,17 @@ impl PhysicalExpr for HashTableLookupExpr { } } +fn evaluate_columns( + columns: &[PhysicalExprRef], + batch: &RecordBatch, +) -> Result> { + let num_rows = batch.num_rows(); + columns + .iter() + .map(|c| c.evaluate(batch)?.into_array(num_rows)) + .collect() +} + #[cfg(test)] mod tests { use super::*; @@ -482,22 +473,19 @@ mod tests { #[test] fn test_hash_table_lookup_expr_eq_same() { let col_a: PhysicalExprRef = Arc::new(Column::new("a", 0)); - let hash_expr: PhysicalExprRef = Arc::new(HashExpr::new( - vec![Arc::clone(&col_a)], - SeededRandomState::with_seeds(1, 2, 3, 4), - "inner_hash".to_string(), - )); let hash_map: Arc = Arc::new(JoinHashMapU32::with_capacity(10)); let expr1 = HashTableLookupExpr::new( - Arc::clone(&hash_expr), + vec![Arc::clone(&col_a)], + SeededRandomState::with_seeds(1, 2, 3, 4), Arc::clone(&hash_map), "lookup".to_string(), ); let expr2 = HashTableLookupExpr::new( - Arc::clone(&hash_expr), + vec![Arc::clone(&col_a)], + SeededRandomState::with_seeds(1, 2, 3, 4), Arc::clone(&hash_map), "lookup".to_string(), ); @@ -506,33 +494,23 @@ mod tests { } #[test] - fn test_hash_table_lookup_expr_eq_different_hash_expr() { + fn test_hash_table_lookup_expr_eq_different_columns() { let col_a: PhysicalExprRef = Arc::new(Column::new("a", 0)); let col_b: PhysicalExprRef = Arc::new(Column::new("b", 1)); - let hash_expr1: PhysicalExprRef = Arc::new(HashExpr::new( - vec![Arc::clone(&col_a)], - SeededRandomState::with_seeds(1, 2, 3, 4), - "inner_hash".to_string(), - )); - - let hash_expr2: PhysicalExprRef = Arc::new(HashExpr::new( - vec![Arc::clone(&col_b)], - SeededRandomState::with_seeds(1, 2, 3, 4), - "inner_hash".to_string(), - )); - let hash_map: Arc = Arc::new(JoinHashMapU32::with_capacity(10)); let expr1 = HashTableLookupExpr::new( - Arc::clone(&hash_expr1), + vec![Arc::clone(&col_a)], + SeededRandomState::with_seeds(1, 2, 3, 4), Arc::clone(&hash_map), "lookup".to_string(), ); let expr2 = HashTableLookupExpr::new( - Arc::clone(&hash_expr2), + vec![Arc::clone(&col_b)], + SeededRandomState::with_seeds(1, 2, 3, 4), Arc::clone(&hash_map), "lookup".to_string(), ); @@ -543,22 +521,19 @@ mod tests { #[test] fn test_hash_table_lookup_expr_eq_different_description() { let col_a: PhysicalExprRef = Arc::new(Column::new("a", 0)); - let hash_expr: PhysicalExprRef = Arc::new(HashExpr::new( - vec![Arc::clone(&col_a)], - SeededRandomState::with_seeds(1, 2, 3, 4), - "inner_hash".to_string(), - )); let hash_map: Arc = Arc::new(JoinHashMapU32::with_capacity(10)); let expr1 = HashTableLookupExpr::new( - Arc::clone(&hash_expr), + vec![Arc::clone(&col_a)], + SeededRandomState::with_seeds(1, 2, 3, 4), Arc::clone(&hash_map), "lookup_one".to_string(), ); let expr2 = HashTableLookupExpr::new( - Arc::clone(&hash_expr), + vec![Arc::clone(&col_a)], + SeededRandomState::with_seeds(1, 2, 3, 4), Arc::clone(&hash_map), "lookup_two".to_string(), ); @@ -569,11 +544,6 @@ mod tests { #[test] fn test_hash_table_lookup_expr_eq_different_hash_map() { let col_a: PhysicalExprRef = Arc::new(Column::new("a", 0)); - let hash_expr: PhysicalExprRef = Arc::new(HashExpr::new( - vec![Arc::clone(&col_a)], - SeededRandomState::with_seeds(1, 2, 3, 4), - "inner_hash".to_string(), - )); // Two different Arc pointers (even with same content) should not be equal let hash_map1: Arc = @@ -582,13 +552,15 @@ mod tests { Arc::new(JoinHashMapU32::with_capacity(10)); let expr1 = HashTableLookupExpr::new( - Arc::clone(&hash_expr), + vec![Arc::clone(&col_a)], + SeededRandomState::with_seeds(1, 2, 3, 4), hash_map1, "lookup".to_string(), ); let expr2 = HashTableLookupExpr::new( - Arc::clone(&hash_expr), + vec![Arc::clone(&col_a)], + SeededRandomState::with_seeds(1, 2, 3, 4), hash_map2, "lookup".to_string(), ); @@ -600,22 +572,19 @@ mod tests { #[test] fn test_hash_table_lookup_expr_hash_consistency() { let col_a: PhysicalExprRef = Arc::new(Column::new("a", 0)); - let hash_expr: PhysicalExprRef = Arc::new(HashExpr::new( - vec![Arc::clone(&col_a)], - SeededRandomState::with_seeds(1, 2, 3, 4), - "inner_hash".to_string(), - )); let hash_map: Arc = Arc::new(JoinHashMapU32::with_capacity(10)); let expr1 = HashTableLookupExpr::new( - Arc::clone(&hash_expr), + vec![Arc::clone(&col_a)], + SeededRandomState::with_seeds(1, 2, 3, 4), Arc::clone(&hash_map), "lookup".to_string(), ); let expr2 = HashTableLookupExpr::new( - Arc::clone(&hash_expr), + vec![Arc::clone(&col_a)], + SeededRandomState::with_seeds(1, 2, 3, 4), Arc::clone(&hash_map), "lookup".to_string(), ); diff --git a/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs b/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs index 7d34ce9acbd57..447caf51dc725 100644 --- a/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs +++ b/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs @@ -129,14 +129,9 @@ fn create_membership_predicate( } // Use hash table lookup for large build sides PushdownStrategy::HashTable(hash_map) => { - let lookup_hash_expr = Arc::new(HashExpr::new( + Ok(Some(Arc::new(HashTableLookupExpr::new( on_right.to_vec(), random_state.clone(), - "hash_join".to_string(), - )) as Arc; - - Ok(Some(Arc::new(HashTableLookupExpr::new( - lookup_hash_expr, hash_map, "hash_lookup".to_string(), )) as Arc)) diff --git a/datafusion/physical-plan/src/joins/join_hash_map.rs b/datafusion/physical-plan/src/joins/join_hash_map.rs index b0ed6dcc7c255..6a07fefaaabdb 100644 --- a/datafusion/physical-plan/src/joins/join_hash_map.rs +++ b/datafusion/physical-plan/src/joins/join_hash_map.rs @@ -22,6 +22,8 @@ use std::fmt::{self, Debug}; use std::ops::Sub; +use arrow::array::BooleanArray; +use arrow::buffer::BooleanBuffer; use arrow::datatypes::ArrowNativeType; use hashbrown::HashTable; use hashbrown::hash_table::Entry::{Occupied, Vacant}; @@ -124,6 +126,9 @@ pub trait JoinHashMapType: Send + Sync { match_indices: &mut Vec, ) -> Option; + /// Returns a BooleanArray indicating which of the provided hashes exist in the map. + fn contain_hashes(&self, hash_values: &[u64]) -> BooleanArray; + /// Returns `true` if the join hash map contains no entries. fn is_empty(&self) -> bool; @@ -196,6 +201,10 @@ impl JoinHashMapType for JoinHashMapU32 { ) } + fn contain_hashes(&self, hash_values: &[u64]) -> BooleanArray { + contain_hashes(&self.map, hash_values) + } + fn is_empty(&self) -> bool { self.map.is_empty() } @@ -270,6 +279,10 @@ impl JoinHashMapType for JoinHashMapU64 { ) } + fn contain_hashes(&self, hash_values: &[u64]) -> BooleanArray { + contain_hashes(&self.map, hash_values) + } + fn is_empty(&self) -> bool { self.map.is_empty() } @@ -496,3 +509,35 @@ where } None } + +pub fn contain_hashes(map: &HashTable<(u64, T)>, hash_values: &[u64]) -> BooleanArray { + let buffer = BooleanBuffer::collect_bool(hash_values.len(), |i| { + let hash = hash_values[i]; + map.find(hash, |(h, _)| hash == *h).is_some() + }); + BooleanArray::new(buffer, None) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_contain_hashes() { + let mut hash_map = JoinHashMapU32::with_capacity(10); + hash_map.update_from_iter(Box::new([10u64, 20u64, 30u64].iter().enumerate()), 0); + + let probe_hashes = vec![10, 11, 20, 21, 30, 31]; + let array = hash_map.contain_hashes(&probe_hashes); + + assert_eq!(array.len(), probe_hashes.len()); + + for (i, &hash) in probe_hashes.iter().enumerate() { + if matches!(hash, 10 | 20 | 30) { + assert!(array.value(i), "Hash {hash} should exist in the map"); + } else { + assert!(!array.value(i), "Hash {hash} should NOT exist in the map"); + } + } + } +} diff --git a/datafusion/physical-plan/src/joins/stream_join_utils.rs b/datafusion/physical-plan/src/joins/stream_join_utils.rs index 22cc82a22db5f..e0b045efc3ff7 100644 --- a/datafusion/physical-plan/src/joins/stream_join_utils.rs +++ b/datafusion/physical-plan/src/joins/stream_join_utils.rs @@ -23,15 +23,16 @@ use std::mem::size_of; use std::sync::Arc; use crate::joins::join_hash_map::{ - JoinHashMapOffset, get_matched_indices, get_matched_indices_with_limit_offset, - update_from_iter, + JoinHashMapOffset, contain_hashes, get_matched_indices, + get_matched_indices_with_limit_offset, update_from_iter, }; use crate::joins::utils::{JoinFilter, JoinHashMapType}; use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricBuilder}; use crate::{ExecutionPlan, metrics}; use arrow::array::{ - ArrowPrimitiveType, BooleanBufferBuilder, NativeAdapter, PrimitiveArray, RecordBatch, + ArrowPrimitiveType, BooleanArray, BooleanBufferBuilder, NativeAdapter, + PrimitiveArray, RecordBatch, }; use arrow::compute::concat_batches; use arrow::datatypes::{ArrowNativeType, Schema, SchemaRef}; @@ -94,6 +95,10 @@ impl JoinHashMapType for PruningJoinHashMap { ) } + fn contain_hashes(&self, hash_values: &[u64]) -> BooleanArray { + contain_hashes(&self.map, hash_values) + } + fn is_empty(&self) -> bool { self.map.is_empty() } diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index aa5458849330f..4754e96c5232f 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -2340,9 +2340,10 @@ fn roundtrip_hash_table_lookup_expr_to_lit() -> Result<()> { // Create a HashTableLookupExpr - it will be replaced with lit(true) during serialization let hash_map = Arc::new(JoinHashMapU32::with_capacity(0)); - let hash_expr: Arc = Arc::new(Column::new("col", 0)); + let on_columns = vec![Arc::new(Column::new("col", 0)) as Arc]; let lookup_expr: Arc = Arc::new(HashTableLookupExpr::new( - hash_expr, + on_columns, + datafusion::physical_plan::joins::SeededRandomState::with_seeds(0, 0, 0, 0), hash_map, "test_lookup".to_string(), ));