diff --git a/guests/python/src/conversion.rs b/guests/python/src/conversion.rs index 15c70077..8b23504b 100644 --- a/guests/python/src/conversion.rs +++ b/guests/python/src/conversion.rs @@ -2,11 +2,11 @@ use std::{ops::ControlFlow, sync::Arc}; use arrow::{ - array::{Array, ArrayRef, BooleanBuilder, Int64Builder}, + array::{Array, ArrayRef, BooleanBuilder, Float64Builder, Int64Builder}, datatypes::DataType, }; use datafusion_common::{ - cast::{as_boolean_array, as_int64_array}, + cast::{as_boolean_array, as_float64_array, as_int64_array}, error::Result as DataFusionResult, exec_datafusion_err, exec_err, }; @@ -38,6 +38,7 @@ impl PythonType { pub(crate) fn data_type(&self) -> DataType { match self { Self::Bool => DataType::Boolean, + Self::Float => DataType::Float64, Self::Int => DataType::Int64, } } @@ -66,6 +67,23 @@ impl PythonType { Ok(Box::new(it)) } + Self::Float => { + let array = as_float64_array(array)?; + + let it = array.into_iter().map(move |maybe_val| { + maybe_val + .map(|val| { + val.into_bound_py_any(py).map_err(|e| { + exec_datafusion_err!( + "cannot convert Rust `f64` value to Python: {e}" + ) + }) + }) + .transpose() + }); + + Ok(Box::new(it)) + } Self::Int => { let array = as_int64_array(array)?; @@ -92,6 +110,7 @@ impl PythonType { fn python_to_arrow<'py>(&self, num_rows: usize) -> Box> { match self { Self::Bool => Box::new(BooleanBuilder::with_capacity(num_rows)), + Self::Float => Box::new(Float64Builder::with_capacity(num_rows)), Self::Int => Box::new(Int64Builder::with_capacity(num_rows)), } } @@ -218,6 +237,24 @@ impl<'py> ArrayBuilder<'py> for BooleanBuilder { } } +impl<'py> ArrayBuilder<'py> for Float64Builder { + fn push(&mut self, val: Bound<'py, PyAny>) -> DataFusionResult<()> { + let val: f64 = val.extract().map_err(|_| { + exec_datafusion_err!("expected `f64` but got {}", py_representation(&val)) + })?; + self.append_value(val); + Ok(()) + } + + fn skip(&mut self) { + self.append_null(); + } + + fn finish(&mut self) -> ArrayRef { + Arc::new(self.finish()) + } +} + impl<'py> ArrayBuilder<'py> for Int64Builder { fn push(&mut self, val: Bound<'py, PyAny>) -> DataFusionResult<()> { // in Python, `bool` is a sub-class of int we should probably not silently cast bools to integers diff --git a/guests/python/src/inspect.rs b/guests/python/src/inspect.rs index 4907e492..d7c2247e 100644 --- a/guests/python/src/inspect.rs +++ b/guests/python/src/inspect.rs @@ -21,10 +21,13 @@ impl<'py> FromPyObject<'py> for PythonType { // https://docs.python.org/3/library/builtins.html let mod_builtins = py.import(intern!(py, "builtins"))?; let type_bool = mod_builtins.getattr(intern!(py, "bool"))?; + let type_float = mod_builtins.getattr(intern!(py, "float"))?; let type_int = mod_builtins.getattr(intern!(py, "int"))?; if ob.is(type_bool) { Ok(Self::Bool) + } else if ob.is(type_float) { + Ok(Self::Float) } else if ob.is(type_int) { Ok(Self::Int) } else { diff --git a/guests/python/src/signature.rs b/guests/python/src/signature.rs index ed7637be..88d3e28b 100644 --- a/guests/python/src/signature.rs +++ b/guests/python/src/signature.rs @@ -18,6 +18,18 @@ pub(crate) enum PythonType { /// We map this to [`Boolean`](arrow::datatypes::DataType::Boolean). Bool, + /// Float. + /// + /// # Python + /// The type is called `float`, documentation can be found here: + /// + /// - + /// - + /// + /// # Arrow + /// We map this to [`Float64`](arrow::datatypes::DataType::Float64). + Float, + /// Signed integer. /// /// # Python diff --git a/host/tests/integration_tests/python/types/float.rs b/host/tests/integration_tests/python/types/float.rs new file mode 100644 index 00000000..c278f9ac --- /dev/null +++ b/host/tests/integration_tests/python/types/float.rs @@ -0,0 +1,85 @@ +use std::sync::Arc; + +use arrow::{ + array::{Array, Float64Array}, + datatypes::{DataType, Field}, +}; +use datafusion_common::cast::as_float64_array; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; + +use crate::integration_tests::{ + python::test_utils::python_scalar_udf, test_utils::ColumnarValueExt, +}; + +#[tokio::test(flavor = "multi_thread")] +async fn test_roundtrip() { + const CODE: &str = " +def foo(x: float) -> float: + return x +"; + let udf = python_scalar_udf(CODE).await.unwrap(); + + assert_eq!( + udf.signature(), + &Signature::exact(vec![DataType::Float64], Volatility::Volatile), + ); + + assert_eq!( + udf.return_type(&[DataType::Float64]).unwrap(), + DataType::Float64, + ); + + let values = &[ + Some(13.37), + Some(f64::NAN), + Some(f64::NEG_INFINITY), + Some(f64::MIN_POSITIVE), + Some(f64::INFINITY), + Some(f64::MAX), + Some(f64::MIN), + Some(0.0), + Some(-0.0), + ]; + + let array = udf + .invoke_with_args(ScalarFunctionArgs { + args: vec![ColumnarValue::Array(Arc::new(Float64Array::from_iter( + values.iter().copied(), + )))], + arg_fields: vec![Arc::new(Field::new("a1", DataType::Float64, true))], + number_rows: values.len(), + return_field: Arc::new(Field::new("r", DataType::Float64, true)), + }) + .unwrap() + .unwrap_array(); + assert_float_total_eq(&array, values); +} + +#[test] +#[should_panic(expected = "Not equal")] +fn test_assert_float_total_eq_uses_total_eq() { + let array = Float64Array::from_iter(&[Some(-0.0)]); + assert_float_total_eq(&array, &[Some(0.0)]); +} + +#[track_caller] +fn assert_float_total_eq(array: &dyn Array, expected: &[Option]) { + assert_eq!(array.len(), expected.len()); + + let array = as_float64_array(array).unwrap(); + let actual = array.into_iter().collect::>(); + + let is_eq = actual + .iter() + .zip(expected) + .map(|(actual, expected)| match (actual, expected) { + (None, None) => true, + (Some(actual), Some(expected)) => actual.total_cmp(expected).is_eq(), + _ => false, + }) + .collect::>(); + + if !is_eq.iter().all(|eq| *eq) { + panic!("Not equal:\n\nActual:\n{actual:?}\n\nExpected:\n{expected:?}\n\nEq:\n{is_eq:?}"); + } +} diff --git a/host/tests/integration_tests/python/types/mod.rs b/host/tests/integration_tests/python/types/mod.rs index 802437d0..ceae0ff2 100644 --- a/host/tests/integration_tests/python/types/mod.rs +++ b/host/tests/integration_tests/python/types/mod.rs @@ -1,2 +1,3 @@ mod bool; +mod float; mod int;