From c119b5df27fc61dbd3fed931acd2c6804f020d88 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 2 Oct 2024 08:49:20 -0400 Subject: [PATCH 1/9] Add option for passing in constructor arguments to the udaf --- python/datafusion/udf.py | 15 +++++++++++++-- python/tests/test_udaf.py | 37 ++++++++++++++++++++++++++++++++----- src/udaf.rs | 13 +++++++++---- 3 files changed, 54 insertions(+), 11 deletions(-) diff --git a/python/datafusion/udf.py b/python/datafusion/udf.py index bb7a90866..2667c60e7 100644 --- a/python/datafusion/udf.py +++ b/python/datafusion/udf.py @@ -23,7 +23,7 @@ from datafusion.expr import Expr from typing import Callable, TYPE_CHECKING, TypeVar from abc import ABCMeta, abstractmethod -from typing import List +from typing import List, Any, Optional from enum import Enum import pyarrow @@ -186,6 +186,7 @@ def __init__( return_type: _R, state_type: list[pyarrow.DataType], volatility: Volatility | str, + arguments: list[Any], ) -> None: """Instantiate a user-defined aggregate function (UDAF). @@ -193,7 +194,13 @@ def __init__( descriptions. """ self._udaf = df_internal.AggregateUDF( - name, accumulator, input_types, return_type, state_type, str(volatility) + name, + accumulator, + input_types, + return_type, + state_type, + str(volatility), + arguments, ) def __call__(self, *args: Expr) -> Expr: @@ -212,6 +219,7 @@ def udaf( return_type: _R, state_type: list[pyarrow.DataType], volatility: Volatility | str, + arguments: Optional[list[Any]] = None, name: str | None = None, ) -> AggregateUDF: """Create a new User-Defined Aggregate Function. @@ -224,6 +232,7 @@ def udaf( return_type: The data type of the return value. state_type: The data types of the intermediate accumulation. volatility: See :py:class:`Volatility` for allowed values. + arguments: A list of arguments to pass in to the __init__ method for accum. name: A descriptive name for the function. Returns: @@ -238,6 +247,7 @@ def udaf( name = accum.__qualname__.lower() if isinstance(input_types, pyarrow.lib.DataType): input_types = [input_types] + arguments = [] if arguments is None else arguments return AggregateUDF( name=name, accumulator=accum, @@ -245,6 +255,7 @@ def udaf( return_type=return_type, state_type=state_type, volatility=volatility, + arguments=arguments, ) diff --git a/python/tests/test_udaf.py b/python/tests/test_udaf.py index 6f2525b0f..792e45f0a 100644 --- a/python/tests/test_udaf.py +++ b/python/tests/test_udaf.py @@ -27,8 +27,8 @@ class Summarize(Accumulator): """Interface of a user-defined accumulation.""" - def __init__(self): - self._sum = pa.scalar(0.0) + def __init__(self, initial_value: float = 0.0): + self._sum = pa.scalar(initial_value) def state(self) -> List[pa.Scalar]: return [self._sum] @@ -97,7 +97,7 @@ def test_errors(df): df.collect() -def test_aggregate(df): +def test_udaf_aggregate(df): summarize = udaf( Summarize, pa.float64(), @@ -106,14 +106,41 @@ def test_aggregate(df): volatility="immutable", ) - df = df.aggregate([], [summarize(column("a"))]) + df1 = df.aggregate([], [summarize(column("a"))]) # execute and collect the first (and only) batch - result = df.collect()[0] + result = df1.collect()[0] + + assert result.column(0) == pa.array([1.0 + 2.0 + 3.0]) + + df2 = df.aggregate([], [summarize(column("a"))]) + + # Run a second time to ensure the state is properly reset + result = df2.collect()[0] assert result.column(0) == pa.array([1.0 + 2.0 + 3.0]) +def test_udaf_aggregate_with_arguments(df): + bias = 10.0 + + summarize = udaf( + Summarize, + pa.float64(), + pa.float64(), + [pa.float64()], + volatility="immutable", + arguments=[bias], + ) + + df1 = df.aggregate([], [summarize(column("a"))]) + + # execute and collect the first (and only) batch + result = df1.collect()[0] + + assert result.column(0) == pa.array([bias + 1.0 + 2.0 + 3.0]) + + def test_group_by(df): summarize = udaf( Summarize, diff --git a/src/udaf.rs b/src/udaf.rs index a6aa59ac3..b9db47a88 100644 --- a/src/udaf.rs +++ b/src/udaf.rs @@ -128,11 +128,15 @@ impl Accumulator for RustAccumulator { } } -pub fn to_rust_accumulator(accum: PyObject) -> AccumulatorFactoryFunction { +pub fn to_rust_accumulator( + accum: PyObject, + arguments: Vec, +) -> AccumulatorFactoryFunction { Arc::new(move |_| -> Result> { let accum = Python::with_gil(|py| { + let py_args = PyTuple::new_bound(py, arguments.iter()); accum - .call0(py) + .call1(py, py_args) .map_err(|e| DataFusionError::Execution(format!("{e}"))) })?; Ok(Box::new(RustAccumulator::new(accum))) @@ -149,7 +153,7 @@ pub struct PyAggregateUDF { #[pymethods] impl PyAggregateUDF { #[new] - #[pyo3(signature=(name, accumulator, input_type, return_type, state_type, volatility))] + #[pyo3(signature=(name, accumulator, input_type, return_type, state_type, volatility, arguments))] fn new( name: &str, accumulator: PyObject, @@ -157,13 +161,14 @@ impl PyAggregateUDF { return_type: PyArrowType, state_type: PyArrowType>, volatility: &str, + arguments: Vec, ) -> PyResult { let function = create_udaf( name, input_type.0, Arc::new(return_type.0), parse_volatility(volatility)?, - to_rust_accumulator(accumulator), + to_rust_accumulator(accumulator, arguments), Arc::new(state_type.0), ); Ok(Self { function }) From 14fc1667339015b72eab333978f610774457563b Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 2 Oct 2024 08:51:04 -0400 Subject: [PATCH 2/9] Fix small warnings in pylance --- python/datafusion/udf.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/datafusion/udf.py b/python/datafusion/udf.py index 2667c60e7..d8419b7da 100644 --- a/python/datafusion/udf.py +++ b/python/datafusion/udf.py @@ -104,8 +104,8 @@ def __call__(self, *args: Expr) -> Expr: This function is not typically called by an end user. These calls will occur during the evaluation of the dataframe. """ - args = [arg.expr for arg in args] - return Expr(self._udf.__call__(*args)) + args_raw = [arg.expr for arg in args] + return Expr(self._udf.__call__(*args_raw)) @staticmethod def udf( @@ -209,13 +209,13 @@ def __call__(self, *args: Expr) -> Expr: This function is not typically called by an end user. These calls will occur during the evaluation of the dataframe. """ - args = [arg.expr for arg in args] - return Expr(self._udaf.__call__(*args)) + args_raw = [arg.expr for arg in args] + return Expr(self._udaf.__call__(*args_raw)) @staticmethod def udaf( accum: _A, - input_types: list[pyarrow.DataType], + input_types: pyarrow.DataType | list[pyarrow.DataType], return_type: _R, state_type: list[pyarrow.DataType], volatility: Volatility | str, @@ -245,7 +245,7 @@ def udaf( ) if name is None: name = accum.__qualname__.lower() - if isinstance(input_types, pyarrow.lib.DataType): + if isinstance(input_types, pyarrow.DataType): input_types = [input_types] arguments = [] if arguments is None else arguments return AggregateUDF( From e3b4acc0edf1e506e3beb968805d5944d41f5abc Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 2 Oct 2024 08:57:40 -0400 Subject: [PATCH 3/9] Improve type hinting for udaf and fix one pylance warning --- python/datafusion/udf.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/python/datafusion/udf.py b/python/datafusion/udf.py index d8419b7da..6f68d085b 100644 --- a/python/datafusion/udf.py +++ b/python/datafusion/udf.py @@ -21,7 +21,7 @@ import datafusion._internal as df_internal from datafusion.expr import Expr -from typing import Callable, TYPE_CHECKING, TypeVar +from typing import Callable, TYPE_CHECKING, TypeVar, Type from abc import ABCMeta, abstractmethod from typing import List, Any, Optional from enum import Enum @@ -167,10 +167,6 @@ def evaluate(self) -> pyarrow.Scalar: pass -if TYPE_CHECKING: - _A = TypeVar("_A", bound=(Callable[..., _R], Accumulator)) - - class AggregateUDF: """Class for performing scalar user-defined functions (UDF). @@ -181,9 +177,9 @@ class AggregateUDF: def __init__( self, name: str | None, - accumulator: _A, + accumulator: Type[Accumulator], input_types: list[pyarrow.DataType], - return_type: _R, + return_type: pyarrow.DataType, state_type: list[pyarrow.DataType], volatility: Volatility | str, arguments: list[Any], @@ -214,9 +210,9 @@ def __call__(self, *args: Expr) -> Expr: @staticmethod def udaf( - accum: _A, + accum: Type[Accumulator], input_types: pyarrow.DataType | list[pyarrow.DataType], - return_type: _R, + return_type: pyarrow.DataType, state_type: list[pyarrow.DataType], volatility: Volatility | str, arguments: Optional[list[Any]] = None, From ca1352e921f0e821045a8653be6832d2c6e759c6 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 2 Oct 2024 09:12:58 -0400 Subject: [PATCH 4/9] Set up UDWF to take arguments as constructor just like UDAF to ensure we get a clean state when functions are reused --- python/datafusion/udf.py | 13 +++++++++---- python/tests/test_udwf.py | 31 ++++++++++++++++++++++++------- src/udwf.rs | 17 +++++++++++++---- 3 files changed, 46 insertions(+), 15 deletions(-) diff --git a/python/datafusion/udf.py b/python/datafusion/udf.py index 6f68d085b..e33b1ccae 100644 --- a/python/datafusion/udf.py +++ b/python/datafusion/udf.py @@ -429,10 +429,11 @@ class WindowUDF: def __init__( self, name: str | None, - func: WindowEvaluator, + func: Type[WindowEvaluator], input_types: list[pyarrow.DataType], return_type: pyarrow.DataType, volatility: Volatility | str, + arguments: list[Any], ) -> None: """Instantiate a user-defined window function (UDWF). @@ -440,7 +441,7 @@ def __init__( descriptions. """ self._udwf = df_internal.WindowUDF( - name, func, input_types, return_type, str(volatility) + name, func, input_types, return_type, str(volatility), arguments ) def __call__(self, *args: Expr) -> Expr: @@ -454,10 +455,11 @@ def __call__(self, *args: Expr) -> Expr: @staticmethod def udwf( - func: WindowEvaluator, + func: Type[WindowEvaluator], input_types: pyarrow.DataType | list[pyarrow.DataType], return_type: pyarrow.DataType, volatility: Volatility | str, + arguments: Optional[list[Any]] = None, name: str | None = None, ) -> WindowUDF: """Create a new User-Defined Window Function. @@ -467,12 +469,13 @@ def udwf( input_types: The data types of the arguments to ``func``. return_type: The data type of the return value. volatility: See :py:class:`Volatility` for allowed values. + arguments: A list of arguments to pass in to the __init__ method for accum. name: A descriptive name for the function. Returns: A user-defined window function. """ - if not isinstance(func, WindowEvaluator): + if not issubclass(func, WindowEvaluator): raise TypeError( "`func` must implement the abstract base class WindowEvaluator" ) @@ -480,10 +483,12 @@ def udwf( name = func.__class__.__qualname__.lower() if isinstance(input_types, pyarrow.DataType): input_types = [input_types] + arguments = [] if arguments is None else arguments return WindowUDF( name=name, func=func, input_types=input_types, return_type=return_type, volatility=volatility, + arguments=arguments, ) diff --git a/python/tests/test_udwf.py b/python/tests/test_udwf.py index 67c0979fe..67966eeaa 100644 --- a/python/tests/test_udwf.py +++ b/python/tests/test_udwf.py @@ -24,7 +24,7 @@ class ExponentialSmoothDefault(WindowEvaluator): - def __init__(self, alpha: float) -> None: + def __init__(self, alpha: float = 0.8) -> None: self.alpha = alpha def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array: @@ -183,7 +183,7 @@ def df(): def test_udwf_errors(df): with pytest.raises(TypeError): udwf( - NotSubclassOfWindowEvaluator(), + NotSubclassOfWindowEvaluator, pa.float64(), pa.float64(), volatility="immutable", @@ -191,38 +191,50 @@ def test_udwf_errors(df): smooth_default = udwf( - ExponentialSmoothDefault(0.9), + ExponentialSmoothDefault, + pa.float64(), + pa.float64(), + volatility="immutable", + arguments=[0.9], +) + +smooth_no_arugments = udwf( + ExponentialSmoothDefault, pa.float64(), pa.float64(), volatility="immutable", ) smooth_bounded = udwf( - ExponentialSmoothBounded(0.9), + ExponentialSmoothBounded, pa.float64(), pa.float64(), volatility="immutable", + arguments=[0.9], ) smooth_rank = udwf( - ExponentialSmoothRank(0.9), + ExponentialSmoothRank, pa.utf8(), pa.float64(), volatility="immutable", + arguments=[0.9], ) smooth_frame = udwf( - ExponentialSmoothFrame(0.9), + ExponentialSmoothFrame, pa.float64(), pa.float64(), volatility="immutable", + arguments=[0.9], ) smooth_two_col = udwf( - SmoothTwoColumn(0.9), + SmoothTwoColumn, [pa.int64(), pa.int64()], pa.float64(), volatility="immutable", + arguments=[0.9], ) data_test_udwf_functions = [ @@ -231,6 +243,11 @@ def test_udwf_errors(df): smooth_default(column("a")), [0, 0.9, 1.89, 2.889, 3.889, 4.889, 5.889], ), + ( + "default_udwf_no_arguments", + smooth_no_arugments(column("a")), + [0, 0.8, 1.76, 2.752, 3.75, 4.75, 5.75], + ), ( "default_udwf_partitioned", smooth_default(column("a")).partition_by(column("c")).build(), diff --git a/src/udwf.rs b/src/udwf.rs index 31cc5e60e..68ef6620d 100644 --- a/src/udwf.rs +++ b/src/udwf.rs @@ -195,9 +195,17 @@ impl PartitionEvaluator for RustPartitionEvaluator { } } -pub fn to_rust_partition_evaluator(evaluator: PyObject) -> PartitionEvaluatorFactory { +pub fn to_rust_partition_evaluator( + evaluator: PyObject, + arguments: Vec, +) -> PartitionEvaluatorFactory { Arc::new(move || -> Result> { - let evaluator = Python::with_gil(|py| evaluator.clone_ref(py)); + let evaluator = Python::with_gil(|py| { + let py_args = PyTuple::new_bound(py, arguments.iter()); + evaluator + .call1(py, py_args) + .map_err(|e| DataFusionError::Execution(e.to_string())) + })?; Ok(Box::new(RustPartitionEvaluator::new(evaluator))) }) } @@ -212,13 +220,14 @@ pub struct PyWindowUDF { #[pymethods] impl PyWindowUDF { #[new] - #[pyo3(signature=(name, evaluator, input_types, return_type, volatility))] + #[pyo3(signature=(name, evaluator, input_types, return_type, volatility, arguments))] fn new( name: &str, evaluator: PyObject, input_types: Vec>, return_type: PyArrowType, volatility: &str, + arguments: Vec, ) -> PyResult { let return_type = return_type.0; let input_types = input_types.into_iter().map(|t| t.0).collect(); @@ -228,7 +237,7 @@ impl PyWindowUDF { input_types, return_type, parse_volatility(volatility)?, - to_rust_partition_evaluator(evaluator), + to_rust_partition_evaluator(evaluator, arguments), )); Ok(Self { function }) } From 15bc0ad78e80dfbf9cf09bfaa8eddab4c09f288b Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 2 Oct 2024 09:36:30 -0400 Subject: [PATCH 5/9] Improve handling of udf when user provides a class instead of bare function --- python/datafusion/udf.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/python/datafusion/udf.py b/python/datafusion/udf.py index e33b1ccae..88e9a1936 100644 --- a/python/datafusion/udf.py +++ b/python/datafusion/udf.py @@ -86,7 +86,7 @@ def __init__( self, name: str | None, func: Callable[..., _R], - input_types: list[pyarrow.DataType], + input_types: pyarrow.DataType | list[pyarrow.DataType], return_type: _R, volatility: Volatility | str, ) -> None: @@ -94,6 +94,8 @@ def __init__( See helper method :py:func:`udf` for argument details. """ + if isinstance(input_types, pyarrow.DataType): + input_types = [input_types] self._udf = df_internal.ScalarUDF( name, func, input_types, return_type, str(volatility) ) @@ -133,7 +135,10 @@ def udf( if not callable(func): raise TypeError("`func` argument must be callable") if name is None: - name = func.__qualname__.lower() + if hasattr(func, "__qualname__"): + name = func.__qualname__.lower() + else: + name = func.__class__.__name__.lower() return ScalarUDF( name=name, func=func, From a989d27b1f06d57adf0dc4c9b6e85eac239b1ece Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 2 Oct 2024 09:37:27 -0400 Subject: [PATCH 6/9] Add unit tests for UDF showing callable class --- python/datafusion/tests/test_udf.py | 79 +++++++++++++++++++++++++++++ python/tests/test_dataframe.py | 16 ------ python/tests/test_udaf.py | 19 +------ 3 files changed, 80 insertions(+), 34 deletions(-) create mode 100644 python/datafusion/tests/test_udf.py diff --git a/python/datafusion/tests/test_udf.py b/python/datafusion/tests/test_udf.py new file mode 100644 index 000000000..a601ddcad --- /dev/null +++ b/python/datafusion/tests/test_udf.py @@ -0,0 +1,79 @@ +from datafusion import udf, column +import pyarrow as pa +import pytest + + +@pytest.fixture +def df(ctx): + # create a RecordBatch and a new DataFrame from it + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3]), pa.array([4, 4, 6])], + names=["a", "b"], + ) + return ctx.create_dataframe([[batch]], name="test_table") + + +def test_udf(df): + # is_null is a pa function over arrays + is_null = udf( + lambda x: x.is_null(), + [pa.int64()], + pa.bool_(), + volatility="immutable", + ) + + df = df.select(is_null(column("a"))) + result = df.collect()[0].column(0) + + assert result == pa.array([False, False, False]) + + +def test_register_udf(ctx, df) -> None: + is_null = udf( + lambda x: x.is_null(), + [pa.float64()], + pa.bool_(), + volatility="immutable", + name="is_null", + ) + + ctx.register_udf(is_null) + + df_result = ctx.sql("select is_null(a) from test_table") + result = df_result.collect()[0].column(0) + + assert result == pa.array([False, False, False]) + + +class OverThresholdUDF: + def __init__(self, threshold: int = 0) -> None: + self.threshold = threshold + + def __call__(self, values: pa.Array) -> pa.Array: + return pa.array(v.as_py() >= self.threshold for v in values) + + +def test_udf_with_parameters(df) -> None: + udf_no_param = udf( + OverThresholdUDF(), + pa.int64(), + pa.bool_(), + volatility="immutable", + ) + + df1 = df.select(udf_no_param(column("a"))) + result = df1.collect()[0].column(0) + + assert result == pa.array([True, True, True]) + + udf_with_param = udf( + OverThresholdUDF(2), + pa.int64(), + pa.bool_(), + volatility="immutable", + ) + + df2 = df.select(udf_with_param(column("a"))) + result = df2.collect()[0].column(0) + + assert result == pa.array([False, True, True]) diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index ad7f728b4..e89c57159 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -29,7 +29,6 @@ WindowFrame, column, literal, - udf, ) from datafusion.expr import Window @@ -236,21 +235,6 @@ def test_unnest_without_nulls(nested_df): assert result.column(1) == pa.array([7, 8, 8, 9, 9, 9]) -def test_udf(df): - # is_null is a pa function over arrays - is_null = udf( - lambda x: x.is_null(), - [pa.int64()], - pa.bool_(), - volatility="immutable", - ) - - df = df.select(is_null(column("a"))) - result = df.collect()[0].column(0) - - assert result == pa.array([False, False, False]) - - def test_join(): ctx = SessionContext() diff --git a/python/tests/test_udaf.py b/python/tests/test_udaf.py index 792e45f0a..434608093 100644 --- a/python/tests/test_udaf.py +++ b/python/tests/test_udaf.py @@ -21,7 +21,7 @@ import pyarrow.compute as pc import pytest -from datafusion import Accumulator, column, udaf, udf +from datafusion import Accumulator, column, udaf class Summarize(Accumulator): @@ -173,20 +173,3 @@ def test_register_udaf(ctx, df) -> None: df_result = ctx.sql("select summarize(b) from test_table") assert df_result.collect()[0][0][0].as_py() == 14.0 - - -def test_register_udf(ctx, df) -> None: - is_null = udf( - lambda x: x.is_null(), - [pa.float64()], - pa.bool_(), - volatility="immutable", - name="is_null", - ) - - ctx.register_udf(is_null) - - df_result = ctx.sql("select is_null(a) from test_table") - result = df_result.collect()[0].column(0) - - assert result == pa.array([False, False, False]) From b7468dc93ccd8d93d3eb4b42496a406428527fe2 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 2 Oct 2024 09:50:10 -0400 Subject: [PATCH 7/9] Add license text --- python/datafusion/tests/test_udf.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/python/datafusion/tests/test_udf.py b/python/datafusion/tests/test_udf.py index a601ddcad..568a66dbb 100644 --- a/python/datafusion/tests/test_udf.py +++ b/python/datafusion/tests/test_udf.py @@ -1,3 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + from datafusion import udf, column import pyarrow as pa import pytest From 250baea24c39d453b56ff78ed1bdafe7f19da113 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Fri, 4 Oct 2024 08:31:28 -0400 Subject: [PATCH 8/9] Switching to use factory methods for udaf and udwf --- python/datafusion/udf.py | 113 +++++++++++++++++++++++++++----------- python/tests/test_udaf.py | 27 +++++---- python/tests/test_udwf.py | 25 ++++----- src/udaf.rs | 13 ++--- src/udwf.rs | 13 ++--- 5 files changed, 114 insertions(+), 77 deletions(-) diff --git a/python/datafusion/udf.py b/python/datafusion/udf.py index 88e9a1936..291ef2bae 100644 --- a/python/datafusion/udf.py +++ b/python/datafusion/udf.py @@ -21,9 +21,9 @@ import datafusion._internal as df_internal from datafusion.expr import Expr -from typing import Callable, TYPE_CHECKING, TypeVar, Type +from typing import Callable, TYPE_CHECKING, TypeVar from abc import ABCMeta, abstractmethod -from typing import List, Any, Optional +from typing import List, Optional from enum import Enum import pyarrow @@ -84,7 +84,7 @@ class ScalarUDF: def __init__( self, - name: str | None, + name: Optional[str], func: Callable[..., _R], input_types: pyarrow.DataType | list[pyarrow.DataType], return_type: _R, @@ -115,7 +115,7 @@ def udf( input_types: list[pyarrow.DataType], return_type: _R, volatility: Volatility | str, - name: str | None = None, + name: Optional[str] = None, ) -> ScalarUDF: """Create a new User-Defined Function. @@ -181,13 +181,12 @@ class AggregateUDF: def __init__( self, - name: str | None, - accumulator: Type[Accumulator], + name: Optional[str], + accumulator: Callable[[], Accumulator], input_types: list[pyarrow.DataType], return_type: pyarrow.DataType, state_type: list[pyarrow.DataType], volatility: Volatility | str, - arguments: list[Any], ) -> None: """Instantiate a user-defined aggregate function (UDAF). @@ -201,7 +200,6 @@ def __init__( return_type, state_type, str(volatility), - arguments, ) def __call__(self, *args: Expr) -> Expr: @@ -215,17 +213,47 @@ def __call__(self, *args: Expr) -> Expr: @staticmethod def udaf( - accum: Type[Accumulator], + accum: Callable[[], Accumulator], input_types: pyarrow.DataType | list[pyarrow.DataType], return_type: pyarrow.DataType, state_type: list[pyarrow.DataType], volatility: Volatility | str, - arguments: Optional[list[Any]] = None, - name: str | None = None, + name: Optional[str] = None, ) -> AggregateUDF: """Create a new User-Defined Aggregate Function. - The accumulator function must be callable and implement :py:class:`Accumulator`. + If your :py:class:`Accumulator` can be instantiated with no arguments, you + can simply pass it's type as ``accum``. If you need to pass additional arguments + to it's constructor, you can define a lambda or a factory method. During runtime + the :py:class:`Accumulator` will be constructed for every instance in + which this UDAF is used. The following examples are all valid. + + .. code-block:: python + import pyarrow as pa + import pyarrow.compute as pc + + class Summarize(Accumulator): + def __init__(self, bias: float = 0.0): + self._sum = pa.scalar(bias) + + def state(self) -> List[pa.Scalar]: + return [self._sum] + + def update(self, values: pa.Array) -> None: + self._sum = pa.scalar(self._sum.as_py() + pc.sum(values).as_py()) + + def merge(self, states: List[pa.Array]) -> None: + self._sum = pa.scalar(self._sum.as_py() + pc.sum(states[0]).as_py()) + + def evaluate(self) -> pa.Scalar: + return self._sum + + def sum_bias_10() -> Summarize: + return Summarize(10.0) + + udaf1 = udaf(Summarize, pa.float64(), pa.float64(), [pa.float64()], "immutable") + udaf2 = udaf(sum_bias_10, pa.float64(), pa.float64(), [pa.float64()], "immutable") + udaf3 = udaf(lambda: Summarize(20.0), pa.float64(), pa.float64(), [pa.float64()], "immutable") Args: accum: The accumulator python function. @@ -233,22 +261,22 @@ def udaf( return_type: The data type of the return value. state_type: The data types of the intermediate accumulation. volatility: See :py:class:`Volatility` for allowed values. - arguments: A list of arguments to pass in to the __init__ method for accum. name: A descriptive name for the function. Returns: A user-defined aggregate function, which can be used in either data aggregation or window function calls. - """ - if not issubclass(accum, Accumulator): + """ # noqa W505 + if not callable(accum): + raise TypeError("`func` must be callable.") + if not isinstance(accum.__call__(), Accumulator): raise TypeError( - "`accum` must implement the abstract base class Accumulator" + "Accumulator must implement the abstract base class Accumulator" ) if name is None: - name = accum.__qualname__.lower() + name = accum.__call__().__class__.__qualname__.lower() if isinstance(input_types, pyarrow.DataType): input_types = [input_types] - arguments = [] if arguments is None else arguments return AggregateUDF( name=name, accumulator=accum, @@ -256,7 +284,6 @@ def udaf( return_type=return_type, state_type=state_type, volatility=volatility, - arguments=arguments, ) @@ -433,12 +460,11 @@ class WindowUDF: def __init__( self, - name: str | None, - func: Type[WindowEvaluator], + name: Optional[str], + func: Callable[[], WindowEvaluator], input_types: list[pyarrow.DataType], return_type: pyarrow.DataType, volatility: Volatility | str, - arguments: list[Any], ) -> None: """Instantiate a user-defined window function (UDWF). @@ -446,7 +472,7 @@ def __init__( descriptions. """ self._udwf = df_internal.WindowUDF( - name, func, input_types, return_type, str(volatility), arguments + name, func, input_types, return_type, str(volatility) ) def __call__(self, *args: Expr) -> Expr: @@ -460,17 +486,40 @@ def __call__(self, *args: Expr) -> Expr: @staticmethod def udwf( - func: Type[WindowEvaluator], + func: Callable[[], WindowEvaluator], input_types: pyarrow.DataType | list[pyarrow.DataType], return_type: pyarrow.DataType, volatility: Volatility | str, - arguments: Optional[list[Any]] = None, - name: str | None = None, + name: Optional[str] = None, ) -> WindowUDF: """Create a new User-Defined Window Function. + If your :py:class:`WindowEvaluator` can be instantiated with no arguments, you + can simply pass it's type as ``func``. If you need to pass additional arguments + to it's constructor, you can define a lambda or a factory method. During runtime + the :py:class:`WindowEvaluator` will be constructed for every instance in + which this UDWF is used. The following examples are all valid. + + .. code-block:: python + + import pyarrow as pa + + class BiasedNumbers(WindowEvaluator): + def __init__(self, start: int = 0) -> None: + self.start = start + + def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array: + return pa.array([self.start + i for i in range(num_rows)]) + + def bias_10() -> BiasedNumbers: + return BiasedNumbers(10) + + udwf1 = udwf(BiasedNumbers, pa.int64(), pa.int64(), "immutable") + udwf2 = udwf(bias_10, pa.int64(), pa.int64(), "immutable") + udwf3 = udwf(lambda: BiasedNumbers(20), pa.int64(), pa.int64(), "immutable") + Args: - func: The python function. + func: A callable to create the window function. input_types: The data types of the arguments to ``func``. return_type: The data type of the return value. volatility: See :py:class:`Volatility` for allowed values. @@ -479,21 +528,21 @@ def udwf( Returns: A user-defined window function. - """ - if not issubclass(func, WindowEvaluator): + """ # noqa W505 + if not callable(func): + raise TypeError("`func` must be callable.") + if not isinstance(func.__call__(), WindowEvaluator): raise TypeError( "`func` must implement the abstract base class WindowEvaluator" ) if name is None: - name = func.__class__.__qualname__.lower() + name = func.__call__().__class__.__qualname__.lower() if isinstance(input_types, pyarrow.DataType): input_types = [input_types] - arguments = [] if arguments is None else arguments return WindowUDF( name=name, func=func, input_types=input_types, return_type=return_type, volatility=volatility, - arguments=arguments, ) diff --git a/python/tests/test_udaf.py b/python/tests/test_udaf.py index 434608093..8f31748e0 100644 --- a/python/tests/test_udaf.py +++ b/python/tests/test_udaf.py @@ -79,22 +79,19 @@ def test_errors(df): volatility="immutable", ) - accum = udaf( - MissingMethods, - pa.int64(), - pa.int64(), - [pa.int64()], - volatility="immutable", - ) - df = df.aggregate([], [accum(column("a"))]) - msg = ( "Can't instantiate abstract class MissingMethods (without an implementation " "for abstract methods 'evaluate', 'merge', 'update'|with abstract methods " "evaluate, merge, update)" ) with pytest.raises(Exception, match=msg): - df.collect() + accum = udaf( # noqa F841 + MissingMethods, + pa.int64(), + pa.int64(), + [pa.int64()], + volatility="immutable", + ) def test_udaf_aggregate(df): @@ -125,12 +122,11 @@ def test_udaf_aggregate_with_arguments(df): bias = 10.0 summarize = udaf( - Summarize, + lambda: Summarize(bias), pa.float64(), pa.float64(), [pa.float64()], volatility="immutable", - arguments=[bias], ) df1 = df.aggregate([], [summarize(column("a"))]) @@ -140,6 +136,13 @@ def test_udaf_aggregate_with_arguments(df): assert result.column(0) == pa.array([bias + 1.0 + 2.0 + 3.0]) + df2 = df.aggregate([], [summarize(column("a"))]) + + # Run a second time to ensure the state is properly reset + result = df2.collect()[0] + + assert result.column(0) == pa.array([bias + 1.0 + 2.0 + 3.0]) + def test_group_by(df): summarize = udaf( diff --git a/python/tests/test_udwf.py b/python/tests/test_udwf.py index 67966eeaa..2099ac9bc 100644 --- a/python/tests/test_udwf.py +++ b/python/tests/test_udwf.py @@ -24,7 +24,7 @@ class ExponentialSmoothDefault(WindowEvaluator): - def __init__(self, alpha: float = 0.8) -> None: + def __init__(self, alpha: float = 0.9) -> None: self.alpha = alpha def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array: @@ -44,7 +44,7 @@ def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array: class ExponentialSmoothBounded(WindowEvaluator): - def __init__(self, alpha: float) -> None: + def __init__(self, alpha: float = 0.9) -> None: self.alpha = alpha def supports_bounded_execution(self) -> bool: @@ -75,7 +75,7 @@ def evaluate( class ExponentialSmoothRank(WindowEvaluator): - def __init__(self, alpha: float) -> None: + def __init__(self, alpha: float = 0.9) -> None: self.alpha = alpha def include_rank(self) -> bool: @@ -101,7 +101,7 @@ def evaluate_all_with_rank( class ExponentialSmoothFrame(WindowEvaluator): - def __init__(self, alpha: float) -> None: + def __init__(self, alpha: float = 0.9) -> None: self.alpha = alpha def uses_window_frame(self) -> bool: @@ -134,7 +134,7 @@ class SmoothTwoColumn(WindowEvaluator): the previous and next rows. """ - def __init__(self, alpha: float) -> None: + def __init__(self, alpha: float = 0.9) -> None: self.alpha = alpha def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array: @@ -195,11 +195,10 @@ def test_udwf_errors(df): pa.float64(), pa.float64(), volatility="immutable", - arguments=[0.9], ) -smooth_no_arugments = udwf( - ExponentialSmoothDefault, +smooth_w_arguments = udwf( + lambda: ExponentialSmoothDefault(0.8), pa.float64(), pa.float64(), volatility="immutable", @@ -210,7 +209,6 @@ def test_udwf_errors(df): pa.float64(), pa.float64(), volatility="immutable", - arguments=[0.9], ) smooth_rank = udwf( @@ -218,7 +216,6 @@ def test_udwf_errors(df): pa.utf8(), pa.float64(), volatility="immutable", - arguments=[0.9], ) smooth_frame = udwf( @@ -226,7 +223,6 @@ def test_udwf_errors(df): pa.float64(), pa.float64(), volatility="immutable", - arguments=[0.9], ) smooth_two_col = udwf( @@ -234,18 +230,17 @@ def test_udwf_errors(df): [pa.int64(), pa.int64()], pa.float64(), volatility="immutable", - arguments=[0.9], ) data_test_udwf_functions = [ ( - "default_udwf", + "default_udwf_no_arguments", smooth_default(column("a")), [0, 0.9, 1.89, 2.889, 3.889, 4.889, 5.889], ), ( - "default_udwf_no_arguments", - smooth_no_arugments(column("a")), + "default_udwf_w_arguments", + smooth_w_arguments(column("a")), [0, 0.8, 1.76, 2.752, 3.75, 4.75, 5.75], ), ( diff --git a/src/udaf.rs b/src/udaf.rs index b9db47a88..a6aa59ac3 100644 --- a/src/udaf.rs +++ b/src/udaf.rs @@ -128,15 +128,11 @@ impl Accumulator for RustAccumulator { } } -pub fn to_rust_accumulator( - accum: PyObject, - arguments: Vec, -) -> AccumulatorFactoryFunction { +pub fn to_rust_accumulator(accum: PyObject) -> AccumulatorFactoryFunction { Arc::new(move |_| -> Result> { let accum = Python::with_gil(|py| { - let py_args = PyTuple::new_bound(py, arguments.iter()); accum - .call1(py, py_args) + .call0(py) .map_err(|e| DataFusionError::Execution(format!("{e}"))) })?; Ok(Box::new(RustAccumulator::new(accum))) @@ -153,7 +149,7 @@ pub struct PyAggregateUDF { #[pymethods] impl PyAggregateUDF { #[new] - #[pyo3(signature=(name, accumulator, input_type, return_type, state_type, volatility, arguments))] + #[pyo3(signature=(name, accumulator, input_type, return_type, state_type, volatility))] fn new( name: &str, accumulator: PyObject, @@ -161,14 +157,13 @@ impl PyAggregateUDF { return_type: PyArrowType, state_type: PyArrowType>, volatility: &str, - arguments: Vec, ) -> PyResult { let function = create_udaf( name, input_type.0, Arc::new(return_type.0), parse_volatility(volatility)?, - to_rust_accumulator(accumulator, arguments), + to_rust_accumulator(accumulator), Arc::new(state_type.0), ); Ok(Self { function }) diff --git a/src/udwf.rs b/src/udwf.rs index 68ef6620d..43c21ec7b 100644 --- a/src/udwf.rs +++ b/src/udwf.rs @@ -195,15 +195,11 @@ impl PartitionEvaluator for RustPartitionEvaluator { } } -pub fn to_rust_partition_evaluator( - evaluator: PyObject, - arguments: Vec, -) -> PartitionEvaluatorFactory { +pub fn to_rust_partition_evaluator(evaluator: PyObject) -> PartitionEvaluatorFactory { Arc::new(move || -> Result> { let evaluator = Python::with_gil(|py| { - let py_args = PyTuple::new_bound(py, arguments.iter()); evaluator - .call1(py, py_args) + .call0(py) .map_err(|e| DataFusionError::Execution(e.to_string())) })?; Ok(Box::new(RustPartitionEvaluator::new(evaluator))) @@ -220,14 +216,13 @@ pub struct PyWindowUDF { #[pymethods] impl PyWindowUDF { #[new] - #[pyo3(signature=(name, evaluator, input_types, return_type, volatility, arguments))] + #[pyo3(signature=(name, evaluator, input_types, return_type, volatility))] fn new( name: &str, evaluator: PyObject, input_types: Vec>, return_type: PyArrowType, volatility: &str, - arguments: Vec, ) -> PyResult { let return_type = return_type.0; let input_types = input_types.into_iter().map(|t| t.0).collect(); @@ -237,7 +232,7 @@ impl PyWindowUDF { input_types, return_type, parse_volatility(volatility)?, - to_rust_partition_evaluator(evaluator, arguments), + to_rust_partition_evaluator(evaluator), )); Ok(Self { function }) } From cb7dc7c3a03955a1b5bff0d2c72266963efd312d Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Fri, 4 Oct 2024 11:38:52 -0400 Subject: [PATCH 9/9] Move new tests to the new testing directory --- python/{datafusion => }/tests/test_plans.py | 0 python/{datafusion => }/tests/test_udf.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename python/{datafusion => }/tests/test_plans.py (100%) rename python/{datafusion => }/tests/test_udf.py (100%) diff --git a/python/datafusion/tests/test_plans.py b/python/tests/test_plans.py similarity index 100% rename from python/datafusion/tests/test_plans.py rename to python/tests/test_plans.py diff --git a/python/datafusion/tests/test_udf.py b/python/tests/test_udf.py similarity index 100% rename from python/datafusion/tests/test_udf.py rename to python/tests/test_udf.py