Skip to content

Commit e222bbe

Browse files
committed
feat: nan_value_counts support
1 parent e4ca871 commit e222bbe

File tree

1 file changed

+73
-10
lines changed

1 file changed

+73
-10
lines changed

crates/iceberg/src/writer/file_writer/parquet_writer.rs

Lines changed: 73 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,13 @@
1717

1818
//! The module contains the file writer for parquet file format.
1919
20+
use std::collections::hash_map::Entry;
2021
use std::collections::HashMap;
2122
use std::sync::atomic::AtomicI64;
2223
use std::sync::Arc;
2324

24-
use arrow_schema::SchemaRef as ArrowSchemaRef;
25+
use arrow_array::Float32Array;
26+
use arrow_schema::{DataType, SchemaRef as ArrowSchemaRef};
2527
use bytes::Bytes;
2628
use futures::future::BoxFuture;
2729
use itertools::Itertools;
@@ -97,6 +99,7 @@ impl<T: LocationGenerator, F: FileNameGenerator> FileWriterBuilder for ParquetWr
9799
written_size,
98100
current_row_num: 0,
99101
out_file,
102+
nan_value_counts: HashMap::new(),
100103
})
101104
}
102105
}
@@ -222,6 +225,7 @@ pub struct ParquetWriter {
222225
writer_properties: WriterProperties,
223226
written_size: Arc<AtomicI64>,
224227
current_row_num: usize,
228+
nan_value_counts: HashMap<i32, u64>,
225229
}
226230

227231
/// Used to aggregate min and max value of each column.
@@ -357,6 +361,7 @@ impl ParquetWriter {
357361
metadata: FileMetaData,
358362
written_size: usize,
359363
file_path: String,
364+
nan_value_counts: HashMap<i32, u64>,
360365
) -> Result<DataFileBuilder> {
361366
let index_by_parquet_path = {
362367
let mut visitor = IndexByParquetPathName::new();
@@ -423,8 +428,8 @@ impl ParquetWriter {
423428
.null_value_counts(null_value_counts)
424429
.lower_bounds(lower_bounds)
425430
.upper_bounds(upper_bounds)
431+
.nan_value_counts(nan_value_counts)
426432
// # TODO(#417)
427-
// - nan_value_counts
428433
// - distinct_counts
429434
.key_metadata(metadata.footer_signing_key_metadata)
430435
.split_offsets(
@@ -541,13 +546,45 @@ impl FileWriter for ParquetWriter {
541546
self.inner_writer.as_mut().unwrap()
542547
};
543548

549+
550+
for (col, field) in batch
551+
.columns()
552+
.iter()
553+
.zip(self.schema.as_struct().fields().iter())
554+
{
555+
let dt = col.data_type();
556+
557+
let nan_val_cnt: u64 = match dt {
558+
DataType::Float32 => {
559+
let float_array = col.as_any().downcast_ref::<Float32Array>().unwrap();
560+
561+
float_array
562+
.iter()
563+
.filter(|value| value.map_or(false, |v| v.is_nan()))
564+
.count() as u64
565+
}
566+
_ => 0,
567+
};
568+
569+
match self.nan_value_counts.entry(field.id) {
570+
Entry::Occupied(mut ele) => {
571+
let total_nan_val_cnt = ele.get() + nan_val_cnt;
572+
ele.insert(total_nan_val_cnt);
573+
}
574+
Entry::Vacant(v) => {
575+
v.insert(nan_val_cnt);
576+
}
577+
}
578+
}
579+
544580
writer.write(batch).await.map_err(|err| {
545581
Error::new(
546582
ErrorKind::Unexpected,
547583
"Failed to write using parquet writer.",
548584
)
549585
.with_source(err)
550586
})?;
587+
551588
Ok(())
552589
}
553590

@@ -566,6 +603,7 @@ impl FileWriter for ParquetWriter {
566603
metadata,
567604
written_size as usize,
568605
self.out_file.location().to_string(),
606+
self.nan_value_counts,
569607
)?])
570608
}
571609
}
@@ -626,8 +664,8 @@ mod tests {
626664
use anyhow::Result;
627665
use arrow_array::types::Int64Type;
628666
use arrow_array::{
629-
Array, ArrayRef, BooleanArray, Decimal128Array, Int32Array, Int64Array, ListArray,
630-
RecordBatch, StructArray,
667+
Array, ArrayRef, BooleanArray, Decimal128Array, Float32Array, Int32Array, Int64Array,
668+
ListArray, RecordBatch, StructArray,
631669
};
632670
use arrow_schema::{DataType, SchemaRef as ArrowSchemaRef};
633671
use arrow_select::concat::concat_batches;
@@ -807,13 +845,27 @@ mod tests {
807845
arrow_schema::Field::new("col", arrow_schema::DataType::Int64, true).with_metadata(
808846
HashMap::from([(PARQUET_FIELD_ID_META_KEY.to_string(), "0".to_string())]),
809847
),
848+
arrow_schema::Field::new("col1", arrow_schema::DataType::Float32, true)
849+
.with_metadata(HashMap::from([(
850+
PARQUET_FIELD_ID_META_KEY.to_string(),
851+
"1".to_string(),
852+
)])),
810853
];
811854
Arc::new(arrow_schema::Schema::new(fields))
812855
};
813856
let col = Arc::new(Int64Array::from_iter_values(0..1024)) as ArrayRef;
814857
let null_col = Arc::new(Int64Array::new_null(1024)) as ArrayRef;
815-
let to_write = RecordBatch::try_new(schema.clone(), vec![col]).unwrap();
816-
let to_write_null = RecordBatch::try_new(schema.clone(), vec![null_col]).unwrap();
858+
let float_col = Arc::new(Float32Array::from_iter_values((0..1024).map(|x| {
859+
if x % 100 == 0 {
860+
// There will be 11 NANs as there are 1024 entries
861+
f32::NAN
862+
} else {
863+
x as f32
864+
}
865+
}))) as ArrayRef;
866+
let to_write = RecordBatch::try_new(schema.clone(), vec![col, float_col.clone()]).unwrap();
867+
let to_write_null =
868+
RecordBatch::try_new(schema.clone(), vec![null_col, float_col]).unwrap();
817869

