Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
296 changes: 170 additions & 126 deletions Cargo.lock

Large diffs are not rendered by default.

18 changes: 9 additions & 9 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,24 +34,24 @@ protoc = [ "datafusion-substrait/protoc" ]
substrait = ["dep:datafusion-substrait"]

[dependencies]
tokio = { version = "1.42", features = ["macros", "rt", "rt-multi-thread", "sync"] }
tokio = { version = "1.43", features = ["macros", "rt", "rt-multi-thread", "sync"] }
pyo3 = { version = "0.23", features = ["extension-module", "abi3", "abi3-py39"] }
pyo3-async-runtimes = { version = "0.23", features = ["tokio-runtime"]}
arrow = { version = "54", features = ["pyarrow"] }
datafusion = { version = "45.0.0", features = ["avro", "unicode_expressions"] }
datafusion-substrait = { version = "45.0.0", optional = true }
datafusion-proto = { version = "45.0.0" }
datafusion-ffi = { version = "45.0.0" }
prost = "0.13" # keep in line with `datafusion-substrait`
arrow = { version = "54.2.1", features = ["pyarrow"] }
datafusion = { version = "46.0.1", features = ["avro", "unicode_expressions"] }
datafusion-substrait = { version = "46.0.1", optional = true }
datafusion-proto = { version = "46.0.1" }
datafusion-ffi = { version = "46.0.1" }
prost = "0.13.1" # keep in line with `datafusion-substrait`
uuid = { version = "1.12", features = ["v4"] }
mimalloc = { version = "0.1", optional = true, default-features = false, features = ["local_dynamic_tls"] }
async-trait = "0.1"
async-trait = "0.1.73"
futures = "0.3"
object_store = { version = "0.11.0", features = ["aws", "gcp", "azure", "http"] }
url = "2"

[build-dependencies]
prost-types = "0.13" # keep in line with `datafusion-substrait`
prost-types = "0.13.1" # keep in line with `datafusion-substrait`
pyo3-build-config = "0.23"

[lib]
Expand Down
3 changes: 2 additions & 1 deletion python/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,7 +752,8 @@ def test_execution_plan(aggregate_df):
assert "AggregateExec:" in indent
assert "CoalesceBatchesExec:" in indent
assert "RepartitionExec:" in indent
assert "CsvExec:" in indent
assert "DataSourceExec:" in indent
assert "file_type=csv" in indent

