Skip to content

Commit 73c8da2

Browse files
authored
cast dictionaries to i64 key type to avoid generic explosion (#66)
1 parent 82521b6 commit 73c8da2

File tree

1 file changed

+155
-95
lines changed

1 file changed

+155
-95
lines changed

src/common.rs

Lines changed: 155 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,12 @@ use std::str::Utf8Error;
22
use std::sync::Arc;
33

44
use datafusion::arrow::array::{
5-
Array, ArrayAccessor, ArrayRef, AsArray, DictionaryArray, Int64Array, LargeStringArray, PrimitiveArray,
6-
StringArray, StringViewArray, UInt64Array,
5+
downcast_array, AnyDictionaryArray, Array, ArrayAccessor, ArrayRef, AsArray, DictionaryArray, LargeStringArray,
6+
PrimitiveArray, StringArray, StringViewArray,
77
};
8+
use datafusion::arrow::compute::kernels::cast;
89
use datafusion::arrow::compute::take;
9-
use datafusion::arrow::datatypes::{
10-
ArrowDictionaryKeyType, ArrowNativeType, ArrowNativeTypeOp, DataType, Int64Type, UInt64Type,
11-
};
12-
use datafusion::arrow::downcast_dictionary_array;
10+
use datafusion::arrow::datatypes::{ArrowNativeType, DataType, Int64Type, UInt64Type};
1311
use datafusion::common::{exec_err, plan_err, Result as DataFusionResult, ScalarValue};
1412
use datafusion::logical_expr::ColumnarValue;
1513
use jiter::{Jiter, JiterError, Peek};
@@ -45,9 +43,10 @@ pub fn return_type_check(args: &[DataType], fn_name: &str, value_type: DataType)
4543
)
4644
}
4745
})?;
48-
match first_dict_key_type {
49-
Some(t) => Ok(DataType::Dictionary(Box::new(t), Box::new(value_type))),
50-
None => Ok(value_type),
46+
if first_dict_key_type.is_some() {
47+
Ok(DataType::Dictionary(Box::new(DataType::Int64), Box::new(value_type)))
48+
} else {
49+
Ok(value_type)
5150
}
5251
}
5352

