Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 41 additions & 24 deletions src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,39 +24,56 @@ use datafusion::arrow::datatypes::DataType;
use datafusion::arrow::pyarrow::FromPyArrow;
use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow};
use datafusion::error::DataFusionError;
use datafusion::logical_expr::create_udf;
use datafusion::logical_expr::function::ScalarFunctionImplementation;
use datafusion::logical_expr::ScalarUDF;
use datafusion::logical_expr::{create_udf, ColumnarValue};

use crate::expr::PyExpr;
use crate::utils::parse_volatility;

/// Create a Rust callable function fr a python function that expects pyarrow arrays
fn pyarrow_function_to_rust(
func: PyObject,
) -> impl Fn(&[ArrayRef]) -> Result<ArrayRef, DataFusionError> {
move |args: &[ArrayRef]| -> Result<ArrayRef, DataFusionError> {
Python::with_gil(|py| {
// 1. cast args to Pyarrow arrays
let py_args = args
.iter()
.map(|arg| {
arg.into_data()
.to_pyarrow(py)
.map_err(|e| DataFusionError::Execution(format!("{e:?}")))
})
.collect::<Result<Vec<_>, _>>()?;
let py_args = PyTuple::new_bound(py, py_args);

// 2. call function
let value = func
.call_bound(py, py_args, None)
.map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;

// 3. cast to arrow::array::Array
let array_data = ArrayData::from_pyarrow_bound(value.bind(py))
.map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;
Ok(make_array(array_data))
})
}
}

/// Create a DataFusion's UDF implementation from a python function
/// that expects pyarrow arrays. This is more efficient as it performs
/// a zero-copy of the contents.
fn to_rust_function(func: PyObject) -> ScalarFunctionImplementation {
#[allow(deprecated)]
datafusion::physical_plan::functions::make_scalar_function(
move |args: &[ArrayRef]| -> Result<ArrayRef, DataFusionError> {
Python::with_gil(|py| {
// 1. cast args to Pyarrow arrays
let py_args = args
.iter()
.map(|arg| arg.into_data().to_pyarrow(py).unwrap())
.collect::<Vec<_>>();
let py_args = PyTuple::new_bound(py, py_args);

// 2. call function
let value = func
.call_bound(py, py_args, None)
.map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;
fn to_scalar_function_impl(func: PyObject) -> ScalarFunctionImplementation {
// Make the python function callable from rust
let pyarrow_func = pyarrow_function_to_rust(func);

// 3. cast to arrow::array::Array
let array_data = ArrayData::from_pyarrow_bound(value.bind(py)).unwrap();
Ok(make_array(array_data))
})
},
)
// Convert input/output from datafusion ColumnarValue to arrow arrays
Arc::new(move |args: &[ColumnarValue]| {
let array_refs = ColumnarValue::values_to_arrays(args)?;
let array_result = pyarrow_func(&array_refs)?;
Ok(array_result.into())
})
}

/// Represents a PyScalarUDF
Expand All @@ -82,7 +99,7 @@ impl PyScalarUDF {
input_types.0,
Arc::new(return_type.0),
parse_volatility(volatility)?,
to_rust_function(func),
to_scalar_function_impl(func),
);
Ok(Self { function })
}
Expand Down
Loading