@@ -24,39 +24,56 @@ use datafusion::arrow::datatypes::DataType;
2424use datafusion:: arrow:: pyarrow:: FromPyArrow ;
2525use datafusion:: arrow:: pyarrow:: { PyArrowType , ToPyArrow } ;
2626use datafusion:: error:: DataFusionError ;
27- use datafusion:: logical_expr:: create_udf;
2827use datafusion:: logical_expr:: function:: ScalarFunctionImplementation ;
2928use datafusion:: logical_expr:: ScalarUDF ;
29+ use datafusion:: logical_expr:: { create_udf, ColumnarValue } ;
3030
3131use crate :: expr:: PyExpr ;
3232use crate :: utils:: parse_volatility;
3333
34+ /// Create a Rust callable function fr a python function that expects pyarrow arrays
35+ fn pyarrow_function_to_rust (
36+ func : PyObject ,
37+ ) -> impl Fn ( & [ ArrayRef ] ) -> Result < ArrayRef , DataFusionError > {
38+ move |args : & [ ArrayRef ] | -> Result < ArrayRef , DataFusionError > {
39+ Python :: with_gil ( |py| {
40+ // 1. cast args to Pyarrow arrays
41+ let py_args = args
42+ . iter ( )
43+ . map ( |arg| {
44+ arg. into_data ( )
45+ . to_pyarrow ( py)
46+ . map_err ( |e| DataFusionError :: Execution ( format ! ( "{e:?}" ) ) )
47+ } )
48+ . collect :: < Result < Vec < _ > , _ > > ( ) ?;
49+ let py_args = PyTuple :: new_bound ( py, py_args) ;
50+
51+ // 2. call function
52+ let value = func
53+ . call_bound ( py, py_args, None )
54+ . map_err ( |e| DataFusionError :: Execution ( format ! ( "{e:?}" ) ) ) ?;
55+
56+ // 3. cast to arrow::array::Array
57+ let array_data = ArrayData :: from_pyarrow_bound ( value. bind ( py) )
58+ . map_err ( |e| DataFusionError :: Execution ( format ! ( "{e:?}" ) ) ) ?;
59+ Ok ( make_array ( array_data) )
60+ } )
61+ }
62+ }
63+
3464/// Create a DataFusion's UDF implementation from a python function
3565/// that expects pyarrow arrays. This is more efficient as it performs
3666/// a zero-copy of the contents.
37- fn to_rust_function ( func : PyObject ) -> ScalarFunctionImplementation {
38- #[ allow( deprecated) ]
39- datafusion:: physical_plan:: functions:: make_scalar_function (
40- move |args : & [ ArrayRef ] | -> Result < ArrayRef , DataFusionError > {
41- Python :: with_gil ( |py| {
42- // 1. cast args to Pyarrow arrays
43- let py_args = args
44- . iter ( )
45- . map ( |arg| arg. into_data ( ) . to_pyarrow ( py) . unwrap ( ) )
46- . collect :: < Vec < _ > > ( ) ;
47- let py_args = PyTuple :: new_bound ( py, py_args) ;
48-
49- // 2. call function
50- let value = func
51- . call_bound ( py, py_args, None )
52- . map_err ( |e| DataFusionError :: Execution ( format ! ( "{e:?}" ) ) ) ?;
67+ fn to_scalar_function_impl ( func : PyObject ) -> ScalarFunctionImplementation {
68+ // Make the python function callable from rust
69+ let pyarrow_func = pyarrow_function_to_rust ( func) ;
5370
54- // 3. cast to arrow::array::Array
55- let array_data = ArrayData :: from_pyarrow_bound ( value . bind ( py ) ) . unwrap ( ) ;
56- Ok ( make_array ( array_data ) )
57- } )
58- } ,
59- )
71+ // Convert input/output from datafusion ColumnarValue to arrow arrays
72+ Arc :: new ( move | args : & [ ColumnarValue ] | {
73+ let array_refs = ColumnarValue :: values_to_arrays ( args ) ? ;
74+ let array_result = pyarrow_func ( & array_refs ) ? ;
75+ Ok ( array_result . into ( ) )
76+ } )
6077}
6178
6279/// Represents a PyScalarUDF
@@ -82,7 +99,7 @@ impl PyScalarUDF {
8299 input_types. 0 ,
83100 Arc :: new ( return_type. 0 ) ,
84101 parse_volatility ( volatility) ?,
85- to_rust_function ( func) ,
102+ to_scalar_function_impl ( func) ,
86103 ) ;
87104 Ok ( Self { function } )
88105 }
0 commit comments