Skip to content

Commit 3f106aa

Browse files
committed
avoid take() allocation
1 parent 2020e13 commit 3f106aa

File tree

1 file changed

+37
-42
lines changed

1 file changed

+37
-42
lines changed

src/common.rs

Lines changed: 37 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@ use std::str::Utf8Error;
22
use std::sync::Arc;
33

44
use datafusion::arrow::array::{
5-
Array, ArrayRef, AsArray, DictionaryArray, Int64Array, LargeStringArray, PrimitiveArray, StringArray,
6-
StringViewArray, UInt64Array, UnionArray,
5+
Array, ArrayAccessor, ArrayRef, AsArray, DictionaryArray, Int64Array, LargeStringArray, PrimitiveArray,
6+
StringArray, StringViewArray, UInt64Array, UnionArray,
77
};
88
use datafusion::arrow::compute::take;
9-
use datafusion::arrow::datatypes::{ArrowDictionaryKeyType, ArrowNativeType, ArrowPrimitiveType, DataType};
9+
use datafusion::arrow::datatypes::{
10+
ArrowDictionaryKeyType, ArrowNativeType, ArrowPrimitiveType, DataType, Int64Type, UInt64Type,
11+
};
1012
use datafusion::arrow::downcast_dictionary_array;
1113
use datafusion::common::{exec_err, plan_err, Result as DataFusionResult, ScalarValue};
1214
use datafusion::logical_expr::ColumnarValue;
@@ -72,6 +74,12 @@ pub enum JsonPath<'s> {
7274
None,
7375
}
7476

77+
impl<'a> From<&'a str> for JsonPath<'a> {
78+
fn from(key: &'a str) -> Self {
79+
JsonPath::Key(key)
80+
}
81+
}
82+
7583
impl From<u64> for JsonPath<'_> {
7684
fn from(index: u64) -> Self {
7785
JsonPath::Index(usize::try_from(index).unwrap())
@@ -145,41 +153,27 @@ fn invoke_array<C: FromIterator<Option<I>> + 'static, I>(
145153
jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result<I, GetError>,
146154
return_dict: bool,
147155
) -> DataFusionResult<ArrayRef> {
148-
if let Some(d) = needle_array.as_any_dictionary_opt() {
149-
// this is the (very rare) case where the needle is a dictionary, it shouldn't affect what we return
150-
invoke_array(
151-
json_array,
152-
// Unpack the dictionary array into a values array, so that we can then use it as input.
153-
// There's probably a way to do this with iterators to avoid exploding the input data,
154-
// but due to possible nested dictionaries, it's not trivial.
155-
&take(d.values(), d.keys(), None)?,
156-
to_array,
157-
jiter_find,
158-
return_dict,
159-
)
160-
} else if let Some(str_path_array) = needle_array.as_any().downcast_ref::<StringArray>() {
161-
let paths = str_path_array.iter().map(|opt_key| opt_key.map(JsonPath::Key));
162-
zip_apply(json_array, paths, to_array, jiter_find, true, return_dict)
163-
} else if let Some(str_path_array) = needle_array.as_any().downcast_ref::<LargeStringArray>() {
164-
let paths = str_path_array.iter().map(|opt_key| opt_key.map(JsonPath::Key));
165-
zip_apply(json_array, paths, to_array, jiter_find, true, return_dict)
166-
} else if let Some(str_path_array) = needle_array.as_any().downcast_ref::<StringViewArray>() {
167-
let paths = str_path_array.iter().map(|opt_key| opt_key.map(JsonPath::Key));
168-
zip_apply(json_array, paths, to_array, jiter_find, true, return_dict)
169-
} else if let Some(int_path_array) = needle_array.as_any().downcast_ref::<Int64Array>() {
170-
let paths = int_path_array.iter().map(|opt_index| opt_index.map(Into::into));
171-
zip_apply(json_array, paths, to_array, jiter_find, false, return_dict)
172-
} else if let Some(int_path_array) = needle_array.as_any().downcast_ref::<UInt64Array>() {
173-
let paths = int_path_array.iter().map(|opt_index| opt_index.map(Into::into));
174-
zip_apply(json_array, paths, to_array, jiter_find, false, return_dict)
175-
} else {
176-
exec_err!("unexpected second argument type, expected string or int array")
177-
}
156+
downcast_dictionary_array!(
157+
needle_array => match needle_array.values().data_type() {
158+
DataType::Utf8 => zip_apply(json_array, needle_array.downcast_dict::<StringArray>().unwrap(), to_array, jiter_find, true, return_dict),
159+
DataType::LargeUtf8 => zip_apply(json_array, needle_array.downcast_dict::<LargeStringArray>().unwrap(), to_array, jiter_find, true, return_dict),
160+
DataType::Utf8View => zip_apply(json_array, needle_array.downcast_dict::<StringViewArray>().unwrap(), to_array, jiter_find, true, return_dict),
161+
DataType::Int64 => zip_apply(json_array, needle_array.downcast_dict::<Int64Array>().unwrap(), to_array, jiter_find, false, return_dict),
162+
DataType::UInt64 => zip_apply(json_array, needle_array.downcast_dict::<UInt64Array>().unwrap(), to_array, jiter_find, false, return_dict),
163+
other => exec_err!("unexpected second argument type, expected string or int array, got {:?}", other),
164+
},
165+
DataType::Utf8 => zip_apply(json_array, needle_array.as_string::<i32>(), to_array, jiter_find, true, return_dict),
166+
DataType::LargeUtf8 => zip_apply(json_array, needle_array.as_string::<i64>(), to_array, jiter_find, true, return_dict),
167+
DataType::Utf8View => zip_apply(json_array, needle_array.as_string_view(), to_array, jiter_find, true, return_dict),
168+
DataType::Int64 => zip_apply(json_array, needle_array.as_primitive::<Int64Type>(), to_array, jiter_find, false, return_dict),
169+
DataType::UInt64 => zip_apply(json_array, needle_array.as_primitive::<UInt64Type>(), to_array, jiter_find, false, return_dict),
170+
other => exec_err!("unexpected second argument type, expected string or int array, got {:?}", other)
171+
)
178172
}
179173

