@@ -3,18 +3,20 @@ use std::sync::Arc;
33
44use datafusion:: arrow:: array:: {
55 Array , ArrayAccessor , ArrayRef , AsArray , DictionaryArray , Int64Array , LargeStringArray , PrimitiveArray ,
6- StringArray , StringViewArray , UInt64Array , UnionArray ,
6+ StringArray , StringViewArray , UInt64Array ,
77} ;
88use datafusion:: arrow:: compute:: take;
99use datafusion:: arrow:: datatypes:: {
10- ArrowDictionaryKeyType , ArrowNativeType , ArrowPrimitiveType , DataType , Int64Type , UInt64Type ,
10+ ArrowDictionaryKeyType , ArrowNativeType , ArrowNativeTypeOp , DataType , Int64Type , UInt64Type ,
1111} ;
1212use datafusion:: arrow:: downcast_dictionary_array;
1313use datafusion:: common:: { exec_err, plan_err, Result as DataFusionResult , ScalarValue } ;
1414use datafusion:: logical_expr:: ColumnarValue ;
1515use jiter:: { Jiter , JiterError , Peek } ;
1616
17- use crate :: common_union:: { is_json_union, json_from_union_scalar, nested_json_array, TYPE_ID_NULL } ;
17+ use crate :: common_union:: {
18+ is_json_union, json_from_union_scalar, nested_json_array, nested_json_array_ref, TYPE_ID_NULL ,
19+ } ;
1820
1921/// General implementation of `ScalarUDFImpl::return_type`.
2022///
@@ -95,6 +97,7 @@ impl From<i64> for JsonPath<'_> {
9597 }
9698}
9799
100+ #[ derive( Debug ) ]
98101enum JsonPathArgs < ' a > {
99102 Array ( & ' a ArrayRef ) ,
100103 Scalars ( Vec < JsonPath < ' a > > ) ,
@@ -175,9 +178,48 @@ fn invoke_array_array<C: FromIterator<Option<I>> + 'static, I>(
175178) -> DataFusionResult < ArrayRef > {
176179 downcast_dictionary_array ! (
177180 json_array => {
178- let values = invoke_array_array( json_array. values( ) , path_array, to_array, jiter_find, return_dict) ?;
179- post_process_dict( json_array, values, return_dict)
180- }
181+ fn wrap_as_dictionary<K : ArrowDictionaryKeyType >( original: & DictionaryArray <K >, new_values: ArrayRef ) -> DictionaryArray <K > {
182+ assert_eq!( original. keys( ) . len( ) , new_values. len( ) ) ;
183+ let mut key = K :: Native :: ZERO ;
184+ let key_range = std:: iter:: from_fn( move || {
185+ let next = key;
186+ key = key. add_checked( K :: Native :: ONE ) . expect( "keys exhausted" ) ;
187+ Some ( next)
188+ } ) . take( new_values. len( ) ) ;
189+ let mut keys = PrimitiveArray :: <K >:: from_iter_values( key_range) ;
190+ if is_json_union( new_values. data_type( ) ) {
191+ // JSON union: post-process the array to set keys to null where the union member is null
192+ let type_ids = new_values. as_union( ) . type_ids( ) ;
193+ keys = mask_dictionary_keys( & keys, type_ids) ;
194+ }
195+ DictionaryArray :: <K >:: new( keys, new_values)
196+ }
197+
198+ // TODO: in theory if path_array is _also_ a dictionary we could work out the unique key
199+ // combinations and do less work, but this can be left as a future optimization
200+ let output = match json_array. values( ) . data_type( ) {
201+ DataType :: Utf8 => zip_apply( json_array. downcast_dict:: <StringArray >( ) . unwrap( ) , path_array, to_array, jiter_find) ,
202+ DataType :: LargeUtf8 => zip_apply( json_array. downcast_dict:: <LargeStringArray >( ) . unwrap( ) , path_array, to_array, jiter_find) ,
203+ DataType :: Utf8View => zip_apply( json_array. downcast_dict:: <StringViewArray >( ) . unwrap( ) , path_array, to_array, jiter_find) ,
204+ other => if let Some ( child_array) = nested_json_array_ref( json_array. values( ) , is_object_lookup_array( path_array. data_type( ) ) ) {
205+ // Horrible case: dict containing union as input with array for paths, figure
206+ // out from the path type which union members we should access, repack the
207+ // dictionary and then recurse.
208+ //
209+ // Use direct return because if return_dict applies, the recursion will handle it.
210+ return invoke_array_array( & ( Arc :: new( json_array. with_values( child_array. clone( ) ) ) as _) , path_array, to_array, jiter_find, return_dict)
211+ } else {
212+ exec_err!( "unexpected json array type {:?}" , other)
213+ }
214+ } ?;
215+
216+ if return_dict {
217+ // ensure return is a dictionary to satisfy the declaration above in return_type_check
218+ Ok ( Arc :: new( wrap_as_dictionary( json_array, output) ) )
219+ } else {
220+ Ok ( output)
221+ }
222+ } ,
181223 DataType :: Utf8 => zip_apply( json_array. as_string:: <i32 >( ) . iter( ) , path_array, to_array, jiter_find) ,
182224 DataType :: LargeUtf8 => zip_apply( json_array. as_string:: <i64 >( ) . iter( ) , path_array, to_array, jiter_find) ,
183225 DataType :: Utf8View => zip_apply( json_array. as_string_view( ) . iter( ) , path_array, to_array, jiter_find) ,
@@ -239,6 +281,7 @@ fn invoke_scalar_array<C: FromIterator<Option<I>> + 'static, I>(
239281 to_array,
240282 jiter_find,
241283 )
284+ // FIXME edge cases where scalar is wrapped in a dictionary, should return a dictionary?
242285 . map ( ColumnarValue :: Array )
243286}
244287
@@ -250,6 +293,7 @@ fn invoke_scalar_scalars<I>(
250293) -> DataFusionResult < ColumnarValue > {
251294 let s = extract_json_scalar ( scalar) ?;
252295 let v = jiter_find ( s, path) . ok ( ) ;
296+ // FIXME edge cases where scalar is wrapped in a dictionary, should return a dictionary?
253297 Ok ( ColumnarValue :: Scalar ( to_scalar ( v) ) )
254298}
255299
@@ -321,7 +365,7 @@ fn post_process_dict<T: ArrowDictionaryKeyType>(
321365 if return_dict {
322366 if is_json_union ( result_values. data_type ( ) ) {
323367 // JSON union: post-process the array to set keys to null where the union member is null
324- let type_ids = result_values. as_any ( ) . downcast_ref :: < UnionArray > ( ) . unwrap ( ) . type_ids ( ) ;
368+ let type_ids = result_values. as_union ( ) . type_ids ( ) ;
325369 Ok ( Arc :: new ( DictionaryArray :: new (
326370 mask_dictionary_keys ( dict_array. keys ( ) , type_ids) ,
327371 result_values,
@@ -413,7 +457,7 @@ impl From<Utf8Error> for GetError {
413457///
414458/// That said, doing this might also be an optimization for cases like null-checking without needing
415459/// to check the value union array.
416- fn mask_dictionary_keys < K : ArrowPrimitiveType > ( keys : & PrimitiveArray < K > , type_ids : & [ i8 ] ) -> PrimitiveArray < K > {
460+ fn mask_dictionary_keys < K : ArrowDictionaryKeyType > ( keys : & PrimitiveArray < K > , type_ids : & [ i8 ] ) -> PrimitiveArray < K > {
417461 let mut null_mask = vec ! [ true ; keys. len( ) ] ;
418462 for ( i, k) in keys. iter ( ) . enumerate ( ) {
419463 match k {
0 commit comments