@@ -176,59 +175,68 @@ fn invoke_array_array<C: FromIterator<Option<I>> + 'static, I>(
176175
jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result<I, GetError>,
177176
return_dict: bool,
178177
) -> DataFusionResult<ArrayRef> {
179-
downcast_dictionary_array!(
180-
json_array => {
181-
fn wrap_as_dictionary<K: ArrowDictionaryKeyType>(original: &DictionaryArray<K>, new_values: ArrayRef) -> DictionaryArray<K> {
182-
assert_eq!(original.keys().len(), new_values.len());
183-
let mut key = K::Native::ZERO;
184-
let key_range = std::iter::from_fn(move || {
185-
let next = key;
186-
key = key.add_checked(K::Native::ONE).expect("keys exhausted");
187-
Some(next)
188-
}).take(new_values.len());
189-
let mut keys = PrimitiveArray::<K>::from_iter_values(key_range);
190-
if is_json_union(new_values.data_type()) {
191-
// JSON union: post-process the array to set keys to null where the union member is null
192-
let type_ids = new_values.as_union().type_ids();
193-
keys = mask_dictionary_keys(&keys, type_ids);
194-
}
195-
DictionaryArray::<K>::new(keys, new_values)
178+
match json_array.data_type() {
179+
// for string dictionaries, cast dictionary keys to larger types to avoid generic explosion
180+
DataType::Dictionary(_, value_type) if value_type.as_ref() == &DataType::Utf8 => {
181+
let json_array = cast_to_large_dictionary(json_array.as_any_dictionary())?;
182+
let output = zip_apply(
183+
json_array.downcast_dict::<StringArray>().unwrap(),
184+
path_array,
185+
to_array,
186+
jiter_find,
187+
)?;
188+
if return_dict {
189+
// ensure return is a dictionary to satisfy the declaration above in return_type_check
190+
Ok(Arc::new(wrap_as_large_dictionary(&json_array, output)))
191+
} else {
192+
Ok(output)
196193
}
197-
198-
// TODO: in theory if path_array is _also_ a dictionary we could work out the unique key
199-
// combinations and do less work, but this can be left as a future optimization
200-
let output = match json_array.values().data_type() {
201-
DataType::Utf8 => zip_apply(json_array.downcast_dict::<StringArray>().unwrap(), path_array, to_array, jiter_find),
202-
DataType::LargeUtf8 => zip_apply(json_array.downcast_dict::<LargeStringArray>().unwrap(), path_array, to_array, jiter_find),
203-
DataType::Utf8View => zip_apply(json_array.downcast_dict::<StringViewArray>().unwrap(), path_array, to_array, jiter_find),
204-
other => if let Some(child_array) = nested_json_array_ref(json_array.values(), is_object_lookup_array(path_array.data_type())) {
205-
// Horrible case: dict containing union as input with array for paths, figure
206-
// out from the path type which union members we should access, repack the
207-
// dictionary and then recurse.
208-
//
209-
// Use direct return because if return_dict applies, the recursion will handle it.
210-
return invoke_array_array(&(Arc::new(json_array.with_values(child_array.clone())) as _), path_array, to_array, jiter_find, return_dict)
211-
} else {
212-
exec_err!("unexpected json array type {:?}", other)
213-
}
214-
}?;
215-
194+
}
195+
DataType::Dictionary(_, value_type) if value_type.as_ref() == &DataType::LargeUtf8 => {
196+
let json_array = cast_to_large_dictionary(json_array.as_any_dictionary())?;
197+
let output = zip_apply(
198+
json_array.downcast_dict::<LargeStringArray>().unwrap(),
199+
path_array,
200+
to_array,
201+
jiter_find,
202+
)?;
216203
if return_dict {
217204
// ensure return is a dictionary to satisfy the declaration above in return_type_check
218-
Ok(Arc::new(wrap_as_dictionary(json_array, output)))
205+
Ok(Arc::new(wrap_as_large_dictionary(&json_array, output)))
219206
} else {
220207
Ok(output)
221208
}
222-
},
209+
}
210+
other_dict_type @ DataType::Dictionary(_, _) => {
211+
// Horrible case: dict containing union as input with array for paths, figure
212+
// out from the path type which union members we should access, repack the
213+
// dictionary and then recurse.
214+
if let Some(child_array) = nested_json_array_ref(
215+
json_array.as_any_dictionary().values(),
216+
is_object_lookup_array(path_array.data_type()),
217+
) {
218+
invoke_array_array(
219+
&(Arc::new(json_array.as_any_dictionary().with_values(child_array.clone())) as _),
220+
path_array,
221+
to_array,
222+
jiter_find,
223+
return_dict,
224+
)
225+
} else {
226+
exec_err!("unexpected json array type {:?}", other_dict_type)
227+
}
228+
}
223229
DataType::Utf8 => zip_apply(json_array.as_string::<i32>().iter(), path_array, to_array, jiter_find),
224230
DataType::LargeUtf8 => zip_apply(json_array.as_string::<i64>().iter(), path_array, to_array, jiter_find),
225231
DataType::Utf8View => zip_apply(json_array.as_string_view().iter(), path_array, to_array, jiter_find),
226-
other => if let Some(string_array) = nested_json_array(json_array, is_object_lookup_array(path_array.data_type())) {
227-
zip_apply(string_array.iter(), path_array, to_array, jiter_find)
228-
} else {
229-
exec_err!("unexpected json array type {:?}", other)
232+
other => {
233+
if let Some(string_array) = nested_json_array(json_array, is_object_lookup_array(path_array.data_type())) {
234+
zip_apply(string_array.iter(), path_array, to_array, jiter_find)
235+
} else {
236+
exec_err!("unexpected json array type {:?}", other)
237+
}
230238
}
231-
)
239+
}
232240
}
233241

234242
fn invoke_array_scalars<C: FromIterator<Option<I>>, I>(
@@ -249,20 +257,35 @@ fn invoke_array_scalars<C: FromIterator<Option<I>>, I>(
249257
.collect::<C>()
250258
}
251259

252-
let c = downcast_dictionary_array!(
253-
json_array => {
260+
let c = match json_array.data_type() {
261+
DataType::Dictionary(_, _) => {
262+
let json_array = json_array.as_any_dictionary();
254263
let values = invoke_array_scalars(json_array.values(), path, to_array, jiter_find, false)?;
255-
return post_process_dict(json_array, values, return_dict);
264+
return if return_dict {
265+
// make the keys into i64 to avoid generic bloat here
266+
let mut keys: PrimitiveArray<Int64Type> = downcast_array(&cast(json_array.keys(), &DataType::Int64)?);
267+
if is_json_union(values.data_type()) {
268+
// JSON union: post-process the array to set keys to null where the union member is null
269+
let type_ids = values.as_union().type_ids();
270+
keys = mask_dictionary_keys(&keys, type_ids);
271+
}
272+
Ok(Arc::new(DictionaryArray::new(keys, values)))
273+
} else {
274+
// this is what cast would do under the hood to unpack a dictionary into an array of its values
275+
Ok(take(&values, json_array.keys(), None)?)
276+
};
256277
}
257278
DataType::Utf8 => inner(json_array.as_string::<i32>(), path, jiter_find),
258279
DataType::LargeUtf8 => inner(json_array.as_string::<i64>(), path, jiter_find),
259280
DataType::Utf8View => inner(json_array.as_string_view(), path, jiter_find),
260-
other => if let Some(string_array) = nested_json_array(json_array, is_object_lookup(path)) {
261-
inner(string_array, path, jiter_find)
262-
} else {
263-
return exec_err!("unexpected json array type {:?}", other);
281+
other => {
282+
if let Some(string_array) = nested_json_array(json_array, is_object_lookup(path)) {
283+
inner(string_array, path, jiter_find)
284+
} else {
285+
return exec_err!("unexpected json array type {:?}", other);
286+
}
264287
}
265-
);
288+
};
266289
to_array(c)
267290
}
268291

@@ -323,22 +346,57 @@ fn zip_apply<'a, C: FromIterator<Option<I>> + 'static, I>(
323346
.collect::<C>()
324347
}
325348

326-
let c = downcast_dictionary_array!(
327-
path_array => match path_array.values().data_type() {
328-
DataType::Utf8 => inner(json_array, path_array.downcast_dict::<StringArray>().unwrap(), jiter_find),
329-
DataType::LargeUtf8 => inner(json_array, path_array.downcast_dict::<LargeStringArray>().unwrap(), jiter_find),
330-
DataType::Utf8View => inner(json_array, path_array.downcast_dict::<StringViewArray>().unwrap(), jiter_find),
331-
DataType::Int64 => inner(json_array, path_array.downcast_dict::<Int64Array>().unwrap(), jiter_find),
332-
DataType::UInt64 => inner(json_array, path_array.downcast_dict::<UInt64Array>().unwrap(), jiter_find),
333-
other => return exec_err!("unexpected second argument type, expected string or int array, got {:?}", other),
334-
},
349+
let c = match path_array.data_type() {
350+
// for string dictionaries, cast dictionary keys to larger types to avoid generic explosion
351+
DataType::Dictionary(_, value_type) if value_type.as_ref() == &DataType::Utf8 => {
352+
let path_array = cast_to_large_dictionary(path_array.as_any_dictionary())?;
353+
inner(
354+
json_array,
355+
path_array.downcast_dict::<StringArray>().unwrap(),
356+
jiter_find,
357+
)
358+
}
359+
DataType::Dictionary(_, value_type) if value_type.as_ref() == &DataType::LargeUtf8 => {
360+
let path_array = cast_to_large_dictionary(path_array.as_any_dictionary())?;
361+
inner(
362+
json_array,
363+
path_array.downcast_dict::<LargeStringArray>().unwrap(),
364+
jiter_find,
365+
)
366+
}
367+
DataType::Dictionary(_, value_type) if value_type.as_ref() == &DataType::Utf8View => {
368+
let path_array = cast_to_large_dictionary(path_array.as_any_dictionary())?;
369+
inner(
370+
json_array,
371+
path_array.downcast_dict::<StringViewArray>().unwrap(),
372+
jiter_find,
373+
)
374+
}
375+
// for integer dictionaries, cast them directly to the inner type because it basically costs
376+
// the same as building a new key array anyway
377+
DataType::Dictionary(_, value_type) if value_type.as_ref() == &DataType::Int64 => inner(
378+
json_array,
379+
cast(path_array, &DataType::Int64)?.as_primitive::<Int64Type>(),
380+
jiter_find,
381+
),
382+
DataType::Dictionary(_, value_type) if value_type.as_ref() == &DataType::UInt64 => inner(
383+
json_array,
384+
cast(path_array, &DataType::UInt64)?.as_primitive::<UInt64Type>(),
385+
jiter_find,
386+
),
387+
// for basic types, just consume directly
335388
DataType::Utf8 => inner(json_array, path_array.as_string::<i32>(), jiter_find),
336389
DataType::LargeUtf8 => inner(json_array, path_array.as_string::<i64>(), jiter_find),
337390
DataType::Utf8View => inner(json_array, path_array.as_string_view(), jiter_find),
338391
DataType::Int64 => inner(json_array, path_array.as_primitive::<Int64Type>(), jiter_find),
339392
DataType::UInt64 => inner(json_array, path_array.as_primitive::<UInt64Type>(), jiter_find),
340-
other => return exec_err!("unexpected second argument type, expected string or int array, got {:?}", other)
341-
);
393+
other => {
394+
return exec_err!(
395+
"unexpected second argument type, expected string or int array, got {:?}",
396+
other
397+
)
398+
}
399+
};
342400

343401
to_array(c)
344402
}
@@ -356,29 +414,6 @@ fn extract_json_scalar(scalar: &ScalarValue) -> DataFusionResult<Option<&str>> {
356414
}
357415
}
358416

359-
/// Take a dictionary array of JSON data and an array of result values and combine them.
360-
fn post_process_dict<T: ArrowDictionaryKeyType>(
361-
dict_array: &DictionaryArray<T>,
362-
result_values: ArrayRef,
363-
return_dict: bool,
364-
) -> DataFusionResult<ArrayRef> {
365-
if return_dict {
366-
if is_json_union(result_values.data_type()) {
367-
// JSON union: post-process the array to set keys to null where the union member is null
368-
let type_ids = result_values.as_union().type_ids();
369-
Ok(Arc::new(DictionaryArray::new(
370-
mask_dictionary_keys(dict_array.keys(), type_ids),
371-
result_values,
372-
)))
373-
} else {
374-
Ok(Arc::new(dict_array.with_values(result_values)))
375-
}
376-
} else {
377-
// this is what cast would do under the hood to unpack a dictionary into an array of its values
378-
Ok(take(&result_values, dict_array.keys(), None)?)
379-
}
380-
}
381-
382417
fn is_object_lookup(path: &[JsonPath]) -> bool {
383418
if let Some(first) = path.first() {
384419
matches!(first, JsonPath::Key(_))
@@ -395,6 +430,31 @@ fn is_object_lookup_array(data_type: &DataType) -> bool {
395430
}
396431
}
397432

433+
/// Cast an array to a dictionary with i64 indices.
434+
///
435+
/// According to <https://arrow.apache.org/docs/format/Columnar.html#dictionary-encoded-layout> the
436+
/// recommendation is to avoid unsigned indices due to technologies like the JVM making it harder to
437+
/// support unsigned integers.
438+
///
439+
/// So we'll just use i64 as the largest signed integer type.
440+
fn cast_to_large_dictionary(dict_array: &dyn AnyDictionaryArray) -> DataFusionResult<DictionaryArray<Int64Type>> {
441+
let keys = downcast_array(&cast(dict_array.keys(), &DataType::Int64)?);
442+
Ok(DictionaryArray::<Int64Type>::new(keys, dict_array.values().clone()))
443+
}
444+
445+
/// Wrap an array as a dictionary with i64 indices.
446+
fn wrap_as_large_dictionary(original: &dyn AnyDictionaryArray, new_values: ArrayRef) -> DictionaryArray<Int64Type> {
447+
assert_eq!(original.keys().len(), new_values.len());
448+
let mut keys =
449+
PrimitiveArray::from_iter_values(0i64..original.keys().len().try_into().expect("keys out of i64 range"));
450+
if is_json_union(new_values.data_type()) {
451+
// JSON union: post-process the array to set keys to null where the union member is null
452+
let type_ids = new_values.as_union().type_ids();
453+
keys = mask_dictionary_keys(&keys, type_ids);
454+
}
455+
DictionaryArray::new(keys, new_values)
456+
}
457+
398458
pub fn jiter_json_find<'j>(opt_json: Option<&'j str>, path: &[JsonPath]) -> Option<(Jiter<'j>, Peek)> {
399459
let json_str = opt_json?;
400460
let mut jiter = Jiter::new(json_str.as_bytes());
@@ -457,7 +517,7 @@ impl From<Utf8Error> for GetError {
457517
///
458518
/// That said, doing this might also be an optimization for cases like null-checking without needing
459519
/// to check the value union array.
460-
fn mask_dictionary_keys<K: ArrowDictionaryKeyType>(keys: &PrimitiveArray<K>, type_ids: &[i8]) -> PrimitiveArray<K> {
520+
fn mask_dictionary_keys(keys: &PrimitiveArray<Int64Type>, type_ids: &[i8]) -> PrimitiveArray<Int64Type> {
461521
let mut null_mask = vec![true; keys.len()];
462522
for (i, k) in keys.iter().enumerate() {
463523
match k {

0 commit comments

Comments
 (0)