Skip to content

Commit baf6f60

Browse files
authored
Use return_field instead of return_type for calling aggregates via FFI (apache#17407)
* Use return_field instead of return_type for calling aggregates via FFI * Add note in upgrade guide * typo
1 parent 6a21b67 commit baf6f60

File tree

2 files changed

+108
-18
lines changed

2 files changed

+108
-18
lines changed

datafusion/ffi/src/udaf/mod.rs

Lines changed: 90 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,11 @@ pub struct FFI_AggregateUDF {
6969
/// FFI equivalent to the `volatility` of a [`AggregateUDF`]
7070
pub volatility: FFI_Volatility,
7171

72-
/// Determines the return type of the underlying [`AggregateUDF`] based on the
73-
/// argument types.
74-
pub return_type: unsafe extern "C" fn(
72+
/// Determines the return field of the underlying [`AggregateUDF`] based on the
73+
/// argument fields.
74+
pub return_field: unsafe extern "C" fn(
7575
udaf: &Self,
76-
arg_types: RVec<WrappedSchema>,
76+
arg_fields: RVec<WrappedSchema>,
7777
) -> RResult<WrappedSchema, RString>,
7878

7979
/// FFI equivalent to the `is_nullable` of a [`AggregateUDF`]
@@ -160,20 +160,22 @@ impl FFI_AggregateUDF {
160160
}
161161
}
162162

163-
unsafe extern "C" fn return_type_fn_wrapper(
163+
unsafe extern "C" fn return_field_fn_wrapper(
164164
udaf: &FFI_AggregateUDF,
165-
arg_types: RVec<WrappedSchema>,
165+
arg_fields: RVec<WrappedSchema>,
166166
) -> RResult<WrappedSchema, RString> {
167167
let udaf = udaf.inner();
168168

169-
let arg_types = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types));
169+
let arg_fields = rresult_return!(rvec_wrapped_to_vec_fieldref(&arg_fields));
170170

171-
let return_type = udaf
172-
.return_type(&arg_types)
173-
.and_then(|v| FFI_ArrowSchema::try_from(v).map_err(DataFusionError::from))
171+
let return_field = udaf
172+
.return_field(&arg_fields)
173+
.and_then(|v| {
174+
FFI_ArrowSchema::try_from(v.as_ref()).map_err(DataFusionError::from)
175+
})
174176
.map(WrappedSchema);
175177

176-
rresult!(return_type)
178+
rresult!(return_field)
177179
}
178180

179181
unsafe extern "C" fn accumulator_fn_wrapper(
@@ -346,7 +348,7 @@ impl From<Arc<AggregateUDF>> for FFI_AggregateUDF {
346348
is_nullable,
347349
volatility,
348350
aliases,
349-
return_type: return_type_fn_wrapper,
351+
return_field: return_field_fn_wrapper,
350352
accumulator: accumulator_fn_wrapper,
351353
create_sliding_accumulator: create_sliding_accumulator_fn_wrapper,
352354
create_groups_accumulator: create_groups_accumulator_fn_wrapper,
@@ -425,14 +427,22 @@ impl AggregateUDFImpl for ForeignAggregateUDF {
425427
&self.signature
426428
}
427429

428-
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
429-
let arg_types = vec_datatype_to_rvec_wrapped(arg_types)?;
430+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
431+
unimplemented!()
432+
}
433+
434+
fn return_field(&self, arg_fields: &[FieldRef]) -> Result<FieldRef> {
435+
let arg_fields = vec_fieldref_to_rvec_wrapped(arg_fields)?;
430436

431-
let result = unsafe { (self.udaf.return_type)(&self.udaf, arg_types) };
437+
let result = unsafe { (self.udaf.return_field)(&self.udaf, arg_fields) };
432438

433439
let result = df_result!(result);
434440

435-
result.and_then(|r| (&r.0).try_into().map_err(DataFusionError::from))
441+
result.and_then(|r| {
442+
Field::try_from(&r.0)
443+
.map(Arc::new)
444+
.map_err(DataFusionError::from)
445+
})
436446
}
437447

438448
fn is_nullable(&self) -> bool {
@@ -608,9 +618,43 @@ mod tests {
608618
physical_expr::PhysicalSortExpr, physical_plan::expressions::col,
609619
scalar::ScalarValue,
610620
};
621+
use std::any::Any;
622+
use std::collections::HashMap;
611623

612624
use super::*;
613625

626+
#[derive(Default, Debug, Hash, Eq, PartialEq)]
627+
struct SumWithCopiedMetadata {
628+
inner: Sum,
629+
}
630+
631+
impl AggregateUDFImpl for SumWithCopiedMetadata {
632+
fn as_any(&self) -> &dyn Any {
633+
self
634+
}
635+
636+
fn name(&self) -> &str {
637+
self.inner.name()
638+
}
639+
640+
fn signature(&self) -> &Signature {
641+
self.inner.signature()
642+
}
643+
644+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
645+
unimplemented!()
646+
}
647+
648+
fn return_field(&self, arg_fields: &[FieldRef]) -> Result<FieldRef> {
649+
// Copy the input field, so any metadata gets returned
650+
Ok(Arc::clone(&arg_fields[0]))
651+
}
652+
653+
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
654+
self.inner.accumulator(acc_args)
655+
}
656+
}
657+
614658
fn create_test_foreign_udaf(
615659
original_udaf: impl AggregateUDFImpl + 'static,
616660
) -> Result<AggregateUDF> {
@@ -644,8 +688,11 @@ mod tests {
644688
let foreign_udaf =
645689
create_test_foreign_udaf(Sum::new())?.with_aliases(["my_function"]);
646690

647-
let return_type = foreign_udaf.return_type(&[DataType::Float64])?;
648-
assert_eq!(return_type, DataType::Float64);
691+
let return_field =
692+
foreign_udaf
693+
.return_field(&[Field::new("a", DataType::Float64, true).into()])?;
694+
let return_type = return_field.data_type();
695+
assert_eq!(return_type, &DataType::Float64);
649696
Ok(())
650697
}
651698

@@ -673,6 +720,31 @@ mod tests {
673720
Ok(())
674721
}
675722

723+
#[test]
724+
fn test_round_trip_udaf_metadata() -> Result<()> {
725+
let original_udaf = SumWithCopiedMetadata::default();
726+
let original_udaf = Arc::new(AggregateUDF::from(original_udaf));
727+
728+
// Convert to FFI format
729+
let local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into();
730+
731+
// Convert back to native format
732+
let foreign_udaf: ForeignAggregateUDF = (&local_udaf).try_into()?;
733+
let foreign_udaf: AggregateUDF = foreign_udaf.into();
734+
735+
let metadata: HashMap<String, String> =
736+
[("a_key".to_string(), "a_value".to_string())]
737+
.into_iter()
738+
.collect();
739+
let input_field = Arc::new(
740+
Field::new("a", DataType::Float64, false).with_metadata(metadata.clone()),
741+
);
742+
let return_field = foreign_udaf.return_field(&[input_field])?;
743+
744+
assert_eq!(&metadata, return_field.metadata());
745+
Ok(())
746+
}
747+
676748
#[test]
677749
fn test_beneficial_ordering() -> Result<()> {
678750
let foreign_udaf = create_test_foreign_udaf(

docs/source/library-user-guide/upgrading.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,24 @@ If you have custom implementations of `FileOpener` or work directly with `FileOp
285285

286286
[#17397]: https://github.com/apache/datafusion/pull/17397
287287

288+
### FFI user defined aggregate function signature change
289+
290+
The Foreign Function Interface (FFI) signature for user defined aggregate functions
291+
has been updated to call `return_field` instead of `return_type` on the underlying
292+
aggregate function. This is to support metadata handling with these aggregate functions.
293+
This change should be transparent to most users. If you have written unit tests to call
294+
`return_type` directly, you may need to change them to calling `return_field` instead.
295+
296+
This update is a breaking change to the FFI API. The current best practice when using the
297+
FFI crate is to ensure that all libraries that are interacting are using the same
298+
underlying Rust version. Issue [#17374] has been opened to discuss stabilization of
299+
this interface so that these libraries can be used across different DataFusion versions.
300+
301+
See [#17407] for details.
302+
303+
[#17407]: https://github.com/apache/datafusion/pull/17407
304+
[#17374]: https://github.com/apache/datafusion/issues/17374
305+
288306
### Added `PhysicalExpr::is_volatile_node`
289307

290308
We added a method to `PhysicalExpr` to mark a `PhysicalExpr` as volatile:

0 commit comments

Comments
 (0)