ctx = SessionContext()
rows_returned = 0
Expand Down
39 changes: 22 additions & 17 deletions src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use datafusion::logical_expr::expr::{AggregateFunctionParams, WindowFunctionParams};
use datafusion::logical_expr::utils::exprlist_to_fields;
use datafusion::logical_expr::{
ExprFuncBuilder, ExprFunctionExt, LogicalPlan, WindowFunctionDefinition,
Expand Down Expand Up @@ -172,6 +173,7 @@ impl PyExpr {
Expr::ScalarSubquery(value) => {
Ok(scalar_subquery::PyScalarSubquery::from(value.clone()).into_bound_py_any(py)?)
}
#[allow(deprecated)]
Expr::Wildcard { qualifier, options } => Err(py_unsupported_variant_err(format!(
"Converting Expr::Wildcard to a Python object is not implemented : {:?} {:?}",
qualifier, options
Expand Down Expand Up @@ -332,7 +334,6 @@ impl PyExpr {
| Expr::AggregateFunction { .. }
| Expr::WindowFunction { .. }
| Expr::InList { .. }
| Expr::Wildcard { .. }
| Expr::Exists { .. }
| Expr::InSubquery { .. }
| Expr::GroupingSet(..)
Expand All @@ -346,6 +347,10 @@ impl PyExpr {
| Expr::Unnest(_)
| Expr::IsNotUnknown(_) => RexType::Call,
Expr::ScalarSubquery(..) => RexType::ScalarSubquery,
#[allow(deprecated)]
Expr::Wildcard { .. } => {
return Err(py_unsupported_variant_err("Expr::Wildcard is unsupported"))
}
})
}

Expand Down Expand Up @@ -394,11 +399,15 @@ impl PyExpr {
| Expr::InSubquery(InSubquery { expr, .. }) => Ok(vec![PyExpr::from(*expr.clone())]),

// Expr variants containing a collection of Expr(s) for operands
Expr::AggregateFunction(AggregateFunction { args, .. })
Expr::AggregateFunction(AggregateFunction {
params: AggregateFunctionParams { args, .. },
..
})
| Expr::ScalarFunction(ScalarFunction { args, .. })
| Expr::WindowFunction(WindowFunction { args, .. }) => {
Ok(args.iter().map(|arg| PyExpr::from(arg.clone())).collect())
}
| Expr::WindowFunction(WindowFunction {
params: WindowFunctionParams { args, .. },
..
}) => Ok(args.iter().map(|arg| PyExpr::from(arg.clone())).collect()),

// Expr(s) that require more specific processing
Expr::Case(Case {
Expand Down Expand Up @@ -465,13 +474,17 @@ impl PyExpr {
Expr::GroupingSet(..)
| Expr::Unnest(_)
| Expr::OuterReferenceColumn(_, _)
| Expr::Wildcard { .. }
| Expr::ScalarSubquery(..)
| Expr::Placeholder { .. }
| Expr::Exists { .. } => Err(py_runtime_err(format!(
"Unimplemented Expr type: {}",
self.expr
))),

#[allow(deprecated)]
Expr::Wildcard { .. } => {
Err(py_unsupported_variant_err("Expr::Wildcard is unsupported"))
}
}
}

Expand Down Expand Up @@ -575,7 +588,7 @@ impl PyExpr {
Expr::AggregateFunction(agg_fn) => {
let window_fn = Expr::WindowFunction(WindowFunction::new(
WindowFunctionDefinition::AggregateUDF(agg_fn.func.clone()),
agg_fn.args.clone(),
agg_fn.params.args.clone(),
));

add_builder_fns_to_window(
Expand Down Expand Up @@ -663,16 +676,8 @@ impl PyExpr {

/// Create a [Field] representing an [Expr], given an input [LogicalPlan] to resolve against
pub fn expr_to_field(expr: &Expr, input_plan: &LogicalPlan) -> PyDataFusionResult<Arc<Field>> {
match expr {
Expr::Wildcard { .. } => {
// Since * could be any of the valid column names just return the first one
Ok(Arc::new(input_plan.schema().field(0).clone()))
}
_ => {
let fields = exprlist_to_fields(&[expr.clone()], input_plan)?;
Ok(fields[0].1.clone())
}
}
let fields = exprlist_to_fields(&[expr.clone()], input_plan)?;
Ok(fields[0].1.clone())
}
fn _types(expr: &Expr) -> PyResult<DataTypeMap> {
match expr {
Expand Down
10 changes: 6 additions & 4 deletions src/expr/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
// under the License.

use datafusion::common::DataFusionError;
use datafusion::logical_expr::expr::{AggregateFunction, Alias};
use datafusion::logical_expr::expr::{AggregateFunction, AggregateFunctionParams, Alias};
use datafusion::logical_expr::logical_plan::Aggregate;
use datafusion::logical_expr::Expr;
use pyo3::{prelude::*, IntoPyObjectExt};
Expand Down Expand Up @@ -126,9 +126,11 @@ impl PyAggregate {
match expr {
// TODO: This Alias logic seems to be returning some strange results that we should investigate
Expr::Alias(Alias { expr, .. }) => self._aggregation_arguments(expr.as_ref()),
Expr::AggregateFunction(AggregateFunction { func: _, args, .. }) => {
Ok(args.iter().map(|e| PyExpr::from(e.clone())).collect())
}
Expr::AggregateFunction(AggregateFunction {
func: _,
params: AggregateFunctionParams { args, .. },
..
}) => Ok(args.iter().map(|e| PyExpr::from(e.clone())).collect()),
_ => Err(py_type_err(
"Encountered a non Aggregate type in aggregation_arguments",
)),
Expand Down
11 changes: 9 additions & 2 deletions src/expr/aggregate_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,13 @@ impl From<AggregateFunction> for PyAggregateFunction {

impl Display for PyAggregateFunction {
fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
let args: Vec<String> = self.aggr.args.iter().map(|expr| expr.to_string()).collect();
let args: Vec<String> = self
.aggr
.params
.args
.iter()
.map(|expr| expr.to_string())
.collect();
write!(f, "{}({})", self.aggr.func.name(), args.join(", "))
}
}
Expand All @@ -54,12 +60,13 @@ impl PyAggregateFunction {

/// is this a distinct aggregate such as `COUNT(DISTINCT expr)`
fn is_distinct(&self) -> bool {
self.aggr.distinct
self.aggr.params.distinct
}

/// Get the arguments to the aggregate function
fn args(&self) -> Vec<PyExpr> {
self.aggr
.params
.args
.iter()
.map(|expr| PyExpr::from(expr.clone()))
Expand Down
24 changes: 17 additions & 7 deletions src/expr/window.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
// under the License.

use datafusion::common::{DataFusionError, ScalarValue};
use datafusion::logical_expr::expr::WindowFunction;
use datafusion::logical_expr::expr::{WindowFunction, WindowFunctionParams};
use datafusion::logical_expr::{Expr, Window, WindowFrame, WindowFrameBound, WindowFrameUnits};
use pyo3::{prelude::*, IntoPyObjectExt};
use std::fmt::{self, Display, Formatter};
Expand Down Expand Up @@ -118,25 +118,32 @@ impl PyWindowExpr {
/// Returns order by columns in a window function expression
pub fn get_sort_exprs(&self, expr: PyExpr) -> PyResult<Vec<PySortExpr>> {
match expr.expr.unalias() {
Expr::WindowFunction(WindowFunction { order_by, .. }) => py_sort_expr_list(&order_by),
Expr::WindowFunction(WindowFunction {
params: WindowFunctionParams { order_by, .. },
..
}) => py_sort_expr_list(&order_by),
other => Err(not_window_function_err(other)),
}
}

/// Return partition by columns in a window function expression
pub fn get_partition_exprs(&self, expr: PyExpr) -> PyResult<Vec<PyExpr>> {
match expr.expr.unalias() {
Expr::WindowFunction(WindowFunction { partition_by, .. }) => {
py_expr_list(&partition_by)
}
Expr::WindowFunction(WindowFunction {
params: WindowFunctionParams { partition_by, .. },
..
}) => py_expr_list(&partition_by),
other => Err(not_window_function_err(other)),
}
}

/// Return input args for window function
pub fn get_args(&self, expr: PyExpr) -> PyResult<Vec<PyExpr>> {
match expr.expr.unalias() {
Expr::WindowFunction(WindowFunction { args, .. }) => py_expr_list(&args),
Expr::WindowFunction(WindowFunction {
params: WindowFunctionParams { args, .. },
..
}) => py_expr_list(&args),
other => Err(not_window_function_err(other)),
}
}
Expand All @@ -152,7 +159,10 @@ impl PyWindowExpr {
/// Returns a Pywindow frame for a given window function expression
pub fn get_frame(&self, expr: PyExpr) -> Option<PyWindowFrame> {
match expr.expr.unalias() {
Expr::WindowFunction(WindowFunction { window_frame, .. }) => Some(window_frame.into()),
Expr::WindowFunction(WindowFunction {
params: WindowFunctionParams { window_frame, .. },
..
}) => Some(window_frame.into()),
_ => None,
}
}
Expand Down
34 changes: 17 additions & 17 deletions src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

use datafusion::functions_aggregate::all_default_aggregate_functions;
use datafusion::functions_window::all_default_window_functions;
use datafusion::logical_expr::expr::WindowFunctionParams;
use datafusion::logical_expr::ExprFunctionExt;
use datafusion::logical_expr::WindowFrame;
use pyo3::{prelude::*, wrap_pyfunction};
Expand Down Expand Up @@ -215,10 +216,7 @@ fn alias(expr: PyExpr, name: &str) -> PyResult<PyExpr> {
#[pyfunction]
fn col(name: &str) -> PyResult<PyExpr> {
Ok(PyExpr {
expr: datafusion::logical_expr::Expr::Column(Column {
relation: None,
name: name.to_string(),
}),
expr: datafusion::logical_expr::Expr::Column(Column::new_unqualified(name)),
})
}

Expand Down Expand Up @@ -333,19 +331,21 @@ fn window(
Ok(PyExpr {
expr: datafusion::logical_expr::Expr::WindowFunction(WindowFunction {
fun,
args: args.into_iter().map(|x| x.expr).collect::<Vec<_>>(),
partition_by: partition_by
.unwrap_or_default()
.into_iter()
.map(|x| x.expr)
.collect::<Vec<_>>(),
order_by: order_by
.unwrap_or_default()
.into_iter()
.map(|x| x.into())
.collect::<Vec<_>>(),
window_frame,
null_treatment: None,
params: WindowFunctionParams {
args: args.into_iter().map(|x| x.expr).collect::<Vec<_>>(),
partition_by: partition_by
.unwrap_or_default()
.into_iter()
.map(|x| x.expr)
.collect::<Vec<_>>(),
order_by: order_by
.unwrap_or_default()
.into_iter()
.map(|x| x.into())
.collect::<Vec<_>>(),
window_frame,
null_treatment: None,
},
}),
})
}
Expand Down
2 changes: 1 addition & 1 deletion src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ pub(crate) fn get_tokio_runtime() -> &'static TokioRuntime {
#[inline]
pub(crate) fn get_global_ctx() -> &'static SessionContext {
static CTX: OnceLock<SessionContext> = OnceLock::new();
CTX.get_or_init(|| SessionContext::new())
CTX.get_or_init(SessionContext::new)
}

/// Utility to collect rust futures with GIL released
Expand Down