Skip to content

Commit 92a239a

Browse files
authored
Implement min, max, sum for run-end-encoded arrays. (#9409)
Efficient implementations: * min & max work directly on the values child array. * sum folds over run lengths & values, without decompressing the array. In particular, those implementations takes care of the logical offset & len of the run-end-encoded arrays. This is non-trivial: * We get the physical start & end indices in O(log(#runs)), but those are incorrect for empty arrays. * Slicing can happen in the middle of a run. For sum, we need to track the logical start & end and reduce the run length accordingly. Finally, one caveat: the aggregation functions only work when the child values array is a primitive array. That's fine ~always, but some client might store the values in an unexpected type. They'll either get None or an Error, depending on the aggregation function used. This feature is tracked in #3520.
1 parent 6931d88 commit 92a239a

File tree

1 file changed

+292
-4
lines changed

1 file changed

+292
-4
lines changed

arrow-arith/src/aggregate.rs

Lines changed: 292 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -540,7 +540,7 @@ pub fn min_string_view(array: &StringViewArray) -> Option<&str> {
540540
/// Returns the sum of values in the array.
541541
///
542542
/// This doesn't detect overflow. Once overflowing, the result will wrap around.
543-
/// For an overflow-checking variant, use `sum_array_checked` instead.
543+
/// For an overflow-checking variant, use [`sum_array_checked`] instead.
544544
pub fn sum_array<T, A: ArrayAccessor<Item = T::Native>>(array: A) -> Option<T::Native>
545545
where
546546
T: ArrowNumericType,
@@ -567,14 +567,22 @@ where
567567

568568
Some(sum)
569569
}
570+
DataType::RunEndEncoded(run_ends, _) => match run_ends.data_type() {
571+
DataType::Int16 => ree::sum_wrapping::<types::Int16Type, T>(&array),
572+
DataType::Int32 => ree::sum_wrapping::<types::Int32Type, T>(&array),
573+
DataType::Int64 => ree::sum_wrapping::<types::Int64Type, T>(&array),
574+
_ => unreachable!(),
575+
},
570576
_ => sum::<T>(as_primitive_array(&array)),
571577
}
572578
}
573579

574580
/// Returns the sum of values in the array.
575581
///
576582
/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant,
577-
/// use `sum_array` instead.
583+
/// use [`sum_array`] instead.
584+
/// Additionally returns an `Err` on run-end-encoded arrays with a provided
585+
/// values type parameter that is incorrect.
578586
pub fn sum_array_checked<T, A: ArrayAccessor<Item = T::Native>>(
579587
array: A,
580588
) -> Result<Option<T::Native>, ArrowError>
@@ -603,10 +611,110 @@ where
603611

604612
Ok(Some(sum))
605613
}
614+
DataType::RunEndEncoded(run_ends, _) => match run_ends.data_type() {
615+
DataType::Int16 => ree::sum_checked::<types::Int16Type, T>(&array),
616+
DataType::Int32 => ree::sum_checked::<types::Int32Type, T>(&array),
617+
DataType::Int64 => ree::sum_checked::<types::Int64Type, T>(&array),
618+
_ => unreachable!(),
619+
},
606620
_ => sum_checked::<T>(as_primitive_array(&array)),
607621
}
608622
}
609623

