Skip to content

Commit 0c48f4f

Browse files
committed
Adding PyWindowUDF and implementing PartitionEvaluator for it. Still requires python side work.
1 parent a00cfbf commit 0c48f4f

File tree

3 files changed

+246
-0
lines changed

3 files changed

+246
-0
lines changed

src/context.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ use crate::sql::logical::PyLogicalPlan;
4242
use crate::store::StorageContexts;
4343
use crate::udaf::PyAggregateUDF;
4444
use crate::udf::PyScalarUDF;
45+
use crate::udwf::PyWindowUDF;
4546
use crate::utils::{get_tokio_runtime, wait_for_future};
4647
use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef};
4748
use datafusion::arrow::pyarrow::PyArrowType;
@@ -746,6 +747,11 @@ impl PySessionContext {
746747
Ok(())
747748
}
748749

750+
pub fn register_udwf(&mut self, udwf: PyWindowUDF) -> PyResult<()> {
751+
self.ctx.register_udwf(udwf.function);
752+
Ok(())
753+
}
754+
749755
#[pyo3(signature = (name="datafusion"))]
750756
pub fn catalog(&self, name: &str) -> PyResult<PyCatalog> {
751757
match self.ctx.catalog(name) {

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ pub mod substrait;
5858
mod udaf;
5959
#[allow(clippy::borrow_deref_ref)]
6060
mod udf;
61+
mod udwf;
6162
pub mod utils;
6263

6364
#[cfg(feature = "mimalloc")]

src/udwf.rs

Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use std::ops::Range;
19+
use std::sync::Arc;
20+
21+
use arrow::array::{make_array, Array, ArrayData, ArrayRef};
22+
use datafusion::logical_expr::window_state::WindowAggState;
23+
use datafusion::prelude::create_udwf;
24+
use datafusion::scalar::ScalarValue;
25+
use pyo3::exceptions::PyValueError;
26+
use pyo3::prelude::*;
27+
28+
use datafusion::arrow::datatypes::DataType;
29+
use datafusion::arrow::pyarrow::{FromPyArrow, PyArrowType, ToPyArrow};
30+
use datafusion::error::{DataFusionError, Result};
31+
use datafusion::logical_expr::{PartitionEvaluator, PartitionEvaluatorFactory, WindowUDF};
32+
use pyo3::types::{PyList, PyTuple};
33+
34+
use crate::expr::PyExpr;
35+
use crate::utils::parse_volatility;
36+
37+
#[derive(Debug)]
38+
struct RustPartitionEvaluator {
39+
evaluator: PyObject,
40+
}
41+
42+
impl RustPartitionEvaluator {
43+
fn new(evaluator: PyObject) -> Self {
44+
Self { evaluator }
45+
}
46+
}
47+
48+
impl PartitionEvaluator for RustPartitionEvaluator {
49+
fn memoize(&mut self, _state: &mut WindowAggState) -> Result<()> {
50+
Python::with_gil(|py| self.evaluator.bind(py).call_method0("memoize").map(|_| ()))
51+
.map_err(|e| DataFusionError::Execution(format!("{e}")))
52+
}
53+
54+
fn get_range(&self, idx: usize, n_rows: usize) -> Result<Range<usize>> {
55+
Python::with_gil(|py| {
56+
let py_args = vec![idx.to_object(py), n_rows.to_object(py)];
57+
let py_args = PyTuple::new_bound(py, py_args);
58+
59+
self.evaluator
60+
.bind(py)
61+
.call_method1("get_range", py_args)
62+
.and_then(|v| {
63+
let tuple: Bound<'_, PyTuple> = v.extract()?;
64+
if tuple.len() != 2 {
65+
return Err(PyValueError::new_err(format!(
66+
"Expected get_range to return tuple of length 2. Received length {}",
67+
tuple.len()
68+
)));
69+
}
70+
71+
let start: usize = tuple.get_item(0).unwrap().extract()?;
72+
let end: usize = tuple.get_item(1).unwrap().extract()?;
73+
74+
Ok(Range { start, end })
75+
})
76+
})
77+
.map_err(|e| DataFusionError::Execution(format!("{e}")))
78+
}
79+
80+
fn is_causal(&self) -> bool {
81+
Python::with_gil(|py| {
82+
self.evaluator
83+
.bind(py)
84+
.call_method0("is_causal")
85+
.and_then(|v| v.extract())
86+
})
87+
.unwrap_or(false)
88+
}
89+
90+
fn evaluate_all(&mut self, values: &[ArrayRef], num_rows: usize) -> Result<ArrayRef> {
91+
Python::with_gil(|py| {
92+
// 1. cast args to Pyarrow array
93+
let mut py_args = values
94+
.iter()
95+
.map(|arg| arg.into_data().to_pyarrow(py).unwrap())
96+
.collect::<Vec<_>>();
97+
py_args.push(num_rows.to_object(py));
98+
let py_args = PyTuple::new_bound(py, py_args);
99+
100+
// 2. call function
101+
self.evaluator
102+
.bind(py)
103+
.call_method1("evaluate_all", py_args)
104+
.map_err(|e| DataFusionError::Execution(format!("{e}")))
105+
.map(|v| {
106+
let array_data = ArrayData::from_pyarrow_bound(&v).unwrap();
107+
make_array(array_data)
108+
})
109+
})
110+
}
111+
112+
fn evaluate(&mut self, values: &[ArrayRef], range: &Range<usize>) -> Result<ScalarValue> {
113+
Python::with_gil(|py| {
114+
// 1. cast args to Pyarrow array
115+
let mut py_args = values
116+
.iter()
117+
.map(|arg| arg.into_data().to_pyarrow(py).unwrap())
118+
.collect::<Vec<_>>();
119+
py_args.push(range.start.to_object(py));
120+
py_args.push(range.end.to_object(py));
121+
let py_args = PyTuple::new_bound(py, py_args);
122+
123+
// 2. call function
124+
self.evaluator
125+
.bind(py)
126+
.call_method1("evaluate", py_args)
127+
.and_then(|v| v.extract())
128+
.map_err(|e| DataFusionError::Execution(format!("{e}")))
129+
})
130+
}
131+
132+
fn evaluate_all_with_rank(
133+
&self,
134+
num_rows: usize,
135+
ranks_in_partition: &[Range<usize>],
136+
) -> Result<ArrayRef> {
137+
Python::with_gil(|py| {
138+
let ranks = ranks_in_partition
139+
.iter()
140+
.map(|r| PyTuple::new_bound(py, vec![r.start, r.end]));
141+
142+
// 1. cast args to Pyarrow array
143+
let py_args = vec![num_rows.to_object(py), PyList::new_bound(py, ranks).into()];
144+
145+
let py_args = PyTuple::new_bound(py, py_args);
146+
147+
// 2. call function
148+
self.evaluator
149+
.bind(py)
150+
.call_method1("evaluate_all_with_rank", py_args)
151+
.map_err(|e| DataFusionError::Execution(format!("{e}")))
152+
.map(|v| {
153+
let array_data = ArrayData::from_pyarrow_bound(&v).unwrap();
154+
make_array(array_data)
155+
})
156+
})
157+
}
158+
159+
fn supports_bounded_execution(&self) -> bool {
160+
Python::with_gil(|py| {
161+
self.evaluator
162+
.bind(py)
163+
.call_method0("supports_bounded_execution")
164+
.and_then(|v| v.extract())
165+
})
166+
.unwrap_or(false)
167+
}
168+
169+
fn uses_window_frame(&self) -> bool {
170+
Python::with_gil(|py| {
171+
self.evaluator
172+
.bind(py)
173+
.call_method0("uses_window_frame")
174+
.and_then(|v| v.extract())
175+
})
176+
.unwrap_or(false)
177+
}
178+
179+
fn include_rank(&self) -> bool {
180+
Python::with_gil(|py| {
181+
self.evaluator
182+
.bind(py)
183+
.call_method0("include_rank")
184+
.and_then(|v| v.extract())
185+
})
186+
.unwrap_or(false)
187+
}
188+
}
189+
190+
pub fn to_rust_partition_evaluator(evalutor: PyObject) -> PartitionEvaluatorFactory {
191+
Arc::new(move || -> Result<Box<dyn PartitionEvaluator>> {
192+
let evalutor = Python::with_gil(|py| {
193+
evalutor
194+
.call0(py)
195+
.map_err(|e| DataFusionError::Execution(format!("{e}")))
196+
})?;
197+
Ok(Box::new(RustPartitionEvaluator::new(evalutor)))
198+
})
199+
}
200+
201+
/// Represents an WindowUDF
202+
#[pyclass(name = "WindowUDF", module = "datafusion", subclass)]
203+
#[derive(Debug, Clone)]
204+
pub struct PyWindowUDF {
205+
pub(crate) function: WindowUDF,
206+
}
207+
208+
#[pymethods]
209+
impl PyWindowUDF {
210+
#[new]
211+
#[pyo3(signature=(name, evaluator, input_type, return_type, volatility))]
212+
fn new(
213+
name: &str,
214+
evaluator: PyObject,
215+
input_type: PyArrowType<DataType>,
216+
return_type: PyArrowType<DataType>,
217+
volatility: &str,
218+
) -> PyResult<Self> {
219+
let function = create_udwf(
220+
name,
221+
input_type.0,
222+
Arc::new(return_type.0),
223+
parse_volatility(volatility)?,
224+
to_rust_partition_evaluator(evaluator),
225+
);
226+
Ok(Self { function })
227+
}
228+
229+
/// creates a new PyExpr with the call of the udf
230+
#[pyo3(signature = (*args))]
231+
fn __call__(&self, args: Vec<PyExpr>) -> PyResult<PyExpr> {
232+
let args = args.iter().map(|e| e.expr.clone()).collect();
233+
Ok(self.function.call(args).into())
234+
}
235+
236+
fn __repr__(&self) -> PyResult<String> {
237+
Ok(format!("WindowUDF({})", self.function.name()))
238+
}
239+
}

0 commit comments

Comments
 (0)