1717
1818//! The module contains the file writer for parquet file format.
1919
20+ use std:: collections:: hash_map:: Entry ;
2021use std:: collections:: HashMap ;
2122use std:: sync:: atomic:: AtomicI64 ;
2223use std:: sync:: Arc ;
2324
24- use arrow_schema:: SchemaRef as ArrowSchemaRef ;
25+ use arrow_array:: Float32Array ;
26+ use arrow_schema:: { DataType , SchemaRef as ArrowSchemaRef } ;
2527use bytes:: Bytes ;
2628use futures:: future:: BoxFuture ;
2729use itertools:: Itertools ;
@@ -97,6 +99,7 @@ impl<T: LocationGenerator, F: FileNameGenerator> FileWriterBuilder for ParquetWr
9799 written_size,
98100 current_row_num : 0 ,
99101 out_file,
102+ nan_value_counts : HashMap :: new ( ) ,
100103 } )
101104 }
102105}
@@ -222,6 +225,7 @@ pub struct ParquetWriter {
222225 writer_properties : WriterProperties ,
223226 written_size : Arc < AtomicI64 > ,
224227 current_row_num : usize ,
228+ nan_value_counts : HashMap < i32 , u64 > ,
225229}
226230
227231/// Used to aggregate min and max value of each column.
@@ -357,6 +361,7 @@ impl ParquetWriter {
357361 metadata : FileMetaData ,
358362 written_size : usize ,
359363 file_path : String ,
364+ nan_value_counts : HashMap < i32 , u64 > ,
360365 ) -> Result < DataFileBuilder > {
361366 let index_by_parquet_path = {
362367 let mut visitor = IndexByParquetPathName :: new ( ) ;
@@ -423,8 +428,8 @@ impl ParquetWriter {
423428 . null_value_counts ( null_value_counts)
424429 . lower_bounds ( lower_bounds)
425430 . upper_bounds ( upper_bounds)
431+ . nan_value_counts ( nan_value_counts)
426432 // # TODO(#417)
427- // - nan_value_counts
428433 // - distinct_counts
429434 . key_metadata ( metadata. footer_signing_key_metadata )
430435 . split_offsets (
@@ -541,13 +546,45 @@ impl FileWriter for ParquetWriter {
541546 self . inner_writer . as_mut ( ) . unwrap ( )
542547 } ;
543548
549+
550+ for ( col, field) in batch
551+ . columns ( )
552+ . iter ( )
553+ . zip ( self . schema . as_struct ( ) . fields ( ) . iter ( ) )
554+ {
555+ let dt = col. data_type ( ) ;
556+
557+ let nan_val_cnt: u64 = match dt {
558+ DataType :: Float32 => {
559+ let float_array = col. as_any ( ) . downcast_ref :: < Float32Array > ( ) . unwrap ( ) ;
560+
561+ float_array
562+ . iter ( )
563+ . filter ( |value| value. map_or ( false , |v| v. is_nan ( ) ) )
564+ . count ( ) as u64
565+ }
566+ _ => 0 ,
567+ } ;
568+
569+ match self . nan_value_counts . entry ( field. id ) {
570+ Entry :: Occupied ( mut ele) => {
571+ let total_nan_val_cnt = ele. get ( ) + nan_val_cnt;
572+ ele. insert ( total_nan_val_cnt) ;
573+ }
574+ Entry :: Vacant ( v) => {
575+ v. insert ( nan_val_cnt) ;
576+ }
577+ }
578+ }
579+
544580 writer. write ( batch) . await . map_err ( |err| {
545581 Error :: new (
546582 ErrorKind :: Unexpected ,
547583 "Failed to write using parquet writer." ,
548584 )
549585 . with_source ( err)
550586 } ) ?;
587+
551588 Ok ( ( ) )
552589 }
553590
@@ -566,6 +603,7 @@ impl FileWriter for ParquetWriter {
566603 metadata,
567604 written_size as usize ,
568605 self . out_file. location( ) . to_string( ) ,
606+ self . nan_value_counts,
569607 ) ?] )
570608 }
571609}
@@ -626,8 +664,8 @@ mod tests {
626664 use anyhow:: Result ;
627665 use arrow_array:: types:: Int64Type ;
628666 use arrow_array:: {
629- Array , ArrayRef , BooleanArray , Decimal128Array , Int32Array , Int64Array , ListArray ,
630- RecordBatch , StructArray ,
667+ Array , ArrayRef , BooleanArray , Decimal128Array , Float32Array , Int32Array , Int64Array ,
668+ ListArray , RecordBatch , StructArray ,
631669 } ;
632670 use arrow_schema:: { DataType , SchemaRef as ArrowSchemaRef } ;
633671 use arrow_select:: concat:: concat_batches;
@@ -807,13 +845,27 @@ mod tests {
807845 arrow_schema:: Field :: new( "col" , arrow_schema:: DataType :: Int64 , true ) . with_metadata(
808846 HashMap :: from( [ ( PARQUET_FIELD_ID_META_KEY . to_string( ) , "0" . to_string( ) ) ] ) ,
809847 ) ,
848+ arrow_schema:: Field :: new( "col1" , arrow_schema:: DataType :: Float32 , true )
849+ . with_metadata( HashMap :: from( [ (
850+ PARQUET_FIELD_ID_META_KEY . to_string( ) ,
851+ "1" . to_string( ) ,
852+ ) ] ) ) ,
810853 ] ;
811854 Arc :: new ( arrow_schema:: Schema :: new ( fields) )
812855 } ;
813856 let col = Arc :: new ( Int64Array :: from_iter_values ( 0 ..1024 ) ) as ArrayRef ;
814857 let null_col = Arc :: new ( Int64Array :: new_null ( 1024 ) ) as ArrayRef ;
815- let to_write = RecordBatch :: try_new ( schema. clone ( ) , vec ! [ col] ) . unwrap ( ) ;
816- let to_write_null = RecordBatch :: try_new ( schema. clone ( ) , vec ! [ null_col] ) . unwrap ( ) ;
858+ let float_col = Arc :: new ( Float32Array :: from_iter_values ( ( 0 ..1024 ) . map ( |x| {
859+ if x % 100 == 0 {
860+ // There will be 11 NANs as there are 1024 entries
861+ f32:: NAN
862+ } else {
863+ x as f32
864+ }
865+ } ) ) ) as ArrayRef ;
866+ let to_write = RecordBatch :: try_new ( schema. clone ( ) , vec ! [ col, float_col. clone( ) ] ) . unwrap ( ) ;
867+ let to_write_null =
868+ RecordBatch :: try_new ( schema. clone ( ) , vec ! [ null_col, float_col] ) . unwrap ( ) ;
817869
818870 // write data
819871 let mut pw = ParquetWriterBuilder :: new (
@@ -825,6 +877,7 @@ mod tests {
825877 )
826878 . build ( )
827879 . await ?;
880+
828881 pw. write ( & to_write) . await ?;
829882 pw. write ( & to_write_null) . await ?;
830883 let res = pw. close ( ) . await ?;
@@ -841,16 +894,26 @@ mod tests {
841894
842895 // check data file
843896 assert_eq ! ( data_file. record_count( ) , 2048 ) ;
844- assert_eq ! ( * data_file. value_counts( ) , HashMap :: from( [ ( 0 , 2048 ) ] ) ) ;
897+ assert_eq ! (
898+ * data_file. value_counts( ) ,
899+ HashMap :: from( [ ( 0 , 2048 ) , ( 1 , 2048 ) ] )
900+ ) ;
845901 assert_eq ! (
846902 * data_file. lower_bounds( ) ,
847- HashMap :: from( [ ( 0 , Datum :: long( 0 ) ) ] )
903+ HashMap :: from( [ ( 0 , Datum :: long( 0 ) ) , ( 1 , Datum :: float ( 1.0 ) ) ] )
848904 ) ;
849905 assert_eq ! (
850906 * data_file. upper_bounds( ) ,
851- HashMap :: from( [ ( 0 , Datum :: long( 1023 ) ) ] )
907+ HashMap :: from( [ ( 0 , Datum :: long( 1023 ) ) , ( 1 , Datum :: float( 1023.0 ) ) ] )
908+ ) ;
909+ assert_eq ! (
910+ * data_file. null_value_counts( ) ,
911+ HashMap :: from( [ ( 0 , 1024 ) , ( 1 , 0 ) ] )
912+ ) ;
913+ assert_eq ! (
914+ * data_file. nan_value_counts( ) ,
915+ HashMap :: from( [ ( 0 , 0 ) , ( 1 , 22 ) ] ) // 22, cause we wrote float column twice
852916 ) ;
853- assert_eq ! ( * data_file. null_value_counts( ) , HashMap :: from( [ ( 0 , 1024 ) ] ) ) ;
854917
855918 // check the written file
856919 let expect_batch = concat_batches ( & schema, vec ! [ & to_write, & to_write_null] ) . unwrap ( ) ;
0 commit comments