Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 39 additions & 2 deletions guests/python/src/conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
use std::{ops::ControlFlow, sync::Arc};

use arrow::{
array::{Array, ArrayRef, BooleanBuilder, Int64Builder},
array::{Array, ArrayRef, BooleanBuilder, Float64Builder, Int64Builder},
datatypes::DataType,
};
use datafusion_common::{
cast::{as_boolean_array, as_int64_array},
cast::{as_boolean_array, as_float64_array, as_int64_array},
error::Result as DataFusionResult,
exec_datafusion_err, exec_err,
};
Expand Down Expand Up @@ -38,6 +38,7 @@ impl PythonType {
pub(crate) fn data_type(&self) -> DataType {
match self {
Self::Bool => DataType::Boolean,
Self::Float => DataType::Float64,
Self::Int => DataType::Int64,
}
}
Expand Down Expand Up @@ -66,6 +67,23 @@ impl PythonType {

Ok(Box::new(it))
}
Self::Float => {
let array = as_float64_array(array)?;

let it = array.into_iter().map(move |maybe_val| {
maybe_val
.map(|val| {
val.into_bound_py_any(py).map_err(|e| {
exec_datafusion_err!(
"cannot convert Rust `f64` value to Python: {e}"
)
})
})
.transpose()
});

Ok(Box::new(it))
}
Self::Int => {
let array = as_int64_array(array)?;

Expand All @@ -92,6 +110,7 @@ impl PythonType {
fn python_to_arrow<'py>(&self, num_rows: usize) -> Box<dyn ArrayBuilder<'py>> {
match self {
Self::Bool => Box::new(BooleanBuilder::with_capacity(num_rows)),
Self::Float => Box::new(Float64Builder::with_capacity(num_rows)),
Self::Int => Box::new(Int64Builder::with_capacity(num_rows)),
}
}
Expand Down Expand Up @@ -218,6 +237,24 @@ impl<'py> ArrayBuilder<'py> for BooleanBuilder {
}
}

impl<'py> ArrayBuilder<'py> for Float64Builder {
fn push(&mut self, val: Bound<'py, PyAny>) -> DataFusionResult<()> {
let val: f64 = val.extract().map_err(|_| {
exec_datafusion_err!("expected `f64` but got {}", py_representation(&val))
})?;
self.append_value(val);
Ok(())
}

fn skip(&mut self) {
self.append_null();
}

fn finish(&mut self) -> ArrayRef {
Arc::new(self.finish())
}
}

impl<'py> ArrayBuilder<'py> for Int64Builder {
fn push(&mut self, val: Bound<'py, PyAny>) -> DataFusionResult<()> {
// in Python, `bool` is a sub-class of int we should probably not silently cast bools to integers
Expand Down
3 changes: 3 additions & 0 deletions guests/python/src/inspect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,13 @@ impl<'py> FromPyObject<'py> for PythonType {
// https://docs.python.org/3/library/builtins.html
let mod_builtins = py.import(intern!(py, "builtins"))?;
let type_bool = mod_builtins.getattr(intern!(py, "bool"))?;
let type_float = mod_builtins.getattr(intern!(py, "float"))?;
let type_int = mod_builtins.getattr(intern!(py, "int"))?;

if ob.is(type_bool) {
Ok(Self::Bool)
} else if ob.is(type_float) {
Ok(Self::Float)
} else if ob.is(type_int) {
Ok(Self::Int)
} else {
Expand Down
12 changes: 12 additions & 0 deletions guests/python/src/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,18 @@ pub(crate) enum PythonType {
/// We map this to [`Boolean`](arrow::datatypes::DataType::Boolean).
Bool,

/// Float.
///
/// # Python
/// The type is called `float`, documentation can be found here:
///
/// - <https://docs.python.org/3/library/stdtypes.html#numeric-types-int-float-complex>
/// - <https://docs.python.org/3/library/functions.html#float>
///
/// # Arrow
/// We map this to [`Float64`](arrow::datatypes::DataType::Float64).
Float,

/// Signed integer.
///
/// # Python
Expand Down
85 changes: 85 additions & 0 deletions host/tests/integration_tests/python/types/float.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
use std::sync::Arc;

use arrow::{
array::{Array, Float64Array},
datatypes::{DataType, Field},
};
use datafusion_common::cast::as_float64_array;
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility};

use crate::integration_tests::{
python::test_utils::python_scalar_udf, test_utils::ColumnarValueExt,
};

#[tokio::test(flavor = "multi_thread")]
async fn test_roundtrip() {
const CODE: &str = "
def foo(x: float) -> float:
return x
";
let udf = python_scalar_udf(CODE).await.unwrap();

assert_eq!(
udf.signature(),
&Signature::exact(vec![DataType::Float64], Volatility::Volatile),
);

assert_eq!(
udf.return_type(&[DataType::Float64]).unwrap(),
DataType::Float64,
);

let values = &[
Some(13.37),
Some(f64::NAN),
Some(f64::NEG_INFINITY),
Some(f64::MIN_POSITIVE),
Some(f64::INFINITY),
Some(f64::MAX),
Some(f64::MIN),
Some(0.0),
Some(-0.0),
];

let array = udf
.invoke_with_args(ScalarFunctionArgs {
args: vec![ColumnarValue::Array(Arc::new(Float64Array::from_iter(
values.iter().copied(),
)))],
arg_fields: vec![Arc::new(Field::new("a1", DataType::Float64, true))],
number_rows: values.len(),
return_field: Arc::new(Field::new("r", DataType::Float64, true)),
})
.unwrap()
.unwrap_array();
assert_float_total_eq(&array, values);
}

#[test]
#[should_panic(expected = "Not equal")]
fn test_assert_float_total_eq_uses_total_eq() {
let array = Float64Array::from_iter(&[Some(-0.0)]);
assert_float_total_eq(&array, &[Some(0.0)]);
}

#[track_caller]
fn assert_float_total_eq(array: &dyn Array, expected: &[Option<f64>]) {
assert_eq!(array.len(), expected.len());

let array = as_float64_array(array).unwrap();
let actual = array.into_iter().collect::<Vec<_>>();

let is_eq = actual
.iter()
.zip(expected)
.map(|(actual, expected)| match (actual, expected) {
(None, None) => true,
(Some(actual), Some(expected)) => actual.total_cmp(expected).is_eq(),
_ => false,
})
.collect::<Vec<_>>();

if !is_eq.iter().all(|eq| *eq) {
panic!("Not equal:\n\nActual:\n{actual:?}\n\nExpected:\n{expected:?}\n\nEq:\n{is_eq:?}");
}
}
1 change: 1 addition & 0 deletions host/tests/integration_tests/python/types/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
mod bool;
mod float;
mod int;