diff --git a/datafusion/physical-plan/src/aggregates/group_values/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/mod.rs index 2f3b1a19e7d73..95077f36c9f4f 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/mod.rs @@ -18,14 +18,19 @@ //! [`GroupValues`] trait for storing and interning group keys use arrow::array::types::{ - Date32Type, Date64Type, Decimal128Type, Time32MillisecondType, Time32SecondType, - Time64MicrosecondType, Time64NanosecondType, TimestampMicrosecondType, - TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, + Date32Type, Date64Type, Decimal128Type, Float16Type, Int8Type, Int16Type, Int32Type, + Int64Type, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, + Time64NanosecondType, TimestampMicrosecondType, TimestampMillisecondType, + TimestampNanosecondType, TimestampSecondType, UInt8Type, UInt16Type, UInt32Type, + UInt64Type, }; use arrow::array::{ArrayRef, downcast_primitive}; use arrow::datatypes::{DataType, SchemaRef, TimeUnit}; -use datafusion_common::Result; +use datafusion_common::stats::Precision; +use datafusion_common::{Result, ScalarValue}; +use half::f16; +use crate::Statistics; use datafusion_expr::EmitTo; pub mod multi_group_by; @@ -42,6 +47,7 @@ use crate::aggregates::{ group_values::single_group_by::{ boolean::GroupValuesBoolean, bytes::GroupValuesBytes, bytes_view::GroupValuesBytesView, primitive::GroupValuesPrimitive, + primitive::GroupValuesSmallPrimitive, }, order::GroupOrdering, }; @@ -134,10 +140,112 @@ pub trait GroupValues: Send { pub fn new_group_values( schema: SchemaRef, group_ordering: &GroupOrdering, + statistics: Option<&Statistics>, ) -> Result> { if schema.fields.len() == 1 { let d = schema.fields[0].data_type(); + match d { + DataType::Int8 => { + return Ok(Box::new(GroupValuesSmallPrimitive::::new( + d.clone(), + i8::MIN, + i8::MAX, + ))); + } + DataType::UInt8 => { + return Ok(Box::new(GroupValuesSmallPrimitive::::new( + d.clone(), + u8::MIN, + u8::MAX, + ))); + } + DataType::Int16 => { + return Ok(Box::new(GroupValuesSmallPrimitive::::new( + d.clone(), + i16::MIN, + i16::MAX, + ))); + } + DataType::UInt16 => { + return Ok(Box::new(GroupValuesSmallPrimitive::::new( + d.clone(), + u16::MIN, + u16::MAX, + ))); + } + DataType::Float16 => { + return Ok(Box::new(GroupValuesSmallPrimitive::::new( + d.clone(), + f16::from_bits(0), + f16::from_bits(65535), + ))); + } + _ => {} + } + + if let Some(stats) = statistics + && stats.column_statistics.len() == 1 + && let Precision::Exact(min) = &stats.column_statistics[0].min_value + && let Precision::Exact(max) = &stats.column_statistics[0].max_value + { + match (d, min, max) { + ( + DataType::Int32, + ScalarValue::Int32(Some(min)), + ScalarValue::Int32(Some(max)), + ) => { + if *max >= *min && (*max as i64 - *min as i64) < 65536 { + return Ok(Box::new(GroupValuesSmallPrimitive::::new( + d.clone(), + *min, + *max, + ))); + } + } + ( + DataType::UInt32, + ScalarValue::UInt32(Some(min)), + ScalarValue::UInt32(Some(max)), + ) => { + if *max >= *min && (*max as u64 - *min as u64) < 65536 { + return Ok(Box::new(GroupValuesSmallPrimitive::::new( + d.clone(), + *min, + *max, + ))); + } + } + ( + DataType::Int64, + ScalarValue::Int64(Some(min)), + ScalarValue::Int64(Some(max)), + ) => { + if *max >= *min && (*max as i128 - *min as i128) < 65536 { + return Ok(Box::new(GroupValuesSmallPrimitive::::new( + d.clone(), + *min, + *max, + ))); + } + } + ( + DataType::UInt64, + ScalarValue::UInt64(Some(min)), + ScalarValue::UInt64(Some(max)), + ) => { + if *max >= *min && (*max - *min) < 65536 { + return Ok(Box::new(GroupValuesSmallPrimitive::::new( + d.clone(), + *min, + *max, + ))); + } + } + _ => {} + } + } + macro_rules! downcast_helper { ($t:ty, $d:ident) => { return Ok(Box::new(GroupValuesPrimitive::<$t>::new($d.clone()))) diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs index c46cde8786eb4..b5f4ef41f85ef 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs @@ -73,6 +73,72 @@ macro_rules! hash_float { hash_float!(f16, f32, f64); +pub(crate) trait SmallValue: arrow::datatypes::ArrowNativeType { + fn to_index(&self, min: Self) -> usize; +} + +macro_rules! impl_small_value_uint { + ($($t:ty),+) => { + $(impl SmallValue for $t { + fn to_index(&self, min: Self) -> usize { + if *self < min { + usize::MAX + } else { + (*self - min) as usize + } + } + })+ + }; +} + +impl_small_value_uint!(u8, u16, u32, u64); + +impl SmallValue for i8 { + fn to_index(&self, min: Self) -> usize { + if *self < min { + usize::MAX + } else { + (*self as i16 - min as i16) as usize + } + } +} + +impl SmallValue for i16 { + fn to_index(&self, min: Self) -> usize { + if *self < min { + usize::MAX + } else { + (*self as i32 - min as i32) as usize + } + } +} + +impl SmallValue for i32 { + fn to_index(&self, min: Self) -> usize { + if *self < min { + usize::MAX + } else { + (*self as i64 - min as i64) as usize + } + } +} + +impl SmallValue for i64 { + fn to_index(&self, min: Self) -> usize { + if *self < min { + usize::MAX + } else { + (*self as i128 - min as i128) as usize + } + } +} + +impl SmallValue for f16 { + fn to_index(&self, _min: Self) -> usize { + self.to_bits() as usize + } +} + /// A [`GroupValues`] storing a single column of primitive values /// /// This specialization is significantly faster than using the more general @@ -161,25 +227,13 @@ where } fn emit(&mut self, emit_to: EmitTo) -> Result> { - fn build_primitive( - values: Vec, - null_idx: Option, - ) -> PrimitiveArray { - let nulls = null_idx.map(|null_idx| { - let mut buffer = NullBufferBuilder::new(values.len()); - buffer.append_n_non_nulls(null_idx); - buffer.append_null(); - buffer.append_n_non_nulls(values.len() - null_idx - 1); - // NOTE: The inner builder must be constructed as there is at least one null - buffer.finish().unwrap() - }); - PrimitiveArray::::new(values.into(), nulls) - } - let array: PrimitiveArray = match emit_to { EmitTo::All => { self.map.clear(); - build_primitive(std::mem::take(&mut self.values), self.null_group.take()) + build_primitive::( + std::mem::take(&mut self.values), + self.null_group.take(), + ) } EmitTo::First(n) => { self.map.retain(|entry| { @@ -205,7 +259,7 @@ where }; let mut split = self.values.split_off(n); std::mem::swap(&mut self.values, &mut split); - build_primitive(split, null_group) + build_primitive::(split, null_group) } }; @@ -219,3 +273,242 @@ where self.map.shrink_to(num_rows, |_| 0); // hasher does not matter since the map is cleared } } + +/// A [`GroupValues`] storing a single column of small primitive values (i.e. <=16 bits) +/// +/// This specialization uses a flat `Vec` as a lookup table instead of a `HashTable` +pub struct GroupValuesSmallPrimitive { + /// The data type of the output array + data_type: DataType, + /// Stores `group_index + 1` for each possible value of the primitive type. + /// 0 means the value has not been seen yet. + map: Vec, + /// The group index of the null value if any + null_group: Option, + /// The values for each group index + values: Vec, + /// The minimum value (offset) + min_value: T::Native, +} + +impl GroupValuesSmallPrimitive +where + T::Native: SmallValue, +{ + pub fn new(data_type: DataType, min: T::Native, max: T::Native) -> Self { + assert!(max >= min, "GroupValuesSmallPrimitive: max < min"); + let range = max.to_index(min); + assert!( + range < 1_000_000, + "GroupValuesSmallPrimitive: range too large ({})", + range + ); + Self { + data_type, + map: vec![0; range + 1], + values: Vec::with_capacity(128), + null_group: None, + min_value: min, + } + } +} + +impl GroupValues for GroupValuesSmallPrimitive +where + T::Native: SmallValue, +{ + fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { + assert_eq!(cols.len(), 1); + groups.clear(); + + for v in cols[0].as_primitive::() { + let group_id = match v { + None => *self.null_group.get_or_insert_with(|| { + let group_id = self.values.len(); + self.values.push(Default::default()); + group_id + }), + Some(key) => { + let index = key.to_index(self.min_value); + let entry = self.map.get_mut(index).expect("GroupValuesSmallPrimitive: value out of range"); + if *entry == 0 { + let g = self.values.len(); + self.values.push(key); + *entry = g + 1; + g + } else { + *entry - 1 + } + } + }; + groups.push(group_id) + } + Ok(()) + } + + fn size(&self) -> usize { + self.map.allocated_size() + self.values.allocated_size() + } + + fn is_empty(&self) -> bool { + self.values.is_empty() + } + + fn len(&self) -> usize { + self.values.len() + } + + fn emit(&mut self, emit_to: EmitTo) -> Result> { + let array: PrimitiveArray = match emit_to { + EmitTo::All => { + self.map.fill(0); + build_primitive::( + std::mem::take(&mut self.values), + self.null_group.take(), + ) + } + EmitTo::First(n) => { + for entry in self.map.iter_mut() { + if *entry == 0 { + continue; + } + let group_idx = *entry - 1; + match group_idx.checked_sub(n) { + Some(sub) => { + *entry = sub + 1; + } + None => { + *entry = 0; + } + } + } + let null_group = match &mut self.null_group { + Some(v) if *v >= n => { + *v -= n; + None + } + Some(_) => self.null_group.take(), + None => None, + }; + let mut split = self.values.split_off(n); + std::mem::swap(&mut self.values, &mut split); + build_primitive::(split, null_group) + } + }; + + Ok(vec![Arc::new(array.with_data_type(self.data_type.clone()))]) + } + + fn clear_shrink(&mut self, num_rows: usize) { + self.values.clear(); + self.values.shrink_to(num_rows); + self.map.fill(0); + } +} + +fn build_primitive( + values: Vec, + null_idx: Option, +) -> PrimitiveArray { + let nulls = null_idx.map(|null_idx| { + let mut buffer = NullBufferBuilder::new(values.len()); + buffer.append_n_non_nulls(null_idx); + buffer.append_null(); + buffer.append_n_non_nulls(values.len() - null_idx - 1); + // NOTE: The inner builder must be constructed as there is at least one null + buffer.finish().unwrap() + }); + PrimitiveArray::::new(values.into(), nulls) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::types::{Int8Type, Int16Type}; + use arrow::array::{Array, Int8Array, Int16Array}; + + #[test] + fn test_intern_int16() { + let mut group_values = + GroupValuesSmallPrimitive::::new(DataType::Int16, -32768, 32767); + let array = Arc::new(Int16Array::from(vec![ + Some(1000), + Some(2000), + Some(1000), + None, + Some(3000), + ])) as ArrayRef; + let mut groups = vec![]; + group_values.intern(&[array], &mut groups).unwrap(); + + assert_eq!(groups, vec![0, 1, 0, 2, 3]); + assert_eq!(group_values.len(), 4); + + let emitted = group_values.emit(EmitTo::All).unwrap(); + let emitted_array = emitted[0].as_primitive::(); + + // Group 0: 1000, Group 1: 2000, Group 2: None, Group 3: 3000 + assert_eq!(emitted_array.len(), 4); + assert_eq!(emitted_array.value(0), 1000); + assert_eq!(emitted_array.value(1), 2000); + assert!(emitted_array.is_null(2)); + assert_eq!(emitted_array.value(3), 3000); + } + + #[test] + fn test_intern_int8() { + let mut group_values = + GroupValuesSmallPrimitive::::new(DataType::Int8, -128, 127); + let array = Arc::new(Int8Array::from(vec![ + Some(1), + Some(2), + Some(1), + None, + Some(3), + ])) as ArrayRef; + let mut groups = vec![]; + group_values.intern(&[array], &mut groups).unwrap(); + + assert_eq!(groups, vec![0, 1, 0, 2, 3]); + assert_eq!(group_values.len(), 4); + + let emitted = group_values.emit(EmitTo::All).unwrap(); + let emitted_array = emitted[0].as_primitive::(); + + // Group 0: 1, Group 1: 2, Group 2: None, Group 3: 3 + assert_eq!(emitted_array.len(), 4); + assert_eq!(emitted_array.value(0), 1); + assert_eq!(emitted_array.value(1), 2); + assert!(emitted_array.is_null(2)); + assert_eq!(emitted_array.value(3), 3); + } + + #[test] + fn test_emit_first_int8() { + let mut group_values = + GroupValuesSmallPrimitive::::new(DataType::Int8, -128, 127); + let array = + Arc::new(Int8Array::from(vec![Some(10), Some(20), Some(10), None])) + as ArrayRef; + let mut groups = vec![]; + group_values.intern(&[array], &mut groups).unwrap(); + assert_eq!(groups, vec![0, 1, 0, 2]); + + // Emit first 2 groups (10 and 20) + let emitted = group_values.emit(EmitTo::First(2)).unwrap(); + let emitted_array = emitted[0].as_primitive::(); + assert_eq!(emitted_array.len(), 2); + assert_eq!(emitted_array.value(0), 10); + assert_eq!(emitted_array.value(1), 20); + + // Remaining should be just the null group at index 0 + assert_eq!(group_values.len(), 1); + let array2 = + Arc::new(Int8Array::from(vec![Some(10), None, Some(30)])) as ArrayRef; + group_values.intern(&[array2], &mut groups).unwrap(); + + // 10 is new (index 1), None is old (index 0), 30 is new (index 2) + assert_eq!(groups, vec![1, 0, 2]); + assert_eq!(group_values.len(), 3); + } +} diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 1ae7202711112..ee6516bd48b46 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -590,7 +590,9 @@ impl GroupedHashAggregateStream { _ => OutOfMemoryMode::ReportError, }; - let group_values = new_group_values(group_schema, &group_ordering)?; + let stats = agg.input().partition_statistics(None).ok(); + let group_values = + new_group_values(group_schema, &group_ordering, stats.as_ref())?; let reservation = MemoryConsumer::new(name) // We interpret 'can spill' as 'can handle memory back pressure'. // This value needs to be set to true for the default memory pool implementations diff --git a/datafusion/physical-plan/src/recursive_query.rs b/datafusion/physical-plan/src/recursive_query.rs index 683dbb4e49765..ca493673d788b 100644 --- a/datafusion/physical-plan/src/recursive_query.rs +++ b/datafusion/physical-plan/src/recursive_query.rs @@ -454,7 +454,7 @@ struct DistinctDeduplicator { impl DistinctDeduplicator { fn new(schema: SchemaRef, task_context: &TaskContext) -> Result { - let group_values = new_group_values(schema, &GroupOrdering::None)?; + let group_values = new_group_values(schema, &GroupOrdering::None, None)?; let reservation = MemoryConsumer::new("RecursiveQueryHashTable") .register(task_context.memory_pool()); Ok(Self {