@@ -758,13 +758,58 @@ impl Accumulator for DistinctCountAccumulator {
758758#[ cfg( test) ]
759759mod tests {
760760 use super :: * ;
761- use arrow:: array:: NullArray ;
762-
761+ use arrow:: {
762+ array:: { DictionaryArray , Int32Array , NullArray , StringArray } ,
763+ datatypes:: Int32Type ,
764+ } ;
763765 #[ test]
764766 fn count_accumulator_nulls ( ) -> Result < ( ) > {
765767 let mut accumulator = CountAccumulator :: new ( ) ;
766768 accumulator. update_batch ( & [ Arc :: new ( NullArray :: new ( 10 ) ) ] ) ?;
767769 assert_eq ! ( accumulator. evaluate( ) ?, ScalarValue :: Int64 ( Some ( 0 ) ) ) ;
768770 Ok ( ( ) )
769771 }
772+
773+ #[ test]
774+ fn count_distinct_accumulator_dictionary_with_null_values ( ) -> Result < ( ) > {
775+ // Create a dictionary array where:
776+ // - keys aren't null
777+ // - but values referenced by some keys are null
778+ let values = StringArray :: from ( vec ! [ Some ( "a" ) , None , Some ( "c" ) ] ) ;
779+ let keys = Int32Array :: from ( vec ! [ 0 , 1 , 2 , 0 , 1 ] ) ; // references "a", null, "c", "a", null
780+ let dict_array = DictionaryArray :: < Int32Type > :: try_new ( keys, Arc :: new ( values) ) ?;
781+
782+ // The expected behavior is that count_distinct should count only non-null values
783+ // which in this case are "a" and "c" (appearing as 0 and 2 in keys)
784+ let mut accumulator = DistinctCountAccumulator {
785+ values : HashSet :: default ( ) ,
786+ state_data_type : dict_array. data_type ( ) . clone ( ) ,
787+ } ;
788+
789+ accumulator. update_batch ( & [ Arc :: new ( dict_array) ] ) ?;
790+
791+ // Should have 2 distinct non-null values ("a" and "c")
792+ assert_eq ! ( accumulator. evaluate( ) ?, ScalarValue :: Int64 ( Some ( 2 ) ) ) ;
793+ Ok ( ( ) )
794+ }
795+
796+ #[ test]
797+ fn count_accumulator_dictionary_with_null_values ( ) -> Result < ( ) > {
798+ // Create a dictionary array where:
799+ // - keys aren't null
800+ // - but values referenced by some keys are null
801+ let values = StringArray :: from ( vec ! [ Some ( "a" ) , None , Some ( "c" ) ] ) ;
802+ let keys = Int32Array :: from ( vec ! [ 0 , 1 , 2 , 0 , 1 ] ) ; // references "a", null, "c", "a", null
803+ let dict_array = DictionaryArray :: < Int32Type > :: try_new ( keys, Arc :: new ( values) ) ?;
804+
805+ // The expected behavior is that count should only count non-null values
806+ let mut accumulator = CountAccumulator :: new ( ) ;
807+
808+ accumulator. update_batch ( & [ Arc :: new ( dict_array) ] ) ?;
809+
810+ // 5 elements in the array, of which 2 reference null values (the two 1s in the keys)
811+ // So we should count 3 non-null values
812+ assert_eq ! ( accumulator. evaluate( ) ?, ScalarValue :: Int64 ( Some ( 3 ) ) ) ;
813+ Ok ( ( ) )
814+ }
770815}
0 commit comments