624+
// Logic for summing run-end-encoded arrays.
625+
mod ree {
626+
use std::convert::Infallible;
627+
628+
use arrow_array::cast::AsArray;
629+
use arrow_array::types::RunEndIndexType;
630+
use arrow_array::{Array, ArrowNativeTypeOp, ArrowNumericType, PrimitiveArray, TypedRunArray};
631+
use arrow_buffer::ArrowNativeType;
632+
use arrow_schema::ArrowError;
633+
634+
/// Downcasts an array to a TypedRunArray.
635+
fn downcast<'a, I: RunEndIndexType, V: ArrowNumericType>(
636+
array: &'a dyn Array,
637+
) -> Option<TypedRunArray<'a, I, PrimitiveArray<V>>> {
638+
let array = array.as_run_opt::<I>()?;
639+
// We only support RunArray wrapping primitive types.
640+
array.downcast::<PrimitiveArray<V>>()
641+
}
642+
643+
/// Computes the sum (wrapping) of the array values.
644+
pub(super) fn sum_wrapping<I: RunEndIndexType, V: ArrowNumericType>(
645+
array: &dyn Array,
646+
) -> Option<V::Native> {
647+
let ree = downcast::<I, V>(array)?;
648+
let Ok(sum) = fold(ree, |acc, val, len| -> Result<V::Native, Infallible> {
649+
Ok(acc.add_wrapping(val.mul_wrapping(V::Native::usize_as(len))))
650+
});
651+
sum
652+
}
653+
654+
/// Computes the sum (erroring on overflow) of the array values.
655+
pub(super) fn sum_checked<I: RunEndIndexType, V: ArrowNumericType>(
656+
array: &dyn Array,
657+
) -> Result<Option<V::Native>, ArrowError> {
658+
let Some(ree) = downcast::<I, V>(array) else {
659+
return Err(ArrowError::InvalidArgumentError(
660+
"Input run array values are not a PrimitiveArray".to_string(),
661+
));
662+
};
663+
fold(ree, |acc, val, len| -> Result<V::Native, ArrowError> {
664+
let Some(len) = V::Native::from_usize(len) else {
665+
return Err(ArrowError::ArithmeticOverflow(format!(
666+
"Cannot convert a run-end index ({:?}) to the value type ({})",
667+
len,
668+
std::any::type_name::<V::Native>()
669+
)));
670+
};
671+
acc.add_checked(val.mul_checked(len)?)
672+
})
673+
}
674+
675+
/// Folds over the values in a run-end-encoded array.
676+
fn fold<'a, I: RunEndIndexType, V: ArrowNumericType, F, E>(
677+
array: TypedRunArray<'a, I, PrimitiveArray<V>>,
678+
mut f: F,
679+
) -> Result<Option<V::Native>, E>
680+
where
681+
F: FnMut(V::Native, V::Native, usize) -> Result<V::Native, E>,
682+
{
683+
let run_ends = array.run_ends();
684+
let logical_start = run_ends.offset();
685+
let logical_end = run_ends.offset() + run_ends.len();
686+
let run_ends = run_ends.sliced_values();
687+
688+
let values_slice = array.run_array().values_slice();
689+
let values = values_slice
690+
.as_any()
691+
.downcast_ref::<PrimitiveArray<V>>()
692+
// Safety: we know the values array is PrimitiveArray<V>.
693+
.unwrap();
694+
695+
let mut prev_end = 0;
696+
let mut acc = V::Native::ZERO;
697+
let mut has_non_null_value = false;
698+
699+
for (run_end, value) in run_ends.zip(values) {
700+
let current_run_end = run_end.as_usize().clamp(logical_start, logical_end);
701+
let run_length = current_run_end - prev_end;
702+
703+
if let Some(value) = value {
704+
has_non_null_value = true;
705+
acc = f(acc, value, run_length)?;
706+
}
707+
708+
prev_end = current_run_end;
709+
if current_run_end == logical_end {
710+
break;
711+
}
712+
}
713+
714+
Ok(if has_non_null_value { Some(acc) } else { None })
715+
}
716+
}
717+
610718
/// Returns the min of values in the array of `ArrowNumericType` type, or dictionary
611719
/// array with value of `ArrowNumericType` type.
612720
pub fn min_array<T, A: ArrayAccessor<Item = T::Native>>(array: A) -> Option<T::Native>
@@ -639,6 +747,20 @@ where
639747
{
640748
match array.data_type() {
641749
DataType::Dictionary(_, _) => min_max_helper::<T::Native, _, _>(array, cmp),
750+
DataType::RunEndEncoded(run_ends, _) => {
751+
// We can directly perform min/max on the values child array, as any
752+
// run must have non-zero length.
753+
let array: &dyn Array = &array;
754+
let values = match run_ends.data_type() {
755+
DataType::Int16 => array.as_run_opt::<types::Int16Type>()?.values_slice(),
756+
DataType::Int32 => array.as_run_opt::<types::Int32Type>()?.values_slice(),
757+
DataType::Int64 => array.as_run_opt::<types::Int64Type>()?.values_slice(),
758+
_ => return None,
759+
};
760+
// We only support RunArray wrapping primitive types.
761+
let values = values.as_any().downcast_ref::<PrimitiveArray<T>>()?;
762+
m(values)
763+
}
642764
_ => m(as_primitive_array(&array)),
643765
}
644766
}
@@ -751,7 +873,7 @@ pub fn bool_or(array: &BooleanArray) -> Option<bool> {
751873
/// Returns `Ok(None)` if the array is empty or only contains null values.
752874
///
753875
/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant,
754-
/// use `sum` instead.
876+
/// use [`sum`] instead.
755877
pub fn sum_checked<T>(array: &PrimitiveArray<T>) -> Result<Option<T::Native>, ArrowError>
756878
where
757879
T: ArrowNumericType,
@@ -799,7 +921,7 @@ where
799921
/// Returns `None` if the array is empty or only contains null values.
800922
///
801923
/// This doesn't detect overflow in release mode by default. Once overflowing, the result will
802-
/// wrap around. For an overflow-checking variant, use `sum_checked` instead.
924+
/// wrap around. For an overflow-checking variant, use [`sum_checked`] instead.
803925
pub fn sum<T: ArrowNumericType>(array: &PrimitiveArray<T>) -> Option<T::Native>
804926
where
805927
T::Native: ArrowNativeTypeOp,
@@ -1750,4 +1872,170 @@ mod tests {
17501872
sum_checked(&a).expect_err("overflow should be detected");
17511873
sum_array_checked::<Int32Type, _>(&a).expect_err("overflow should be detected");
17521874
}
1875+
1876+
/// Helper for building a RunArray.
1877+
fn make_run_array<'a, I: RunEndIndexType, V: ArrowNumericType, ItemType>(
1878+
values: impl IntoIterator<Item = &'a ItemType>,
1879+
) -> RunArray<I>
1880+
where
1881+
ItemType: Clone + Into<Option<V::Native>> + 'static,
1882+
{
1883+
let mut builder = arrow_array::builder::PrimitiveRunBuilder::<I, V>::new();
1884+
for v in values.into_iter() {
1885+
builder.append_option((*v).clone().into());
1886+
}
1887+
builder.finish()
1888+
}
1889+
1890+
#[test]
1891+
fn test_ree_sum_array_basic() {
1892+
let run_array = make_run_array::<Int16Type, Int32Type, _>(&[10, 10, 20, 30, 30, 30]);
1893+
let typed_array = run_array.downcast::<Int32Array>().unwrap();
1894+
1895+
let result = sum_array::<Int32Type, _>(typed_array);
1896+
assert_eq!(result, Some(130));
1897+
1898+
let result = sum_array_checked::<Int32Type, _>(typed_array).unwrap();
1899+
assert_eq!(result, Some(130));
1900+
}
1901+
1902+
#[test]
1903+
fn test_ree_sum_array_empty() {
1904+
let run_array = make_run_array::<Int16Type, Int32Type, i32>(&[]);
1905+
let typed_array = run_array.downcast::<Int32Array>().unwrap();
1906+
1907+
let result = sum_array::<Int32Type, _>(typed_array);
1908+
assert_eq!(result, None);
1909+
1910+
let result = sum_array_checked::<Int32Type, _>(typed_array).unwrap();
1911+
assert_eq!(result, None);
1912+
}
1913+
1914+
#[test]
1915+
fn test_ree_sum_array_with_nulls() {
1916+
let run_array =
1917+
make_run_array::<Int16Type, Int32Type, _>(&[Some(10), None, Some(20), None, Some(30)]);
1918+
let typed_array = run_array.downcast::<Int32Array>().unwrap();
1919+
1920+
let result = sum_array::<Int32Type, _>(typed_array);
1921+
assert_eq!(result, Some(60));
1922+
1923+
let result = sum_array_checked::<Int32Type, _>(typed_array).unwrap();
1924+
assert_eq!(result, Some(60));
1925+
}
1926+
1927+
#[test]
1928+
fn test_ree_sum_array_with_only_nulls() {
1929+
let run_array = make_run_array::<Int16Type, Int16Type, _>(&[None, None, None, None, None]);
1930+
let typed_array = run_array.downcast::<Int16Array>().unwrap();
1931+
1932+
let result = sum_array::<Int16Type, _>(typed_array);
1933+
assert_eq!(result, None);
1934+
1935+
let result = sum_array_checked::<Int16Type, _>(typed_array).unwrap();
1936+
assert_eq!(result, None);
1937+
}
1938+
1939+
#[test]
1940+
fn test_ree_sum_array_overflow() {
1941+
let run_array = make_run_array::<Int16Type, Int8Type, _>(&[126, 2]);
1942+
let typed_array = run_array.downcast::<Int8Array>().unwrap();
1943+
1944+
// i8 range is -128..=127. 126+2 overflows to -128.
1945+
let result = sum_array::<Int8Type, _>(typed_array);
1946+
assert_eq!(result, Some(-128));
1947+
1948+
let result = sum_array_checked::<Int8Type, _>(typed_array);
1949+
assert!(result.is_err());
1950+
}
1951+
1952+
#[test]
1953+
fn test_ree_sum_array_sliced() {
1954+
let run_array = make_run_array::<Int16Type, UInt8Type, _>(&[0, 10, 10, 10, 20, 30, 30, 30]);
1955+
// Skip 2 values at the start and 1 at the end.
1956+
let sliced = run_array.slice(2, 5);
1957+
let typed_array = sliced.downcast::<UInt8Array>().unwrap();
1958+
1959+
let result = sum_array::<UInt8Type, _>(typed_array);
1960+
assert_eq!(result, Some(100));
1961+
1962+
let result = sum_array_checked::<UInt8Type, _>(typed_array).unwrap();
1963+
assert_eq!(result, Some(100));
1964+
}
1965+
1966+
#[test]
1967+
fn test_ree_min_max_array_basic() {
1968+
let run_array = make_run_array::<Int16Type, Int32Type, _>(&[30, 30, 10, 20, 20]);
1969+
let typed_array = run_array.downcast::<Int32Array>().unwrap();
1970+
1971+
let result = min_array::<Int32Type, _>(typed_array);
1972+
assert_eq!(result, Some(10));
1973+
1974+
let result = max_array::<Int32Type, _>(typed_array);
1975+
assert_eq!(result, Some(30));
1976+
}
1977+
1978+
#[test]
1979+
fn test_ree_min_max_array_empty() {
1980+
let run_array = make_run_array::<Int16Type, Int32Type, i32>(&[]);
1981+
let typed_array = run_array.downcast::<Int32Array>().unwrap();
1982+
1983+
let result = min_array::<Int32Type, _>(typed_array);
1984+
assert_eq!(result, None);
1985+
1986+
let result = max_array::<Int32Type, _>(typed_array);
1987+
assert_eq!(result, None);
1988+
}
1989+
1990+
#[test]
1991+
fn test_ree_min_max_array_float() {
1992+
let run_array = make_run_array::<Int16Type, Float64Type, _>(&[5.5, 5.5, 2.1, 8.9, 8.9]);
1993+
let typed_array = run_array.downcast::<Float64Array>().unwrap();
1994+
1995+
let result = min_array::<Float64Type, _>(typed_array);
1996+
assert_eq!(result, Some(2.1));
1997+
1998+
let result = max_array::<Float64Type, _>(typed_array);
1999+
assert_eq!(result, Some(8.9));
2000+
}
2001+
2002+
#[test]
2003+
fn test_ree_min_max_array_with_nulls() {
2004+
let run_array = make_run_array::<Int16Type, UInt8Type, _>(&[None, Some(10)]);
2005+
let typed_array = run_array.downcast::<UInt8Array>().unwrap();
2006+
2007+
let result = min_array::<UInt8Type, _>(typed_array);
2008+
assert_eq!(result, Some(10));
2009+
2010+
let result = max_array::<UInt8Type, _>(typed_array);
2011+
assert_eq!(result, Some(10));
2012+
}
2013+
2014+
#[test]
2015+
fn test_ree_min_max_array_sliced() {
2016+
let run_array = make_run_array::<Int16Type, Int32Type, _>(&[0, 30, 30, 10, 20, 20, 100]);
2017+
// Skip 1 value at the start and 1 at the end.
2018+
let sliced = run_array.slice(1, 5);
2019+
let typed_array = sliced.downcast::<Int32Array>().unwrap();
2020+
2021+
let result = min_array::<Int32Type, _>(typed_array);
2022+
assert_eq!(result, Some(10));
2023+
2024+
let result = max_array::<Int32Type, _>(typed_array);
2025+
assert_eq!(result, Some(30));
2026+
}
2027+
2028+
#[test]
2029+
fn test_ree_min_max_array_sliced_mid_run() {
2030+
let run_array = make_run_array::<Int16Type, Int32Type, _>(&[0, 0, 30, 10, 20, 100, 100]);
2031+
// Skip 1 value at the start and 1 at the end.
2032+
let sliced = run_array.slice(1, 5);
2033+
let typed_array = sliced.downcast::<Int32Array>().unwrap();
2034+
2035+
let result = min_array::<Int32Type, _>(typed_array);
2036+
assert_eq!(result, Some(0));
2037+
2038+
let result = max_array::<Int32Type, _>(typed_array);
2039+
assert_eq!(result, Some(100));
2040+
}
17532041
}

0 commit comments

Comments
 (0)