diff --git a/src/common.rs b/src/common.rs index fe92b25..66505cb 100644 --- a/src/common.rs +++ b/src/common.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use datafusion::arrow::array::{ downcast_array, AnyDictionaryArray, Array, ArrayAccessor, ArrayRef, AsArray, DictionaryArray, LargeStringArray, - PrimitiveArray, RunArray, StringArray, StringViewArray, + PrimitiveArray, PrimitiveBuilder, RunArray, StringArray, StringViewArray, }; use datafusion::arrow::compute::kernels::cast; use datafusion::arrow::compute::take; @@ -245,6 +245,34 @@ fn invoke_array_array( } } +/// Transform keys that may be pointing to values with nulls to nulls themselves. +/// keys = `[0, 1, 2, 3]`, values = `[null, "a", null, "b"]` +/// into +/// keys = `[null, 0, null, 1]`, values = `["a", "b"]` +/// +/// Arrow / `DataFusion` assumes that dictionary values do not contain nulls, nulls are encoded by the keys. +/// Not following this invariant causes invalid dictionary arrays to be built later on inside of `DataFusion` +/// when arrays are concacted and such. +fn remap_dictionary_key_nulls(keys: PrimitiveArray, values: ArrayRef) -> DictionaryArray { + // fast path: no nulls in values + if values.null_count() == 0 { + return DictionaryArray::new(keys, values); + } + + let mut new_keys_builder = PrimitiveBuilder::::new(); + + for key in &keys { + match key { + Some(k) if values.is_null(k.as_usize()) => new_keys_builder.append_null(), + Some(k) => new_keys_builder.append_value(k), + None => new_keys_builder.append_null(), + } + } + + let new_keys = new_keys_builder.finish(); + DictionaryArray::new(new_keys, values) +} + fn invoke_array_scalars( json_array: &ArrayRef, path: &[JsonPath], @@ -281,7 +309,7 @@ fn invoke_array_scalars( let type_ids = values.as_union().type_ids(); keys = mask_dictionary_keys(&keys, type_ids); } - Ok(Arc::new(DictionaryArray::new(keys, values))) + Ok(Arc::new(remap_dictionary_key_nulls(keys, values))) } else { // this is what cast would do under the hood to unpack a dictionary into an array of its values Ok(take(&values, json_array.keys(), None)?) diff --git a/tests/main.rs b/tests/main.rs index d71a910..0019e1e 100644 --- a/tests/main.rs +++ b/tests/main.rs @@ -1,7 +1,7 @@ use std::sync::Arc; -use datafusion::arrow::array::{ArrayRef, RecordBatch}; -use datafusion::arrow::datatypes::{Field, Int8Type, Schema}; +use datafusion::arrow::array::{Array, ArrayRef, DictionaryArray, RecordBatch}; +use datafusion::arrow::datatypes::{Field, Int64Type, Int8Type, Schema}; use datafusion::arrow::{array::StringDictionaryBuilder, datatypes::DataType}; use datafusion::assert_batches_eq; use datafusion::common::ScalarValue; @@ -1280,6 +1280,68 @@ async fn test_dict_haystack() { assert_batches_eq!(expected, &batches); } +fn check_for_null_dictionary_values(array: &dyn Array) { + let array = array.as_any().downcast_ref::>().unwrap(); + let keys_array = array.keys(); + let keys = keys_array + .iter() + .filter_map(|x| x.map(|v| usize::try_from(v).unwrap())) + .collect::>(); + let values_array = array.values(); + // no non-null keys should point to a null value + for i in 0..values_array.len() { + if values_array.is_null(i) { + // keys should not contain + if keys.contains(&i) { + println!("keys: {:?}", keys); + println!("values: {:?}", values_array); + panic!("keys should not contain null values"); + } + } + } +} + +/// Test that we don't output nulls in dictionary values. +/// This can cause issues with arrow-rs and DataFusion; they expect nulls to be in keys. +#[tokio::test] +async fn test_dict_get_no_null_values() { + let ctx = build_dict_schema().await; + + let sql = "select json_get(x, 'baz') v from data"; + let expected = [ + "+------------+", + "| v |", + "+------------+", + "| |", + "| {str=fizz} |", + "| |", + "| {str=abcd} |", + "| |", + "| {str=fizz} |", + "| {str=fizz} |", + "| {str=fizz} |", + "| {str=fizz} |", + "| |", + "+------------+", + ]; + let batches = ctx.sql(&sql).await.unwrap().collect().await.unwrap(); + assert_batches_eq!(expected, &batches); + for batch in batches { + check_for_null_dictionary_values(batch.column(0).as_ref()); + } + + let sql = "select json_get_str(x, 'baz') v from data"; + let expected = [ + "+------+", "| v |", "+------+", "| |", "| fizz |", "| |", "| abcd |", "| |", "| fizz |", + "| fizz |", "| fizz |", "| fizz |", "| |", "+------+", + ]; + let batches = ctx.sql(&sql).await.unwrap().collect().await.unwrap(); + assert_batches_eq!(expected, &batches); + for batch in batches { + check_for_null_dictionary_values(batch.column(0).as_ref()); + } +} + #[tokio::test] async fn test_dict_haystack_filter() { let sql = "select json_data v from dicts where json_get(json_data, 'foo') is not null";