Skip to content

Commit dd3d1ef

Browse files
committed
Intermediate work adding ffi scalar udf
1 parent 1391078 commit dd3d1ef

File tree

1 file changed

+24
-1
lines changed

1 file changed

+24
-1
lines changed

src/udf.rs

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
use std::sync::Arc;
1919

20+
use datafusion_ffi::udf::{FFI_ScalarUDF, ForeignScalarUDF};
21+
use pyo3::types::PyCapsule;
2022
use pyo3::{prelude::*, types::PyTuple};
2123

2224
use datafusion::arrow::array::{make_array, Array, ArrayData, ArrayRef};
@@ -29,8 +31,9 @@ use datafusion::logical_expr::ScalarUDF;
2931
use datafusion::logical_expr::{create_udf, ColumnarValue};
3032

3133
use crate::errors::to_datafusion_err;
34+
use crate::errors::{py_datafusion_err, PyDataFusionResult};
3235
use crate::expr::PyExpr;
33-
use crate::utils::parse_volatility;
36+
use crate::utils::{parse_volatility, validate_pycapsule};
3437

3538
/// Create a Rust callable function from a python function that expects pyarrow arrays
3639
fn pyarrow_function_to_rust(
@@ -105,6 +108,26 @@ impl PyScalarUDF {
105108
Ok(Self { function })
106109
}
107110

111+
#[staticmethod]
112+
pub fn from_pycapsule(func: Bound<'_, PyAny>) -> PyDataFusionResult<Self> {
113+
if func.hasattr("__datafusion_scalar_udf__")? {
114+
let capsule = func.getattr("__datafusion_scalar_udf__")?.call0()?;
115+
let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
116+
validate_pycapsule(capsule, "datafusion_scalar_udf")?;
117+
118+
let udf = unsafe { capsule.reference::<FFI_ScalarUDF>() };
119+
let udf: ForeignScalarUDF = udf.try_into()?;
120+
121+
Ok(Self {
122+
function: udf.into(),
123+
})
124+
} else {
125+
Err(crate::errors::PyDataFusionError::Common(
126+
"__datafusion_scalar_udf__ does not exist on ScalarUDF object.".to_string(),
127+
))
128+
}
129+
}
130+
108131
/// creates a new PyExpr with the call of the udf
109132
#[pyo3(signature = (*args))]
110133
fn __call__(&self, args: Vec<PyExpr>) -> PyResult<PyExpr> {

0 commit comments

Comments
 (0)