@@ -24,39 +24,51 @@ use datafusion::arrow::datatypes::DataType;
2424use datafusion:: arrow:: pyarrow:: FromPyArrow ;
2525use datafusion:: arrow:: pyarrow:: { PyArrowType , ToPyArrow } ;
2626use datafusion:: error:: DataFusionError ;
27- use datafusion:: logical_expr:: create_udf;
2827use datafusion:: logical_expr:: function:: ScalarFunctionImplementation ;
2928use datafusion:: logical_expr:: ScalarUDF ;
29+ use datafusion:: logical_expr:: { create_udf, ColumnarValue } ;
3030
3131use crate :: expr:: PyExpr ;
3232use crate :: utils:: parse_volatility;
3333
34+ /// Create a Rust callable function fr a python function that expects pyarrow arrays
35+ fn pyarrow_function_to_rust (
36+ func : PyObject ,
37+ ) -> impl Fn ( & [ ArrayRef ] ) -> Result < ArrayRef , DataFusionError > {
38+ move |args : & [ ArrayRef ] | -> Result < ArrayRef , DataFusionError > {
39+ Python :: with_gil ( |py| {
40+ // 1. cast args to Pyarrow arrays
41+ let py_args = args
42+ . iter ( )
43+ . map ( |arg| arg. into_data ( ) . to_pyarrow ( py) . unwrap ( ) )
44+ . collect :: < Vec < _ > > ( ) ;
45+ let py_args = PyTuple :: new_bound ( py, py_args) ;
46+
47+ // 2. call function
48+ let value = func
49+ . call_bound ( py, py_args, None )
50+ . map_err ( |e| DataFusionError :: Execution ( format ! ( "{e:?}" ) ) ) ?;
51+
52+ // 3. cast to arrow::array::Array
53+ let array_data = ArrayData :: from_pyarrow_bound ( value. bind ( py) ) . unwrap ( ) ;
54+ Ok ( make_array ( array_data) )
55+ } )
56+ }
57+ }
58+
3459/// Create a DataFusion's UDF implementation from a python function
3560/// that expects pyarrow arrays. This is more efficient as it performs
3661/// a zero-copy of the contents.
37- fn to_rust_function ( func : PyObject ) -> ScalarFunctionImplementation {
38- #[ allow( deprecated) ]
39- datafusion:: physical_plan:: functions:: make_scalar_function (
40- move |args : & [ ArrayRef ] | -> Result < ArrayRef , DataFusionError > {
41- Python :: with_gil ( |py| {
42- // 1. cast args to Pyarrow arrays
43- let py_args = args
44- . iter ( )
45- . map ( |arg| arg. into_data ( ) . to_pyarrow ( py) . unwrap ( ) )
46- . collect :: < Vec < _ > > ( ) ;
47- let py_args = PyTuple :: new_bound ( py, py_args) ;
48-
49- // 2. call function
50- let value = func
51- . call_bound ( py, py_args, None )
52- . map_err ( |e| DataFusionError :: Execution ( format ! ( "{e:?}" ) ) ) ?;
62+ fn to_scalar_function_impl ( func : PyObject ) -> ScalarFunctionImplementation {
63+ // Make the python function callable from rust
64+ let pyarrow_func = pyarrow_function_to_rust ( func) ;
5365
54- // 3. cast to arrow::array::Array
55- let array_data = ArrayData :: from_pyarrow_bound ( value . bind ( py ) ) . unwrap ( ) ;
56- Ok ( make_array ( array_data ) )
57- } )
58- } ,
59- )
66+ // Convert input/output from datafusion ColumnarValue to arrow arrays
67+ Arc :: new ( move | args : & [ ColumnarValue ] | {
68+ let array_refs = ColumnarValue :: values_to_arrays ( args ) ? ;
69+ let array_result = pyarrow_func ( & array_refs ) ? ;
70+ Ok ( array_result . into ( ) )
71+ } )
6072}
6173
6274/// Represents a PyScalarUDF
@@ -82,7 +94,7 @@ impl PyScalarUDF {
8294 input_types. 0 ,
8395 Arc :: new ( return_type. 0 ) ,
8496 parse_volatility ( volatility) ?,
85- to_rust_function ( func) ,
97+ to_scalar_function_impl ( func) ,
8698 ) ;
8799 Ok ( Self { function } )
88100 }
0 commit comments