Skip to content

Commit 3de7526

Browse files
committed
Add tests for count and count_distinct with dictionary arrays containing null values
1 parent ee5ee4d commit 3de7526

File tree

1 file changed

+47
-2
lines changed
  • datafusion/functions-aggregate/src

1 file changed

+47
-2
lines changed

datafusion/functions-aggregate/src/count.rs

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -758,13 +758,58 @@ impl Accumulator for DistinctCountAccumulator {
758758
#[cfg(test)]
759759
mod tests {
760760
use super::*;
761-
use arrow::array::NullArray;
762-
761+
use arrow::{
762+
array::{DictionaryArray, Int32Array, NullArray, StringArray},
763+
datatypes::Int32Type,
764+
};
763765
#[test]
764766
fn count_accumulator_nulls() -> Result<()> {
765767
let mut accumulator = CountAccumulator::new();
766768
accumulator.update_batch(&[Arc::new(NullArray::new(10))])?;
767769
assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(0)));
768770
Ok(())
769771
}
772+
773+
#[test]
774+
fn count_distinct_accumulator_dictionary_with_null_values() -> Result<()> {
775+
// Create a dictionary array where:
776+
// - keys aren't null
777+
// - but values referenced by some keys are null
778+
let values = StringArray::from(vec![Some("a"), None, Some("c")]);
779+
let keys = Int32Array::from(vec![0, 1, 2, 0, 1]); // references "a", null, "c", "a", null
780+
let dict_array = DictionaryArray::<Int32Type>::try_new(keys, Arc::new(values))?;
781+
782+
// The expected behavior is that count_distinct should count only non-null values
783+
// which in this case are "a" and "c" (appearing as 0 and 2 in keys)
784+
let mut accumulator = DistinctCountAccumulator {
785+
values: HashSet::default(),
786+
state_data_type: dict_array.data_type().clone(),
787+
};
788+
789+
accumulator.update_batch(&[Arc::new(dict_array)])?;
790+
791+
// Should have 2 distinct non-null values ("a" and "c")
792+
assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(2)));
793+
Ok(())
794+
}
795+
796+
#[test]
797+
fn count_accumulator_dictionary_with_null_values() -> Result<()> {
798+
// Create a dictionary array where:
799+
// - keys aren't null
800+
// - but values referenced by some keys are null
801+
let values = StringArray::from(vec![Some("a"), None, Some("c")]);
802+
let keys = Int32Array::from(vec![0, 1, 2, 0, 1]); // references "a", null, "c", "a", null
803+
let dict_array = DictionaryArray::<Int32Type>::try_new(keys, Arc::new(values))?;
804+
805+
// The expected behavior is that count should only count non-null values
806+
let mut accumulator = CountAccumulator::new();
807+
808+
accumulator.update_batch(&[Arc::new(dict_array)])?;
809+
810+
// 5 elements in the array, of which 2 reference null values (the two 1s in the keys)
811+
// So we should count 3 non-null values
812+
assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(3)));
813+
Ok(())
814+
}
770815
}

0 commit comments

Comments
 (0)