180-
fn zip_apply<'a, P: Iterator<Item = Option<JsonPath<'a>>>, C: FromIterator<Option<I>> + 'static, I>(
174+
fn zip_apply<'a, P: Into<JsonPath<'a>>, C: FromIterator<Option<I>> + 'static, I>(
181175
json_array: &ArrayRef,
182-
path_array: P,
176+
path_array: impl ArrayAccessor<Item = P>,
183177
to_array: impl Fn(C) -> DataFusionResult<ArrayRef>,
184178
jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result<I, GetError>,
185179
object_lookup: bool,
@@ -203,18 +197,19 @@ fn zip_apply<'a, P: Iterator<Item = Option<JsonPath<'a>>>, C: FromIterator<Optio
203197
to_array(c)
204198
}
205199

206-
fn zip_apply_iter<'a, 'j, P: Iterator<Item = Option<JsonPath<'a>>>, C: FromIterator<Option<I>> + 'static, I>(
200+
fn zip_apply_iter<'a, 'j, P: Into<JsonPath<'a>>, C: FromIterator<Option<I>> + 'static, I>(
207201
json_iter: impl Iterator<Item = Option<&'j str>>,
208-
path_array: P,
202+
path_array: impl ArrayAccessor<Item = P>,
209203
jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result<I, GetError>,
210204
) -> C {
211205
json_iter
212-
.zip(path_array)
213-
.map(|(opt_json, opt_path)| {
214-
if let Some(path) = opt_path {
215-
jiter_find(opt_json, &[path]).ok()
216-
} else {
206+
.enumerate()
207+
.map(|(i, opt_json)| {
208+
if path_array.is_null(i) {
217209
None
210+
} else {
211+
let path = path_array.value(i).into();
212+
jiter_find(opt_json, &[path]).ok()
218213
}
219214
})
220215
.collect::<C>()

0 commit comments

Comments
 (0)