818870
// write data
819871
let mut pw = ParquetWriterBuilder::new(
@@ -825,6 +877,7 @@ mod tests {
825877
)
826878
.build()
827879
.await?;
880+
828881
pw.write(&to_write).await?;
829882
pw.write(&to_write_null).await?;
830883
let res = pw.close().await?;
@@ -841,16 +894,26 @@ mod tests {
841894

842895
// check data file
843896
assert_eq!(data_file.record_count(), 2048);
844-
assert_eq!(*data_file.value_counts(), HashMap::from([(0, 2048)]));
897+
assert_eq!(
898+
*data_file.value_counts(),
899+
HashMap::from([(0, 2048), (1, 2048)])
900+
);
845901
assert_eq!(
846902
*data_file.lower_bounds(),
847-
HashMap::from([(0, Datum::long(0))])
903+
HashMap::from([(0, Datum::long(0)), (1, Datum::float(1.0))])
848904
);
849905
assert_eq!(
850906
*data_file.upper_bounds(),
851-
HashMap::from([(0, Datum::long(1023))])
907+
HashMap::from([(0, Datum::long(1023)), (1, Datum::float(1023.0))])
908+
);
909+
assert_eq!(
910+
*data_file.null_value_counts(),
911+
HashMap::from([(0, 1024), (1, 0)])
912+
);
913+
assert_eq!(
914+
*data_file.nan_value_counts(),
915+
HashMap::from([(0, 0), (1, 22)]) // 22, cause we wrote float column twice
852916
);
853-
assert_eq!(*data_file.null_value_counts(), HashMap::from([(0, 1024)]));
854917

855918
// check the written file
856919
let expect_batch = concat_batches(&schema, vec![&to_write, &to_write_null]).unwrap();

0 commit comments

Comments
 (0)