Skip to content

Commit d904d08

Browse files
committed
Removed unnecessary cloning of scalar value when going from rust to python. Also removed the rust unit tests copied over from upstream repo that were failing due to #941 in pyo3
1 parent f0d25a2 commit d904d08

File tree

3 files changed

+11
-102
lines changed

3 files changed

+11
-102
lines changed

src/expr.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ use std::convert::{From, Into};
2424
use std::sync::Arc;
2525
use window::PyWindowFrame;
2626

27-
use arrow::pyarrow::ToPyArrow;
2827
use datafusion::arrow::datatypes::{DataType, Field};
2928
use datafusion::arrow::pyarrow::PyArrowType;
3029
use datafusion::functions::core::expr_ext::FieldAccessor;
@@ -41,6 +40,7 @@ use crate::expr::binary_expr::PyBinaryExpr;
4140
use crate::expr::column::PyColumn;
4241
use crate::expr::literal::PyLiteral;
4342
use crate::functions::add_builder_fns_to_window;
43+
use crate::pyarrow_util::scalar_to_pyarrow;
4444
use crate::sql::logical::PyLogicalPlan;
4545

4646
use self::alias::PyAlias;
@@ -355,7 +355,7 @@ impl PyExpr {
355355
/// Extracts the Expr value into a PyObject that can be shared with Python
356356
pub fn python_value(&self, py: Python) -> PyResult<PyObject> {
357357
match &self.expr {
358-
Expr::Literal(scalar_value) => Ok(PyScalarValue(scalar_value.clone()).to_pyarrow(py)?),
358+
Expr::Literal(scalar_value) => scalar_to_pyarrow(scalar_value, py),
359359
_ => Err(py_type_err(format!(
360360
"Non Expr::Literal encountered in types: {:?}",
361361
&self.expr

src/pyarrow_filter_expression.rs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,11 @@ use pyo3::prelude::*;
2121
use std::convert::TryFrom;
2222
use std::result::Result;
2323

24-
use arrow::pyarrow::ToPyArrow;
2524
use datafusion::common::{Column, ScalarValue};
2625
use datafusion::logical_expr::{expr::InList, Between, BinaryExpr, Expr, Operator};
2726

28-
use crate::common::data_type::PyScalarValue;
2927
use crate::errors::PyDataFusionError;
28+
use crate::pyarrow_util::scalar_to_pyarrow;
3029

3130
#[derive(Debug)]
3231
#[repr(transparent)]
@@ -103,9 +102,7 @@ impl TryFrom<&Expr> for PyArrowFilterExpression {
103102
let op_module = Python::import_bound(py, "operator")?;
104103
let pc_expr: Result<Bound<'_, PyAny>, PyDataFusionError> = match expr {
105104
Expr::Column(Column { name, .. }) => Ok(pc.getattr("field")?.call1((name,))?),
106-
Expr::Literal(scalar) => {
107-
Ok(PyScalarValue(scalar.clone()).to_pyarrow(py)?.into_bound(py))
108-
}
105+
Expr::Literal(scalar) => Ok(scalar_to_pyarrow(scalar, py)?.into_bound(py)),
109106
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
110107
let operator = operator_to_py(op, &op_module)?;
111108
let left = PyArrowFilterExpression::try_from(left.as_ref())?.0;

src/pyarrow_util.rs

Lines changed: 7 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use arrow::array::{Array, ArrayData};
2121
use arrow::pyarrow::{FromPyArrow, ToPyArrow};
2222
use datafusion::scalar::ScalarValue;
2323
use pyo3::types::{PyAnyMethods, PyList};
24-
use pyo3::{Bound, FromPyObject, IntoPy, PyAny, PyObject, PyResult, Python};
24+
use pyo3::{Bound, FromPyObject, PyAny, PyObject, PyResult, Python};
2525

2626
use crate::common::data_type::PyScalarValue;
2727
use crate::errors::PyDataFusionError;
@@ -45,105 +45,17 @@ impl FromPyArrow for PyScalarValue {
4545
}
4646
}
4747

48-
impl ToPyArrow for PyScalarValue {
49-
fn to_pyarrow(&self, py: Python) -> PyResult<PyObject> {
50-
let array = self.0.to_array().map_err(PyDataFusionError::from)?;
51-
// convert to pyarrow array using C data interface
52-
let pyarray = array.to_data().to_pyarrow(py)?;
53-
let pyscalar = pyarray.call_method1(py, "__getitem__", (0,))?;
54-
55-
Ok(pyscalar)
56-
}
57-
}
58-
5948
impl<'source> FromPyObject<'source> for PyScalarValue {
6049
fn extract_bound(value: &Bound<'source, PyAny>) -> PyResult<Self> {
6150
Self::from_pyarrow_bound(value)
6251
}
6352
}
6453

65-
impl IntoPy<PyObject> for PyScalarValue {
66-
fn into_py(self, py: Python) -> PyObject {
67-
self.to_pyarrow(py).unwrap()
68-
}
69-
}
70-
71-
#[cfg(test)]
72-
mod tests {
73-
use pyo3::prepare_freethreaded_python;
74-
use pyo3::py_run;
75-
use pyo3::types::PyDict;
76-
77-
use super::*;
78-
79-
fn init_python() {
80-
prepare_freethreaded_python();
81-
Python::with_gil(|py| {
82-
if py.run_bound("import pyarrow", None, None).is_err() {
83-
let locals = PyDict::new_bound(py);
84-
py.run_bound(
85-
"import sys; executable = sys.executable; python_path = sys.path",
86-
None,
87-
Some(&locals),
88-
)
89-
.expect("Couldn't get python info");
90-
let executable = locals.get_item("executable").unwrap();
91-
let executable: String = executable.extract().unwrap();
92-
93-
let python_path = locals.get_item("python_path").unwrap();
94-
let python_path: Vec<String> = python_path.extract().unwrap();
54+
pub fn scalar_to_pyarrow(scalar: &ScalarValue, py: Python) -> PyResult<PyObject> {
55+
let array = scalar.to_array().map_err(PyDataFusionError::from)?;
56+
// convert to pyarrow array using C data interface
57+
let pyarray = array.to_data().to_pyarrow(py)?;
58+
let pyscalar = pyarray.call_method1(py, "__getitem__", (0,))?;
9559

96-
panic!(
97-
"pyarrow not found\nExecutable: {executable}\nPython path: {python_path:?}\n\
98-
HINT: try `pip install pyarrow`\n\
99-
NOTE: On Mac OS, you must compile against a Framework Python \
100-
(default in python.org installers and brew, but not pyenv)\n\
101-
NOTE: On Mac OS, PYO3 might point to incorrect Python library \
102-
path when using virtual environments. Try \
103-
`export PYTHONPATH=$(python -c \"import sys; print(sys.path[-1])\")`\n"
104-
)
105-
}
106-
})
107-
}
108-
109-
#[test]
110-
fn test_roundtrip() {
111-
init_python();
112-
113-
let example_scalars = vec![
114-
ScalarValue::Boolean(Some(true)),
115-
ScalarValue::Int32(Some(23)),
116-
ScalarValue::Float64(Some(12.34)),
117-
ScalarValue::from("Hello!"),
118-
ScalarValue::Date32(Some(1234)),
119-
];
120-
121-
Python::with_gil(|py| {
122-
for scalar in example_scalars.into_iter() {
123-
let scalar = PyScalarValue(scalar);
124-
let result =
125-
PyScalarValue::from_pyarrow_bound(scalar.to_pyarrow(py).unwrap().bind(py))
126-
.unwrap();
127-
assert_eq!(scalar, result);
128-
}
129-
});
130-
}
131-
132-
#[test]
133-
fn test_py_scalar() {
134-
init_python();
135-
136-
// TODO: remove this attribute when bumping pyo3 to v0.23.0
137-
// See: <https://github.com/PyO3/pyo3/blob/v0.23.0/guide/src/migration.md#gil-refs-feature-removed>
138-
#[allow(unexpected_cfgs)]
139-
Python::with_gil(|py| {
140-
let scalar_float = PyScalarValue(ScalarValue::Float64(Some(12.34)));
141-
let py_float = scalar_float.into_py(py).call_method0(py, "as_py").unwrap();
142-
py_run!(py, py_float, "assert py_float == 12.34");
143-
144-
let scalar_string = PyScalarValue(ScalarValue::Utf8(Some("Hello!".to_string())));
145-
let py_string = scalar_string.into_py(py).call_method0(py, "as_py").unwrap();
146-
py_run!(py, py_string, "assert py_string == 'Hello!'");
147-
});
148-
}
60+
Ok(pyscalar)
14961
}

0 commit comments

Comments
 (0)