@@ -69,11 +69,11 @@ pub struct FFI_AggregateUDF {
69
69
/// FFI equivalent to the `volatility` of a [`AggregateUDF`]
70
70
pub volatility : FFI_Volatility ,
71
71
72
- /// Determines the return type of the underlying [`AggregateUDF`] based on the
73
- /// argument types .
74
- pub return_type : unsafe extern "C" fn (
72
+ /// Determines the return field of the underlying [`AggregateUDF`] based on the
73
+ /// argument fields .
74
+ pub return_field : unsafe extern "C" fn (
75
75
udaf : & Self ,
76
- arg_types : RVec < WrappedSchema > ,
76
+ arg_fields : RVec < WrappedSchema > ,
77
77
) -> RResult < WrappedSchema , RString > ,
78
78
79
79
/// FFI equivalent to the `is_nullable` of a [`AggregateUDF`]
@@ -160,20 +160,22 @@ impl FFI_AggregateUDF {
160
160
}
161
161
}
162
162
163
- unsafe extern "C" fn return_type_fn_wrapper (
163
+ unsafe extern "C" fn return_field_fn_wrapper (
164
164
udaf : & FFI_AggregateUDF ,
165
- arg_types : RVec < WrappedSchema > ,
165
+ arg_fields : RVec < WrappedSchema > ,
166
166
) -> RResult < WrappedSchema , RString > {
167
167
let udaf = udaf. inner ( ) ;
168
168
169
- let arg_types = rresult_return ! ( rvec_wrapped_to_vec_datatype ( & arg_types ) ) ;
169
+ let arg_fields = rresult_return ! ( rvec_wrapped_to_vec_fieldref ( & arg_fields ) ) ;
170
170
171
- let return_type = udaf
172
- . return_type ( & arg_types)
173
- . and_then ( |v| FFI_ArrowSchema :: try_from ( v) . map_err ( DataFusionError :: from) )
171
+ let return_field = udaf
172
+ . return_field ( & arg_fields)
173
+ . and_then ( |v| {
174
+ FFI_ArrowSchema :: try_from ( v. as_ref ( ) ) . map_err ( DataFusionError :: from)
175
+ } )
174
176
. map ( WrappedSchema ) ;
175
177
176
- rresult ! ( return_type )
178
+ rresult ! ( return_field )
177
179
}
178
180
179
181
unsafe extern "C" fn accumulator_fn_wrapper (
@@ -346,7 +348,7 @@ impl From<Arc<AggregateUDF>> for FFI_AggregateUDF {
346
348
is_nullable,
347
349
volatility,
348
350
aliases,
349
- return_type : return_type_fn_wrapper ,
351
+ return_field : return_field_fn_wrapper ,
350
352
accumulator : accumulator_fn_wrapper,
351
353
create_sliding_accumulator : create_sliding_accumulator_fn_wrapper,
352
354
create_groups_accumulator : create_groups_accumulator_fn_wrapper,
@@ -425,14 +427,22 @@ impl AggregateUDFImpl for ForeignAggregateUDF {
425
427
& self . signature
426
428
}
427
429
428
- fn return_type ( & self , arg_types : & [ DataType ] ) -> Result < DataType > {
429
- let arg_types = vec_datatype_to_rvec_wrapped ( arg_types) ?;
430
+ fn return_type ( & self , _arg_types : & [ DataType ] ) -> Result < DataType > {
431
+ unimplemented ! ( )
432
+ }
433
+
434
+ fn return_field ( & self , arg_fields : & [ FieldRef ] ) -> Result < FieldRef > {
435
+ let arg_fields = vec_fieldref_to_rvec_wrapped ( arg_fields) ?;
430
436
431
- let result = unsafe { ( self . udaf . return_type ) ( & self . udaf , arg_types ) } ;
437
+ let result = unsafe { ( self . udaf . return_field ) ( & self . udaf , arg_fields ) } ;
432
438
433
439
let result = df_result ! ( result) ;
434
440
435
- result. and_then ( |r| ( & r. 0 ) . try_into ( ) . map_err ( DataFusionError :: from) )
441
+ result. and_then ( |r| {
442
+ Field :: try_from ( & r. 0 )
443
+ . map ( Arc :: new)
444
+ . map_err ( DataFusionError :: from)
445
+ } )
436
446
}
437
447
438
448
fn is_nullable ( & self ) -> bool {
@@ -608,9 +618,43 @@ mod tests {
608
618
physical_expr:: PhysicalSortExpr , physical_plan:: expressions:: col,
609
619
scalar:: ScalarValue ,
610
620
} ;
621
+ use std:: any:: Any ;
622
+ use std:: collections:: HashMap ;
611
623
612
624
use super :: * ;
613
625
626
+ #[ derive( Default , Debug , Hash , Eq , PartialEq ) ]
627
+ struct SumWithCopiedMetadata {
628
+ inner : Sum ,
629
+ }
630
+
631
+ impl AggregateUDFImpl for SumWithCopiedMetadata {
632
+ fn as_any ( & self ) -> & dyn Any {
633
+ self
634
+ }
635
+
636
+ fn name ( & self ) -> & str {
637
+ self . inner . name ( )
638
+ }
639
+
640
+ fn signature ( & self ) -> & Signature {
641
+ self . inner . signature ( )
642
+ }
643
+
644
+ fn return_type ( & self , _arg_types : & [ DataType ] ) -> Result < DataType > {
645
+ unimplemented ! ( )
646
+ }
647
+
648
+ fn return_field ( & self , arg_fields : & [ FieldRef ] ) -> Result < FieldRef > {
649
+ // Copy the input field, so any metadata gets returned
650
+ Ok ( Arc :: clone ( & arg_fields[ 0 ] ) )
651
+ }
652
+
653
+ fn accumulator ( & self , acc_args : AccumulatorArgs ) -> Result < Box < dyn Accumulator > > {
654
+ self . inner . accumulator ( acc_args)
655
+ }
656
+ }
657
+
614
658
fn create_test_foreign_udaf (
615
659
original_udaf : impl AggregateUDFImpl + ' static ,
616
660
) -> Result < AggregateUDF > {
@@ -644,8 +688,11 @@ mod tests {
644
688
let foreign_udaf =
645
689
create_test_foreign_udaf ( Sum :: new ( ) ) ?. with_aliases ( [ "my_function" ] ) ;
646
690
647
- let return_type = foreign_udaf. return_type ( & [ DataType :: Float64 ] ) ?;
648
- assert_eq ! ( return_type, DataType :: Float64 ) ;
691
+ let return_field =
692
+ foreign_udaf
693
+ . return_field ( & [ Field :: new ( "a" , DataType :: Float64 , true ) . into ( ) ] ) ?;
694
+ let return_type = return_field. data_type ( ) ;
695
+ assert_eq ! ( return_type, & DataType :: Float64 ) ;
649
696
Ok ( ( ) )
650
697
}
651
698
@@ -673,6 +720,31 @@ mod tests {
673
720
Ok ( ( ) )
674
721
}
675
722
723
+ #[ test]
724
+ fn test_round_trip_udaf_metadata ( ) -> Result < ( ) > {
725
+ let original_udaf = SumWithCopiedMetadata :: default ( ) ;
726
+ let original_udaf = Arc :: new ( AggregateUDF :: from ( original_udaf) ) ;
727
+
728
+ // Convert to FFI format
729
+ let local_udaf: FFI_AggregateUDF = Arc :: clone ( & original_udaf) . into ( ) ;
730
+
731
+ // Convert back to native format
732
+ let foreign_udaf: ForeignAggregateUDF = ( & local_udaf) . try_into ( ) ?;
733
+ let foreign_udaf: AggregateUDF = foreign_udaf. into ( ) ;
734
+
735
+ let metadata: HashMap < String , String > =
736
+ [ ( "a_key" . to_string ( ) , "a_value" . to_string ( ) ) ]
737
+ . into_iter ( )
738
+ . collect ( ) ;
739
+ let input_field = Arc :: new (
740
+ Field :: new ( "a" , DataType :: Float64 , false ) . with_metadata ( metadata. clone ( ) ) ,
741
+ ) ;
742
+ let return_field = foreign_udaf. return_field ( & [ input_field] ) ?;
743
+
744
+ assert_eq ! ( & metadata, return_field. metadata( ) ) ;
745
+ Ok ( ( ) )
746
+ }
747
+
676
748
#[ test]
677
749
fn test_beneficial_ordering ( ) -> Result < ( ) > {
678
750
let foreign_udaf = create_test_foreign_udaf (
0 commit comments