Skip to content

Commit 6654325

Browse files
committed
feat: add float64 nan value counts support
1 parent e222bbe commit 6654325

File tree

1 file changed

+106
-1
lines changed

1 file changed

+106
-1
lines changed

crates/iceberg/src/writer/file_writer/parquet_writer.rs

Lines changed: 106 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ use std::collections::HashMap;
2222
use std::sync::atomic::AtomicI64;
2323
use std::sync::Arc;
2424

25-
use arrow_array::Float32Array;
25+
use arrow_array::{Float32Array, Float64Array};
2626
use arrow_schema::{DataType, SchemaRef as ArrowSchemaRef};
2727
use bytes::Bytes;
2828
use futures::future::BoxFuture;
@@ -563,6 +563,14 @@ impl FileWriter for ParquetWriter {
563563
.filter(|value| value.map_or(false, |v| v.is_nan()))
564564
.count() as u64
565565
}
566+
DataType::Float64 => {
567+
let float_array = col.as_any().downcast_ref::<Float64Array>().unwrap();
568+
569+
float_array
570+
.iter()
571+
.filter(|value| value.map_or(false, |v| v.is_nan()))
572+
.count() as u64
573+
}
566574
_ => 0,
567575
};
568576

@@ -830,6 +838,7 @@ mod tests {
830838
assert_eq!(visitor.name_to_id, expect);
831839
}
832840

841+
// TODO(feniljain): Remove nan value count test from here
833842
#[tokio::test]
834843
async fn test_parquet_writer() -> Result<()> {
835844
let temp_dir = TempDir::new().unwrap();
@@ -922,6 +931,102 @@ mod tests {
922931
Ok(())
923932
}
924933

934+
#[tokio::test]
935+
async fn test_parquet_writer_for_nan_value_counts() -> Result<()> {
936+
let temp_dir = TempDir::new().unwrap();
937+
let file_io = FileIOBuilder::new_fs_io().build().unwrap();
938+
let location_gen =
939+
MockLocationGenerator::new(temp_dir.path().to_str().unwrap().to_string());
940+
let file_name_gen =
941+
DefaultFileNameGenerator::new("test".to_string(), None, DataFileFormat::Parquet);
942+
943+
// prepare data
944+
let schema = {
945+
let fields = vec![
946+
// TODO(feniljain):
947+
// Types:
948+
// [X] Primitive
949+
// [ ] Struct
950+
// [ ] List
951+
// [ ] Map
952+
arrow_schema::Field::new("col", arrow_schema::DataType::Float32, true)
953+
.with_metadata(HashMap::from([(
954+
PARQUET_FIELD_ID_META_KEY.to_string(),
955+
"0".to_string(),
956+
)])),
957+
arrow_schema::Field::new("col1", arrow_schema::DataType::Float64, true)
958+
.with_metadata(HashMap::from([(
959+
PARQUET_FIELD_ID_META_KEY.to_string(),
960+
"1".to_string(),
961+
)])),
962+
];
963+
Arc::new(arrow_schema::Schema::new(fields))
964+
};
965+
966+
let float_32_col = Arc::new(Float32Array::from_iter_values_with_nulls(
967+
[1.0_f32, f32::NAN, 2.0, 2.0].into_iter(),
968+
None,
969+
)) as ArrayRef;
970+
971+
let float_64_col = Arc::new(Float64Array::from_iter_values_with_nulls(
972+
[1.0_f64, f64::NAN, 2.0, 2.0].into_iter(),
973+
None,
974+
)) as ArrayRef;
975+
976+
let to_write =
977+
RecordBatch::try_new(schema.clone(), vec![float_32_col, float_64_col]).unwrap();
978+
979+
// write data
980+
let mut pw = ParquetWriterBuilder::new(
981+
WriterProperties::builder().build(),
982+
Arc::new(to_write.schema().as_ref().try_into().unwrap()),
983+
file_io.clone(),
984+
location_gen,
985+
file_name_gen,
986+
)
987+
.build()
988+
.await?;
989+
990+
pw.write(&to_write).await?;
991+
let res = pw.close().await?;
992+
assert_eq!(res.len(), 1);
993+
let data_file = res
994+
.into_iter()
995+
.next()
996+
.unwrap()
997+
// Put dummy field for build successfully.
998+
.content(crate::spec::DataContentType::Data)
999+
.partition(Struct::empty())
1000+
.build()
1001+
.unwrap();
1002+
1003+
// check data file
1004+
assert_eq!(data_file.record_count(), 4);
1005+
assert_eq!(*data_file.value_counts(), HashMap::from([(0, 4), (1, 4)]));
1006+
assert_eq!(
1007+
*data_file.lower_bounds(),
1008+
HashMap::from([(0, Datum::float(1.0)), (1, Datum::double(1.0))])
1009+
);
1010+
assert_eq!(
1011+
*data_file.upper_bounds(),
1012+
HashMap::from([(0, Datum::float(2.0)), (1, Datum::double(2.0))])
1013+
);
1014+
assert_eq!(
1015+
*data_file.null_value_counts(),
1016+
HashMap::from([(0, 0), (1, 0)])
1017+
);
1018+
assert_eq!(
1019+
*data_file.nan_value_counts(),
1020+
HashMap::from([(0, 1), (1, 1)])
1021+
);
1022+
1023+
// check the written file
1024+
let expect_batch = concat_batches(&schema, vec![&to_write]).unwrap();
1025+
check_parquet_data_file(&file_io, &data_file, &expect_batch).await;
1026+
1027+
Ok(())
1028+
}
1029+
9251030
#[tokio::test]
9261031
async fn test_parquet_writer_with_complex_schema() -> Result<()> {
9271032
let temp_dir = TempDir::new().unwrap();

0 commit comments

Comments
 (0)