Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions python/datafusion/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,8 +562,6 @@ def literal(value: Any) -> Expr:
"""
if isinstance(value, str):
value = pa.scalar(value, type=pa.string_view())
if not isinstance(value, pa.Scalar):
value = pa.scalar(value)
return Expr(expr_internal.RawExpr.literal(value))

@staticmethod
Expand All @@ -576,7 +574,6 @@ def literal_with_metadata(value: Any, metadata: dict[str, str]) -> Expr:
"""
if isinstance(value, str):
value = pa.scalar(value, type=pa.string_view())
value = value if isinstance(value, pa.Scalar) else pa.scalar(value)

return Expr(expr_internal.RawExpr.literal_with_metadata(value, metadata))

Expand Down
30 changes: 30 additions & 0 deletions python/tests/test_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from datetime import date, datetime, time, timezone
from decimal import Decimal

import arro3.core
import nanoarrow
import pyarrow as pa
import pytest
from datafusion import (
Expand Down Expand Up @@ -980,6 +982,34 @@ def test_literal_metadata(ctx):
assert expected_field.metadata == actual_field.metadata


def test_scalar_conversion() -> None:
expected_value = lit(1)
assert str(expected_value) == "Expr(Int64(1))"

# Test pyarrow imports
assert expected_value == lit(pa.scalar(1))
assert expected_value == lit(pa.scalar(1, type=pa.int32()))

# Test nanoarrow
na_scalar = nanoarrow.Array([1], nanoarrow.int32())[0]
assert expected_value == lit(na_scalar)

# Test pyo3
arro3_scalar = arro3.core.Scalar(1, type=arro3.core.DataType.int32())
assert expected_value == lit(arro3_scalar)

expected_value = lit([1, 2, 3])
assert str(expected_value) == "Expr(List([1, 2, 3]))"

assert expected_value == lit(pa.scalar([1, 2, 3]))

na_array = nanoarrow.Array([1, 2, 3], nanoarrow.int32())
assert expected_value == lit(na_array)

arro3_array = arro3.core.Array([1, 2, 3], type=arro3.core.DataType.int32())
assert expected_value == lit(arro3_array)


def test_ensure_expr():
e = col("a")
assert ensure_expr(e) is e.expr
Expand Down
6 changes: 3 additions & 3 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ use parking_lot::RwLock;
use pyo3::prelude::*;
use pyo3::types::*;

use crate::common::data_type::PyScalarValue;
use crate::errors::PyDataFusionResult;
use crate::utils::py_obj_to_scalar_value;
#[pyclass(name = "Config", module = "datafusion", subclass, frozen)]
#[derive(Clone)]
pub(crate) struct PyConfig {
Expand Down Expand Up @@ -65,9 +65,9 @@ impl PyConfig {

/// Set a configuration option
pub fn set(&self, key: &str, value: Py<PyAny>, py: Python) -> PyDataFusionResult<()> {
let scalar_value = py_obj_to_scalar_value(py, value)?;
let scalar_value: PyScalarValue = value.extract(py)?;
let mut options = self.config.write();
options.set(key, scalar_value.to_string().as_str())?;
options.set(key, scalar_value.0.to_string().as_str())?;
Ok(())
}

Expand Down
9 changes: 4 additions & 5 deletions src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,15 @@ use pyo3::pybacked::PyBackedStr;
use pyo3::types::{PyCapsule, PyList, PyTuple, PyTupleMethods};
use pyo3::PyErr;

use crate::common::data_type::PyScalarValue;
use crate::errors::{py_datafusion_err, PyDataFusionError, PyDataFusionResult};
use crate::expr::sort_expr::{to_sort_expressions, PySortExpr};
use crate::expr::PyExpr;
use crate::physical_plan::PyExecutionPlan;
use crate::record_batch::{poll_next_batch, PyRecordBatchStream};
use crate::sql::logical::PyLogicalPlan;
use crate::table::{PyTable, TempViewTable};
use crate::utils::{
is_ipython_env, py_obj_to_scalar_value, spawn_future, validate_pycapsule, wait_for_future,
};
use crate::utils::{is_ipython_env, spawn_future, validate_pycapsule, wait_for_future};

/// File-level static CStr for the Arrow array stream capsule name.
static ARROW_ARRAY_STREAM_NAME: &CStr = cstr!("arrow_array_stream");
Expand Down Expand Up @@ -1191,14 +1190,14 @@ impl PyDataFrame {
columns: Option<Vec<PyBackedStr>>,
py: Python,
) -> PyDataFusionResult<Self> {
let scalar_value = py_obj_to_scalar_value(py, value)?;
let scalar_value: PyScalarValue = value.extract(py)?;

let cols = match columns {
Some(col_names) => col_names.iter().map(|c| c.to_string()).collect(),
None => Vec::new(), // Empty vector means fill null for all columns
};

let df = self.df.as_ref().clone().fill_null(scalar_value, cols)?;
let df = self.df.as_ref().clone().fill_null(scalar_value.0, cols)?;
Ok(Self::new(df))
}
}
Expand Down
117 changes: 107 additions & 10 deletions src/pyarrow_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,30 +17,127 @@

//! Conversions between PyArrow and DataFusion types

use arrow::array::{Array, ArrayData};
use std::sync::Arc;

use arrow::array::{make_array, Array, ArrayData, ListArray};
use arrow::buffer::OffsetBuffer;
use arrow::datatypes::Field;
use arrow::pyarrow::{FromPyArrow, ToPyArrow};
use datafusion::common::exec_err;
use datafusion::scalar::ScalarValue;
use pyo3::types::{PyAnyMethods, PyList};
use pyo3::{Bound, FromPyObject, PyAny, PyResult, Python};

use crate::common::data_type::PyScalarValue;
use crate::errors::PyDataFusionError;

fn pyobj_extract_scalar_via_capsule(
value: &Bound<'_, PyAny>,
as_list_array: bool,
) -> PyResult<PyScalarValue> {
let array_data = ArrayData::from_pyarrow_bound(value)?;
let array = make_array(array_data);

if as_list_array {
let field = Arc::new(Field::new_list_field(
array.data_type().clone(),
array.nulls().is_some(),
));
let offsets = OffsetBuffer::from_lengths(vec![array.len()]);
let list_array = ListArray::new(field, offsets, array, None);
Ok(PyScalarValue(ScalarValue::List(Arc::new(list_array))))
} else {
let scalar = ScalarValue::try_from_array(&array, 0).map_err(PyDataFusionError::from)?;
Ok(PyScalarValue(scalar))
}
}

impl FromPyArrow for PyScalarValue {
fn from_pyarrow_bound(value: &Bound<'_, PyAny>) -> PyResult<Self> {
let py = value.py();
let typ = value.getattr("type")?;
let pyarrow_mod = py.import("pyarrow");

// construct pyarrow array from the python value and pyarrow type
let factory = py.import("pyarrow")?.getattr("array")?;
let args = PyList::new(py, [value])?;
let array = factory.call1((args, typ))?;
// Is it a PyArrow object?
if let Ok(pa) = pyarrow_mod.as_ref() {
let scalar_type = pa.getattr("Scalar")?;
if value.is_instance(&scalar_type)? {
let typ = value.getattr("type")?;

// convert the pyarrow array to rust array using C data interface
let array = arrow::array::make_array(ArrayData::from_pyarrow_bound(&array)?);
let scalar = ScalarValue::try_from_array(&array, 0).map_err(PyDataFusionError::from)?;
// construct pyarrow array from the python value and pyarrow type
let factory = py.import("pyarrow")?.getattr("array")?;
let args = PyList::new(py, [value])?;
let array = factory.call1((args, typ))?;

Ok(PyScalarValue(scalar))
return pyobj_extract_scalar_via_capsule(&array, false);
}

let array_type = pa.getattr("Array")?;
if value.is_instance(&array_type)? {
return pyobj_extract_scalar_via_capsule(value, true);
}
}

// Is it a NanoArrow scalar?
if let Ok(na) = py.import("nanoarrow") {
let type_name = value.get_type().repr()?;
if type_name.contains("nanoarrow")? && type_name.contains("Scalar")? {
return pyobj_extract_scalar_via_capsule(value, false);
}
let array_type = na.getattr("Array")?;
if value.is_instance(&array_type)? {
return pyobj_extract_scalar_via_capsule(value, true);
}
}

// Is it a arro3 scalar?
if let Ok(arro3) = py.import("arro3").and_then(|arro3| arro3.getattr("core")) {
let scalar_type = arro3.getattr("Scalar")?;
if value.is_instance(&scalar_type)? {
return pyobj_extract_scalar_via_capsule(value, false);
}
let array_type = arro3.getattr("Array")?;
if value.is_instance(&array_type)? {
return pyobj_extract_scalar_via_capsule(value, true);
}
}

// Does it have a PyCapsule interface but isn't one of our known libraries?
// If so do our "best guess". Try checking type name, and if that fails
// return a single value if the length is 1 and return a List value otherwise
if value.hasattr("__arrow_c_array__")? {
let type_name = value.get_type().repr()?;
if type_name.contains("Scalar")? {
return pyobj_extract_scalar_via_capsule(value, false);
}
if type_name.contains("Array")? {
return pyobj_extract_scalar_via_capsule(value, true);
}

let array_data = ArrayData::from_pyarrow_bound(value)?;
let array = make_array(array_data);
if array.len() == 1 {
let scalar =
ScalarValue::try_from_array(&array, 0).map_err(PyDataFusionError::from)?;
return Ok(PyScalarValue(scalar));
} else {
let field = Arc::new(Field::new_list_field(
array.data_type().clone(),
array.nulls().is_some(),
));
let offsets = OffsetBuffer::from_lengths(vec![array.len()]);
let list_array = ListArray::new(field, offsets, array, None);
return Ok(PyScalarValue(ScalarValue::List(Arc::new(list_array))));
}
}

// Last attempt - try to create a PyArrow scalar from a plain Python object
if let Ok(pa) = pyarrow_mod.as_ref() {
let scalar = pa.call_method1("scalar", (value,))?;

PyScalarValue::from_pyarrow_bound(&scalar)
} else {
exec_err!("Unable to import scalar value").map_err(PyDataFusionError::from)?
}
}
}

Expand Down
12 changes: 3 additions & 9 deletions src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use pyo3::types::{PyCapsule, PyTuple};
use crate::common::data_type::PyScalarValue;
use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionResult};
use crate::expr::PyExpr;
use crate::utils::{parse_volatility, py_obj_to_scalar_value, validate_pycapsule};
use crate::utils::{parse_volatility, validate_pycapsule};

#[derive(Debug)]
struct RustAccumulator {
Expand All @@ -52,10 +52,7 @@ impl Accumulator for RustAccumulator {
let mut scalars = Vec::new();
for item in values.try_iter()? {
let item: Bound<'_, PyAny> = item?;
let scalar = match item.extract::<PyScalarValue>() {
Ok(py_scalar) => py_scalar.0,
Err(_) => py_obj_to_scalar_value(py, item.unbind())?,
};
let scalar = item.extract::<PyScalarValue>()?.0;
scalars.push(scalar);
}
Ok(scalars)
Expand All @@ -66,10 +63,7 @@ impl Accumulator for RustAccumulator {
fn evaluate(&mut self) -> Result<ScalarValue> {
Python::attach(|py| -> PyResult<ScalarValue> {
let value = self.accum.bind(py).call_method0("evaluate")?;
match value.extract::<PyScalarValue>() {
Ok(py_scalar) => Ok(py_scalar.0),
Err(_) => py_obj_to_scalar_value(py, value.unbind()),
}
value.extract::<PyScalarValue>().map(|v| v.0)
})
.map_err(|e| DataFusionError::Execution(format!("{e}")))
}
Expand Down
1 change: 0 additions & 1 deletion src/udwf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ impl PartitionEvaluator for RustPartitionEvaluator {
}

fn evaluate_all(&mut self, values: &[ArrayRef], num_rows: usize) -> Result<ArrayRef> {
println!("evaluate all called with number of values {}", values.len());
Python::attach(|py| {
let py_values = PyList::new(
py,
Expand Down
57 changes: 0 additions & 57 deletions src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,6 @@ use std::future::Future;
use std::sync::{Arc, OnceLock};
use std::time::Duration;

use datafusion::arrow::array::{make_array, ArrayData, ListArray};
use datafusion::arrow::buffer::{OffsetBuffer, ScalarBuffer};
use datafusion::arrow::datatypes::Field;
use datafusion::arrow::pyarrow::FromPyArrow;
use datafusion::common::ScalarValue;
use datafusion::datasource::TableProvider;
use datafusion::execution::context::SessionContext;
use datafusion::logical_expr::Volatility;
Expand All @@ -37,7 +32,6 @@ use tokio::runtime::Runtime;
use tokio::task::JoinHandle;
use tokio::time::sleep;

use crate::common::data_type::PyScalarValue;
use crate::context::PySessionContext;
use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError, PyDataFusionResult};
use crate::TokioRuntime;
Expand Down Expand Up @@ -203,57 +197,6 @@ pub(crate) fn table_provider_from_pycapsule<'py>(
}
}

pub(crate) fn py_obj_to_scalar_value(py: Python, obj: Py<PyAny>) -> PyResult<ScalarValue> {
// convert Python object to PyScalarValue to ScalarValue

let pa = py.import("pyarrow")?;
let scalar_attr = pa.getattr("Scalar")?;
let scalar_type = scalar_attr.downcast::<PyType>()?;
let array_attr = pa.getattr("Array")?;
let array_type = array_attr.downcast::<PyType>()?;
let chunked_array_attr = pa.getattr("ChunkedArray")?;
let chunked_array_type = chunked_array_attr.downcast::<PyType>()?;

let obj_ref = obj.bind(py);

if obj_ref.is_instance(scalar_type)? {
let py_scalar = PyScalarValue::extract_bound(obj_ref)
.map_err(|e| PyValueError::new_err(format!("Failed to extract PyScalarValue: {e}")))?;
return Ok(py_scalar.into());
}

if obj_ref.is_instance(array_type)? || obj_ref.is_instance(chunked_array_type)? {
let array_obj = if obj_ref.is_instance(chunked_array_type)? {
obj_ref.call_method0("combine_chunks")?.unbind()
} else {
obj_ref.clone().unbind()
};
let array_bound = array_obj.bind(py);
let array_data = ArrayData::from_pyarrow_bound(array_bound)
.map_err(|e| PyValueError::new_err(format!("Failed to extract pyarrow array: {e}")))?;
let array = make_array(array_data);
let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, array.len() as i32]));
let list_array = Arc::new(ListArray::new(
Arc::new(Field::new_list_field(array.data_type().clone(), true)),
offsets,
array,
None,
));

return Ok(ScalarValue::List(list_array));
}

// Convert Python object to PyArrow scalar
let scalar = pa.call_method1("scalar", (obj,))?;

// Convert PyArrow scalar to PyScalarValue
let py_scalar = PyScalarValue::extract_bound(scalar.as_ref())
.map_err(|e| PyValueError::new_err(format!("Failed to extract PyScalarValue: {e}")))?;

// Convert PyScalarValue to ScalarValue
Ok(py_scalar.into())
}

pub(crate) fn extract_logical_extension_codec(
py: Python,
obj: Option<Bound<PyAny>>,
Expand Down