diff --git a/python/datafusion/udf.py b/python/datafusion/udf.py index bb7a90866..291ef2bae 100644 --- a/python/datafusion/udf.py +++ b/python/datafusion/udf.py @@ -23,7 +23,7 @@ from datafusion.expr import Expr from typing import Callable, TYPE_CHECKING, TypeVar from abc import ABCMeta, abstractmethod -from typing import List +from typing import List, Optional from enum import Enum import pyarrow @@ -84,9 +84,9 @@ class ScalarUDF: def __init__( self, - name: str | None, + name: Optional[str], func: Callable[..., _R], - input_types: list[pyarrow.DataType], + input_types: pyarrow.DataType | list[pyarrow.DataType], return_type: _R, volatility: Volatility | str, ) -> None: @@ -94,6 +94,8 @@ def __init__( See helper method :py:func:`udf` for argument details. """ + if isinstance(input_types, pyarrow.DataType): + input_types = [input_types] self._udf = df_internal.ScalarUDF( name, func, input_types, return_type, str(volatility) ) @@ -104,8 +106,8 @@ def __call__(self, *args: Expr) -> Expr: This function is not typically called by an end user. These calls will occur during the evaluation of the dataframe. """ - args = [arg.expr for arg in args] - return Expr(self._udf.__call__(*args)) + args_raw = [arg.expr for arg in args] + return Expr(self._udf.__call__(*args_raw)) @staticmethod def udf( @@ -113,7 +115,7 @@ def udf( input_types: list[pyarrow.DataType], return_type: _R, volatility: Volatility | str, - name: str | None = None, + name: Optional[str] = None, ) -> ScalarUDF: """Create a new User-Defined Function. @@ -133,7 +135,10 @@ def udf( if not callable(func): raise TypeError("`func` argument must be callable") if name is None: - name = func.__qualname__.lower() + if hasattr(func, "__qualname__"): + name = func.__qualname__.lower() + else: + name = func.__class__.__name__.lower() return ScalarUDF( name=name, func=func, @@ -167,10 +172,6 @@ def evaluate(self) -> pyarrow.Scalar: pass -if TYPE_CHECKING: - _A = TypeVar("_A", bound=(Callable[..., _R], Accumulator)) - - class AggregateUDF: """Class for performing scalar user-defined functions (UDF). @@ -180,10 +181,10 @@ class AggregateUDF: def __init__( self, - name: str | None, - accumulator: _A, + name: Optional[str], + accumulator: Callable[[], Accumulator], input_types: list[pyarrow.DataType], - return_type: _R, + return_type: pyarrow.DataType, state_type: list[pyarrow.DataType], volatility: Volatility | str, ) -> None: @@ -193,7 +194,12 @@ def __init__( descriptions. """ self._udaf = df_internal.AggregateUDF( - name, accumulator, input_types, return_type, state_type, str(volatility) + name, + accumulator, + input_types, + return_type, + state_type, + str(volatility), ) def __call__(self, *args: Expr) -> Expr: @@ -202,21 +208,52 @@ def __call__(self, *args: Expr) -> Expr: This function is not typically called by an end user. These calls will occur during the evaluation of the dataframe. """ - args = [arg.expr for arg in args] - return Expr(self._udaf.__call__(*args)) + args_raw = [arg.expr for arg in args] + return Expr(self._udaf.__call__(*args_raw)) @staticmethod def udaf( - accum: _A, - input_types: list[pyarrow.DataType], - return_type: _R, + accum: Callable[[], Accumulator], + input_types: pyarrow.DataType | list[pyarrow.DataType], + return_type: pyarrow.DataType, state_type: list[pyarrow.DataType], volatility: Volatility | str, - name: str | None = None, + name: Optional[str] = None, ) -> AggregateUDF: """Create a new User-Defined Aggregate Function. - The accumulator function must be callable and implement :py:class:`Accumulator`. + If your :py:class:`Accumulator` can be instantiated with no arguments, you + can simply pass it's type as ``accum``. If you need to pass additional arguments + to it's constructor, you can define a lambda or a factory method. During runtime + the :py:class:`Accumulator` will be constructed for every instance in + which this UDAF is used. The following examples are all valid. + + .. code-block:: python + import pyarrow as pa + import pyarrow.compute as pc + + class Summarize(Accumulator): + def __init__(self, bias: float = 0.0): + self._sum = pa.scalar(bias) + + def state(self) -> List[pa.Scalar]: + return [self._sum] + + def update(self, values: pa.Array) -> None: + self._sum = pa.scalar(self._sum.as_py() + pc.sum(values).as_py()) + + def merge(self, states: List[pa.Array]) -> None: + self._sum = pa.scalar(self._sum.as_py() + pc.sum(states[0]).as_py()) + + def evaluate(self) -> pa.Scalar: + return self._sum + + def sum_bias_10() -> Summarize: + return Summarize(10.0) + + udaf1 = udaf(Summarize, pa.float64(), pa.float64(), [pa.float64()], "immutable") + udaf2 = udaf(sum_bias_10, pa.float64(), pa.float64(), [pa.float64()], "immutable") + udaf3 = udaf(lambda: Summarize(20.0), pa.float64(), pa.float64(), [pa.float64()], "immutable") Args: accum: The accumulator python function. @@ -229,14 +266,16 @@ def udaf( Returns: A user-defined aggregate function, which can be used in either data aggregation or window function calls. - """ - if not issubclass(accum, Accumulator): + """ # noqa W505 + if not callable(accum): + raise TypeError("`func` must be callable.") + if not isinstance(accum.__call__(), Accumulator): raise TypeError( - "`accum` must implement the abstract base class Accumulator" + "Accumulator must implement the abstract base class Accumulator" ) if name is None: - name = accum.__qualname__.lower() - if isinstance(input_types, pyarrow.lib.DataType): + name = accum.__call__().__class__.__qualname__.lower() + if isinstance(input_types, pyarrow.DataType): input_types = [input_types] return AggregateUDF( name=name, @@ -421,8 +460,8 @@ class WindowUDF: def __init__( self, - name: str | None, - func: WindowEvaluator, + name: Optional[str], + func: Callable[[], WindowEvaluator], input_types: list[pyarrow.DataType], return_type: pyarrow.DataType, volatility: Volatility | str, @@ -447,30 +486,57 @@ def __call__(self, *args: Expr) -> Expr: @staticmethod def udwf( - func: WindowEvaluator, + func: Callable[[], WindowEvaluator], input_types: pyarrow.DataType | list[pyarrow.DataType], return_type: pyarrow.DataType, volatility: Volatility | str, - name: str | None = None, + name: Optional[str] = None, ) -> WindowUDF: """Create a new User-Defined Window Function. + If your :py:class:`WindowEvaluator` can be instantiated with no arguments, you + can simply pass it's type as ``func``. If you need to pass additional arguments + to it's constructor, you can define a lambda or a factory method. During runtime + the :py:class:`WindowEvaluator` will be constructed for every instance in + which this UDWF is used. The following examples are all valid. + + .. code-block:: python + + import pyarrow as pa + + class BiasedNumbers(WindowEvaluator): + def __init__(self, start: int = 0) -> None: + self.start = start + + def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array: + return pa.array([self.start + i for i in range(num_rows)]) + + def bias_10() -> BiasedNumbers: + return BiasedNumbers(10) + + udwf1 = udwf(BiasedNumbers, pa.int64(), pa.int64(), "immutable") + udwf2 = udwf(bias_10, pa.int64(), pa.int64(), "immutable") + udwf3 = udwf(lambda: BiasedNumbers(20), pa.int64(), pa.int64(), "immutable") + Args: - func: The python function. + func: A callable to create the window function. input_types: The data types of the arguments to ``func``. return_type: The data type of the return value. volatility: See :py:class:`Volatility` for allowed values. + arguments: A list of arguments to pass in to the __init__ method for accum. name: A descriptive name for the function. Returns: A user-defined window function. - """ - if not isinstance(func, WindowEvaluator): + """ # noqa W505 + if not callable(func): + raise TypeError("`func` must be callable.") + if not isinstance(func.__call__(), WindowEvaluator): raise TypeError( "`func` must implement the abstract base class WindowEvaluator" ) if name is None: - name = func.__class__.__qualname__.lower() + name = func.__call__().__class__.__qualname__.lower() if isinstance(input_types, pyarrow.DataType): input_types = [input_types] return WindowUDF( diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index ad7f728b4..e89c57159 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -29,7 +29,6 @@ WindowFrame, column, literal, - udf, ) from datafusion.expr import Window @@ -236,21 +235,6 @@ def test_unnest_without_nulls(nested_df): assert result.column(1) == pa.array([7, 8, 8, 9, 9, 9]) -def test_udf(df): - # is_null is a pa function over arrays - is_null = udf( - lambda x: x.is_null(), - [pa.int64()], - pa.bool_(), - volatility="immutable", - ) - - df = df.select(is_null(column("a"))) - result = df.collect()[0].column(0) - - assert result == pa.array([False, False, False]) - - def test_join(): ctx = SessionContext() diff --git a/python/datafusion/tests/test_plans.py b/python/tests/test_plans.py similarity index 100% rename from python/datafusion/tests/test_plans.py rename to python/tests/test_plans.py diff --git a/python/tests/test_udaf.py b/python/tests/test_udaf.py index 6f2525b0f..8f31748e0 100644 --- a/python/tests/test_udaf.py +++ b/python/tests/test_udaf.py @@ -21,14 +21,14 @@ import pyarrow.compute as pc import pytest -from datafusion import Accumulator, column, udaf, udf +from datafusion import Accumulator, column, udaf class Summarize(Accumulator): """Interface of a user-defined accumulation.""" - def __init__(self): - self._sum = pa.scalar(0.0) + def __init__(self, initial_value: float = 0.0): + self._sum = pa.scalar(initial_value) def state(self) -> List[pa.Scalar]: return [self._sum] @@ -79,25 +79,22 @@ def test_errors(df): volatility="immutable", ) - accum = udaf( - MissingMethods, - pa.int64(), - pa.int64(), - [pa.int64()], - volatility="immutable", - ) - df = df.aggregate([], [accum(column("a"))]) - msg = ( "Can't instantiate abstract class MissingMethods (without an implementation " "for abstract methods 'evaluate', 'merge', 'update'|with abstract methods " "evaluate, merge, update)" ) with pytest.raises(Exception, match=msg): - df.collect() + accum = udaf( # noqa F841 + MissingMethods, + pa.int64(), + pa.int64(), + [pa.int64()], + volatility="immutable", + ) -def test_aggregate(df): +def test_udaf_aggregate(df): summarize = udaf( Summarize, pa.float64(), @@ -106,13 +103,46 @@ def test_aggregate(df): volatility="immutable", ) - df = df.aggregate([], [summarize(column("a"))]) + df1 = df.aggregate([], [summarize(column("a"))]) # execute and collect the first (and only) batch - result = df.collect()[0] + result = df1.collect()[0] assert result.column(0) == pa.array([1.0 + 2.0 + 3.0]) + df2 = df.aggregate([], [summarize(column("a"))]) + + # Run a second time to ensure the state is properly reset + result = df2.collect()[0] + + assert result.column(0) == pa.array([1.0 + 2.0 + 3.0]) + + +def test_udaf_aggregate_with_arguments(df): + bias = 10.0 + + summarize = udaf( + lambda: Summarize(bias), + pa.float64(), + pa.float64(), + [pa.float64()], + volatility="immutable", + ) + + df1 = df.aggregate([], [summarize(column("a"))]) + + # execute and collect the first (and only) batch + result = df1.collect()[0] + + assert result.column(0) == pa.array([bias + 1.0 + 2.0 + 3.0]) + + df2 = df.aggregate([], [summarize(column("a"))]) + + # Run a second time to ensure the state is properly reset + result = df2.collect()[0] + + assert result.column(0) == pa.array([bias + 1.0 + 2.0 + 3.0]) + def test_group_by(df): summarize = udaf( @@ -146,20 +176,3 @@ def test_register_udaf(ctx, df) -> None: df_result = ctx.sql("select summarize(b) from test_table") assert df_result.collect()[0][0][0].as_py() == 14.0 - - -def test_register_udf(ctx, df) -> None: - is_null = udf( - lambda x: x.is_null(), - [pa.float64()], - pa.bool_(), - volatility="immutable", - name="is_null", - ) - - ctx.register_udf(is_null) - - df_result = ctx.sql("select is_null(a) from test_table") - result = df_result.collect()[0].column(0) - - assert result == pa.array([False, False, False]) diff --git a/python/tests/test_udf.py b/python/tests/test_udf.py new file mode 100644 index 000000000..568a66dbb --- /dev/null +++ b/python/tests/test_udf.py @@ -0,0 +1,96 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from datafusion import udf, column +import pyarrow as pa +import pytest + + +@pytest.fixture +def df(ctx): + # create a RecordBatch and a new DataFrame from it + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3]), pa.array([4, 4, 6])], + names=["a", "b"], + ) + return ctx.create_dataframe([[batch]], name="test_table") + + +def test_udf(df): + # is_null is a pa function over arrays + is_null = udf( + lambda x: x.is_null(), + [pa.int64()], + pa.bool_(), + volatility="immutable", + ) + + df = df.select(is_null(column("a"))) + result = df.collect()[0].column(0) + + assert result == pa.array([False, False, False]) + + +def test_register_udf(ctx, df) -> None: + is_null = udf( + lambda x: x.is_null(), + [pa.float64()], + pa.bool_(), + volatility="immutable", + name="is_null", + ) + + ctx.register_udf(is_null) + + df_result = ctx.sql("select is_null(a) from test_table") + result = df_result.collect()[0].column(0) + + assert result == pa.array([False, False, False]) + + +class OverThresholdUDF: + def __init__(self, threshold: int = 0) -> None: + self.threshold = threshold + + def __call__(self, values: pa.Array) -> pa.Array: + return pa.array(v.as_py() >= self.threshold for v in values) + + +def test_udf_with_parameters(df) -> None: + udf_no_param = udf( + OverThresholdUDF(), + pa.int64(), + pa.bool_(), + volatility="immutable", + ) + + df1 = df.select(udf_no_param(column("a"))) + result = df1.collect()[0].column(0) + + assert result == pa.array([True, True, True]) + + udf_with_param = udf( + OverThresholdUDF(2), + pa.int64(), + pa.bool_(), + volatility="immutable", + ) + + df2 = df.select(udf_with_param(column("a"))) + result = df2.collect()[0].column(0) + + assert result == pa.array([False, True, True]) diff --git a/python/tests/test_udwf.py b/python/tests/test_udwf.py index 67c0979fe..2099ac9bc 100644 --- a/python/tests/test_udwf.py +++ b/python/tests/test_udwf.py @@ -24,7 +24,7 @@ class ExponentialSmoothDefault(WindowEvaluator): - def __init__(self, alpha: float) -> None: + def __init__(self, alpha: float = 0.9) -> None: self.alpha = alpha def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array: @@ -44,7 +44,7 @@ def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array: class ExponentialSmoothBounded(WindowEvaluator): - def __init__(self, alpha: float) -> None: + def __init__(self, alpha: float = 0.9) -> None: self.alpha = alpha def supports_bounded_execution(self) -> bool: @@ -75,7 +75,7 @@ def evaluate( class ExponentialSmoothRank(WindowEvaluator): - def __init__(self, alpha: float) -> None: + def __init__(self, alpha: float = 0.9) -> None: self.alpha = alpha def include_rank(self) -> bool: @@ -101,7 +101,7 @@ def evaluate_all_with_rank( class ExponentialSmoothFrame(WindowEvaluator): - def __init__(self, alpha: float) -> None: + def __init__(self, alpha: float = 0.9) -> None: self.alpha = alpha def uses_window_frame(self) -> bool: @@ -134,7 +134,7 @@ class SmoothTwoColumn(WindowEvaluator): the previous and next rows. """ - def __init__(self, alpha: float) -> None: + def __init__(self, alpha: float = 0.9) -> None: self.alpha = alpha def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array: @@ -183,7 +183,7 @@ def df(): def test_udwf_errors(df): with pytest.raises(TypeError): udwf( - NotSubclassOfWindowEvaluator(), + NotSubclassOfWindowEvaluator, pa.float64(), pa.float64(), volatility="immutable", @@ -191,35 +191,42 @@ def test_udwf_errors(df): smooth_default = udwf( - ExponentialSmoothDefault(0.9), + ExponentialSmoothDefault, + pa.float64(), + pa.float64(), + volatility="immutable", +) + +smooth_w_arguments = udwf( + lambda: ExponentialSmoothDefault(0.8), pa.float64(), pa.float64(), volatility="immutable", ) smooth_bounded = udwf( - ExponentialSmoothBounded(0.9), + ExponentialSmoothBounded, pa.float64(), pa.float64(), volatility="immutable", ) smooth_rank = udwf( - ExponentialSmoothRank(0.9), + ExponentialSmoothRank, pa.utf8(), pa.float64(), volatility="immutable", ) smooth_frame = udwf( - ExponentialSmoothFrame(0.9), + ExponentialSmoothFrame, pa.float64(), pa.float64(), volatility="immutable", ) smooth_two_col = udwf( - SmoothTwoColumn(0.9), + SmoothTwoColumn, [pa.int64(), pa.int64()], pa.float64(), volatility="immutable", @@ -227,10 +234,15 @@ def test_udwf_errors(df): data_test_udwf_functions = [ ( - "default_udwf", + "default_udwf_no_arguments", smooth_default(column("a")), [0, 0.9, 1.89, 2.889, 3.889, 4.889, 5.889], ), + ( + "default_udwf_w_arguments", + smooth_w_arguments(column("a")), + [0, 0.8, 1.76, 2.752, 3.75, 4.75, 5.75], + ), ( "default_udwf_partitioned", smooth_default(column("a")).partition_by(column("c")).build(), diff --git a/src/udwf.rs b/src/udwf.rs index 31cc5e60e..43c21ec7b 100644 --- a/src/udwf.rs +++ b/src/udwf.rs @@ -197,7 +197,11 @@ impl PartitionEvaluator for RustPartitionEvaluator { pub fn to_rust_partition_evaluator(evaluator: PyObject) -> PartitionEvaluatorFactory { Arc::new(move || -> Result> { - let evaluator = Python::with_gil(|py| evaluator.clone_ref(py)); + let evaluator = Python::with_gil(|py| { + evaluator + .call0(py) + .map_err(|e| DataFusionError::Execution(e.to_string())) + })?; Ok(Box::new(RustPartitionEvaluator::new(evaluator))) }) }