@@ -22,7 +22,7 @@ use std::collections::HashMap;
2222use std:: sync:: atomic:: AtomicI64 ;
2323use std:: sync:: Arc ;
2424
25- use arrow_array:: Float32Array ;
25+ use arrow_array:: { Float32Array , Float64Array } ;
2626use arrow_schema:: { DataType , SchemaRef as ArrowSchemaRef } ;
2727use bytes:: Bytes ;
2828use futures:: future:: BoxFuture ;
@@ -563,6 +563,14 @@ impl FileWriter for ParquetWriter {
563563 . filter ( |value| value. map_or ( false , |v| v. is_nan ( ) ) )
564564 . count ( ) as u64
565565 }
566+ DataType :: Float64 => {
567+ let float_array = col. as_any ( ) . downcast_ref :: < Float64Array > ( ) . unwrap ( ) ;
568+
569+ float_array
570+ . iter ( )
571+ . filter ( |value| value. map_or ( false , |v| v. is_nan ( ) ) )
572+ . count ( ) as u64
573+ }
566574 _ => 0 ,
567575 } ;
568576
@@ -830,6 +838,7 @@ mod tests {
830838 assert_eq ! ( visitor. name_to_id, expect) ;
831839 }
832840
841+ // TODO(feniljain): Remove nan value count test from here
833842 #[ tokio:: test]
834843 async fn test_parquet_writer ( ) -> Result < ( ) > {
835844 let temp_dir = TempDir :: new ( ) . unwrap ( ) ;
@@ -922,6 +931,102 @@ mod tests {
922931 Ok ( ( ) )
923932 }
924933
934+ #[ tokio:: test]
935+ async fn test_parquet_writer_for_nan_value_counts ( ) -> Result < ( ) > {
936+ let temp_dir = TempDir :: new ( ) . unwrap ( ) ;
937+ let file_io = FileIOBuilder :: new_fs_io ( ) . build ( ) . unwrap ( ) ;
938+ let location_gen =
939+ MockLocationGenerator :: new ( temp_dir. path ( ) . to_str ( ) . unwrap ( ) . to_string ( ) ) ;
940+ let file_name_gen =
941+ DefaultFileNameGenerator :: new ( "test" . to_string ( ) , None , DataFileFormat :: Parquet ) ;
942+
943+ // prepare data
944+ let schema = {
945+ let fields = vec ! [
946+ // TODO(feniljain):
947+ // Types:
948+ // [X] Primitive
949+ // [ ] Struct
950+ // [ ] List
951+ // [ ] Map
952+ arrow_schema:: Field :: new( "col" , arrow_schema:: DataType :: Float32 , true )
953+ . with_metadata( HashMap :: from( [ (
954+ PARQUET_FIELD_ID_META_KEY . to_string( ) ,
955+ "0" . to_string( ) ,
956+ ) ] ) ) ,
957+ arrow_schema:: Field :: new( "col1" , arrow_schema:: DataType :: Float64 , true )
958+ . with_metadata( HashMap :: from( [ (
959+ PARQUET_FIELD_ID_META_KEY . to_string( ) ,
960+ "1" . to_string( ) ,
961+ ) ] ) ) ,
962+ ] ;
963+ Arc :: new ( arrow_schema:: Schema :: new ( fields) )
964+ } ;
965+
966+ let float_32_col = Arc :: new ( Float32Array :: from_iter_values_with_nulls (
967+ [ 1.0_f32 , f32:: NAN , 2.0 , 2.0 ] . into_iter ( ) ,
968+ None ,
969+ ) ) as ArrayRef ;
970+
971+ let float_64_col = Arc :: new ( Float64Array :: from_iter_values_with_nulls (
972+ [ 1.0_f64 , f64:: NAN , 2.0 , 2.0 ] . into_iter ( ) ,
973+ None ,
974+ ) ) as ArrayRef ;
975+
976+ let to_write =
977+ RecordBatch :: try_new ( schema. clone ( ) , vec ! [ float_32_col, float_64_col] ) . unwrap ( ) ;
978+
979+ // write data
980+ let mut pw = ParquetWriterBuilder :: new (
981+ WriterProperties :: builder ( ) . build ( ) ,
982+ Arc :: new ( to_write. schema ( ) . as_ref ( ) . try_into ( ) . unwrap ( ) ) ,
983+ file_io. clone ( ) ,
984+ location_gen,
985+ file_name_gen,
986+ )
987+ . build ( )
988+ . await ?;
989+
990+ pw. write ( & to_write) . await ?;
991+ let res = pw. close ( ) . await ?;
992+ assert_eq ! ( res. len( ) , 1 ) ;
993+ let data_file = res
994+ . into_iter ( )
995+ . next ( )
996+ . unwrap ( )
997+ // Put dummy field for build successfully.
998+ . content ( crate :: spec:: DataContentType :: Data )
999+ . partition ( Struct :: empty ( ) )
1000+ . build ( )
1001+ . unwrap ( ) ;
1002+
1003+ // check data file
1004+ assert_eq ! ( data_file. record_count( ) , 4 ) ;
1005+ assert_eq ! ( * data_file. value_counts( ) , HashMap :: from( [ ( 0 , 4 ) , ( 1 , 4 ) ] ) ) ;
1006+ assert_eq ! (
1007+ * data_file. lower_bounds( ) ,
1008+ HashMap :: from( [ ( 0 , Datum :: float( 1.0 ) ) , ( 1 , Datum :: double( 1.0 ) ) ] )
1009+ ) ;
1010+ assert_eq ! (
1011+ * data_file. upper_bounds( ) ,
1012+ HashMap :: from( [ ( 0 , Datum :: float( 2.0 ) ) , ( 1 , Datum :: double( 2.0 ) ) ] )
1013+ ) ;
1014+ assert_eq ! (
1015+ * data_file. null_value_counts( ) ,
1016+ HashMap :: from( [ ( 0 , 0 ) , ( 1 , 0 ) ] )
1017+ ) ;
1018+ assert_eq ! (
1019+ * data_file. nan_value_counts( ) ,
1020+ HashMap :: from( [ ( 0 , 1 ) , ( 1 , 1 ) ] )
1021+ ) ;
1022+
1023+ // check the written file
1024+ let expect_batch = concat_batches ( & schema, vec ! [ & to_write] ) . unwrap ( ) ;
1025+ check_parquet_data_file ( & file_io, & data_file, & expect_batch) . await ;
1026+
1027+ Ok ( ( ) )
1028+ }
1029+
9251030 #[ tokio:: test]
9261031 async fn test_parquet_writer_with_complex_schema ( ) -> Result < ( ) > {
9271032 let temp_dir = TempDir :: new ( ) . unwrap ( ) ;
0 commit comments