1717
1818use std:: sync:: Arc ;
1919
20+ use datafusion_ffi:: udf:: { FFI_ScalarUDF , ForeignScalarUDF } ;
21+ use pyo3:: types:: PyCapsule ;
2022use pyo3:: { prelude:: * , types:: PyTuple } ;
2123
2224use datafusion:: arrow:: array:: { make_array, Array , ArrayData , ArrayRef } ;
@@ -29,8 +31,9 @@ use datafusion::logical_expr::ScalarUDF;
2931use datafusion:: logical_expr:: { create_udf, ColumnarValue } ;
3032
3133use crate :: errors:: to_datafusion_err;
34+ use crate :: errors:: { py_datafusion_err, PyDataFusionResult } ;
3235use crate :: expr:: PyExpr ;
33- use crate :: utils:: parse_volatility;
36+ use crate :: utils:: { parse_volatility, validate_pycapsule } ;
3437
3538/// Create a Rust callable function from a python function that expects pyarrow arrays
3639fn pyarrow_function_to_rust (
@@ -105,6 +108,26 @@ impl PyScalarUDF {
105108 Ok ( Self { function } )
106109 }
107110
111+ #[ staticmethod]
112+ pub fn from_pycapsule ( func : Bound < ' _ , PyAny > ) -> PyDataFusionResult < Self > {
113+ if func. hasattr ( "__datafusion_scalar_udf__" ) ? {
114+ let capsule = func. getattr ( "__datafusion_scalar_udf__" ) ?. call0 ( ) ?;
115+ let capsule = capsule. downcast :: < PyCapsule > ( ) . map_err ( py_datafusion_err) ?;
116+ validate_pycapsule ( capsule, "datafusion_scalar_udf" ) ?;
117+
118+ let udf = unsafe { capsule. reference :: < FFI_ScalarUDF > ( ) } ;
119+ let udf: ForeignScalarUDF = udf. try_into ( ) ?;
120+
121+ Ok ( Self {
122+ function : udf. into ( ) ,
123+ } )
124+ } else {
125+ Err ( crate :: errors:: PyDataFusionError :: Common (
126+ "__datafusion_scalar_udf__ does not exist on ScalarUDF object." . to_string ( ) ,
127+ ) )
128+ }
129+ }
130+
108131 /// creates a new PyExpr with the call of the udf
109132 #[ pyo3( signature = ( * args) ) ]
110133 fn __call__ ( & self , args : Vec < PyExpr > ) -> PyResult < PyExpr > {
0 commit comments