Skip to content

Commit ca1352e

Browse files
committed
Set up UDWF to take arguments as constructor just like UDAF to ensure we get a clean state when functions are reused
1 parent e3b4acc commit ca1352e

File tree

3 files changed

+46
-15
lines changed

3 files changed

+46
-15
lines changed

python/datafusion/udf.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -429,18 +429,19 @@ class WindowUDF:
429429
def __init__(
430430
self,
431431
name: str | None,
432-
func: WindowEvaluator,
432+
func: Type[WindowEvaluator],
433433
input_types: list[pyarrow.DataType],
434434
return_type: pyarrow.DataType,
435435
volatility: Volatility | str,
436+
arguments: list[Any],
436437
) -> None:
437438
"""Instantiate a user-defined window function (UDWF).
438439
439440
See :py:func:`udwf` for a convenience function and argument
440441
descriptions.
441442
"""
442443
self._udwf = df_internal.WindowUDF(
443-
name, func, input_types, return_type, str(volatility)
444+
name, func, input_types, return_type, str(volatility), arguments
444445
)
445446

446447
def __call__(self, *args: Expr) -> Expr:
@@ -454,10 +455,11 @@ def __call__(self, *args: Expr) -> Expr:
454455

455456
@staticmethod
456457
def udwf(
457-
func: WindowEvaluator,
458+
func: Type[WindowEvaluator],
458459
input_types: pyarrow.DataType | list[pyarrow.DataType],
459460
return_type: pyarrow.DataType,
460461
volatility: Volatility | str,
462+
arguments: Optional[list[Any]] = None,
461463
name: str | None = None,
462464
) -> WindowUDF:
463465
"""Create a new User-Defined Window Function.
@@ -467,23 +469,26 @@ def udwf(
467469
input_types: The data types of the arguments to ``func``.
468470
return_type: The data type of the return value.
469471
volatility: See :py:class:`Volatility` for allowed values.
472+
arguments: A list of arguments to pass in to the __init__ method for accum.
470473
name: A descriptive name for the function.
471474
472475
Returns:
473476
A user-defined window function.
474477
"""
475-
if not isinstance(func, WindowEvaluator):
478+
if not issubclass(func, WindowEvaluator):
476479
raise TypeError(
477480
"`func` must implement the abstract base class WindowEvaluator"
478481
)
479482
if name is None:
480483
name = func.__class__.__qualname__.lower()
481484
if isinstance(input_types, pyarrow.DataType):
482485
input_types = [input_types]
486+
arguments = [] if arguments is None else arguments
483487
return WindowUDF(
484488
name=name,
485489
func=func,
486490
input_types=input_types,
487491
return_type=return_type,
488492
volatility=volatility,
493+
arguments=arguments,
489494
)

python/tests/test_udwf.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525

2626
class ExponentialSmoothDefault(WindowEvaluator):
27-
def __init__(self, alpha: float) -> None:
27+
def __init__(self, alpha: float = 0.8) -> None:
2828
self.alpha = alpha
2929

3030
def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array:
@@ -183,46 +183,58 @@ def df():
183183
def test_udwf_errors(df):
184184
with pytest.raises(TypeError):
185185
udwf(
186-
NotSubclassOfWindowEvaluator(),
186+
NotSubclassOfWindowEvaluator,
187187
pa.float64(),
188188
pa.float64(),
189189
volatility="immutable",
190190
)
191191

192192

193193
smooth_default = udwf(
194-
ExponentialSmoothDefault(0.9),
194+
ExponentialSmoothDefault,
195+
pa.float64(),
196+
pa.float64(),
197+
volatility="immutable",
198+
arguments=[0.9],
199+
)
200+
201+
smooth_no_arugments = udwf(
202+
ExponentialSmoothDefault,
195203
pa.float64(),
196204
pa.float64(),
197205
volatility="immutable",
198206
)
199207

200208
smooth_bounded = udwf(
201-
ExponentialSmoothBounded(0.9),
209+
ExponentialSmoothBounded,
202210
pa.float64(),
203211
pa.float64(),
204212
volatility="immutable",
213+
arguments=[0.9],
205214
)
206215

207216
smooth_rank = udwf(
208-
ExponentialSmoothRank(0.9),
217+
ExponentialSmoothRank,
209218
pa.utf8(),
210219
pa.float64(),
211220
volatility="immutable",
221+
arguments=[0.9],
212222
)
213223

214224
smooth_frame = udwf(
215-
ExponentialSmoothFrame(0.9),
225+
ExponentialSmoothFrame,
216226
pa.float64(),
217227
pa.float64(),
218228
volatility="immutable",
229+
arguments=[0.9],
219230
)
220231

221232
smooth_two_col = udwf(
222-
SmoothTwoColumn(0.9),
233+
SmoothTwoColumn,
223234
[pa.int64(), pa.int64()],
224235
pa.float64(),
225236
volatility="immutable",
237+
arguments=[0.9],
226238
)
227239

228240
data_test_udwf_functions = [
@@ -231,6 +243,11 @@ def test_udwf_errors(df):
231243
smooth_default(column("a")),
232244
[0, 0.9, 1.89, 2.889, 3.889, 4.889, 5.889],
233245
),
246+
(
247+
"default_udwf_no_arguments",
248+
smooth_no_arugments(column("a")),
249+
[0, 0.8, 1.76, 2.752, 3.75, 4.75, 5.75],
250+
),
234251
(
235252
"default_udwf_partitioned",
236253
smooth_default(column("a")).partition_by(column("c")).build(),

src/udwf.rs

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,9 +195,17 @@ impl PartitionEvaluator for RustPartitionEvaluator {
195195
}
196196
}
197197

198-
pub fn to_rust_partition_evaluator(evaluator: PyObject) -> PartitionEvaluatorFactory {
198+
pub fn to_rust_partition_evaluator(
199+
evaluator: PyObject,
200+
arguments: Vec<PyObject>,
201+
) -> PartitionEvaluatorFactory {
199202
Arc::new(move || -> Result<Box<dyn PartitionEvaluator>> {
200-
let evaluator = Python::with_gil(|py| evaluator.clone_ref(py));
203+
let evaluator = Python::with_gil(|py| {
204+
let py_args = PyTuple::new_bound(py, arguments.iter());
205+
evaluator
206+
.call1(py, py_args)
207+
.map_err(|e| DataFusionError::Execution(e.to_string()))
208+
})?;
201209
Ok(Box::new(RustPartitionEvaluator::new(evaluator)))
202210
})
203211
}
@@ -212,13 +220,14 @@ pub struct PyWindowUDF {
212220
#[pymethods]
213221
impl PyWindowUDF {
214222
#[new]
215-
#[pyo3(signature=(name, evaluator, input_types, return_type, volatility))]
223+
#[pyo3(signature=(name, evaluator, input_types, return_type, volatility, arguments))]
216224
fn new(
217225
name: &str,
218226
evaluator: PyObject,
219227
input_types: Vec<PyArrowType<DataType>>,
220228
return_type: PyArrowType<DataType>,
221229
volatility: &str,
230+
arguments: Vec<PyObject>,
222231
) -> PyResult<Self> {
223232
let return_type = return_type.0;
224233
let input_types = input_types.into_iter().map(|t| t.0).collect();
@@ -228,7 +237,7 @@ impl PyWindowUDF {
228237
input_types,
229238
return_type,
230239
parse_volatility(volatility)?,
231-
to_rust_partition_evaluator(evaluator),
240+
to_rust_partition_evaluator(evaluator, arguments),
232241
));
233242
Ok(Self { function })
234243
}

0 commit comments

Comments
 (0)