diff --git a/python/datafusion/udf.py b/python/datafusion/udf.py index 0bba3d723..af7bcf2ed 100644 --- a/python/datafusion/udf.py +++ b/python/datafusion/udf.py @@ -19,6 +19,7 @@ from __future__ import annotations +import functools from abc import ABCMeta, abstractmethod from enum import Enum from typing import TYPE_CHECKING, Callable, List, Optional, TypeVar @@ -110,43 +111,102 @@ def __call__(self, *args: Expr) -> Expr: args_raw = [arg.expr for arg in args] return Expr(self._udf.__call__(*args_raw)) - @staticmethod - def udf( - func: Callable[..., _R], - input_types: list[pyarrow.DataType], - return_type: _R, - volatility: Volatility | str, - name: Optional[str] = None, - ) -> ScalarUDF: - """Create a new User-Defined Function. + class udf: + """Create a new User-Defined Function (UDF). + + This class can be used both as a **function** and as a **decorator**. + + Usage: + - **As a function**: Call `udf(func, input_types, return_type, volatility, + name)`. + - **As a decorator**: Use `@udf(input_types, return_type, volatility, + name)`. In this case, do **not** pass `func` explicitly. Args: - func: A callable python function. - input_types: The data types of the arguments to ``func``. This list - must be of the same length as the number of arguments. - return_type: The data type of the return value from the python - function. - volatility: See ``Volatility`` for allowed values. - name: A descriptive name for the function. + func (Callable, optional): **Only needed when calling as a function.** + Skip this argument when using `udf` as a decorator. + input_types (list[pyarrow.DataType]): The data types of the arguments + to `func`. This list must be of the same length as the number of + arguments. + return_type (_R): The data type of the return value from the function. + volatility (Volatility | str): See `Volatility` for allowed values. + name (Optional[str]): A descriptive name for the function. Returns: - A user-defined aggregate function, which can be used in either data - aggregation or window function calls. + A user-defined function that can be used in SQL expressions, + data aggregation, or window function calls. + + Example: + **Using `udf` as a function:** + ``` + def double_func(x): + return x * 2 + double_udf = udf(double_func, [pyarrow.int32()], pyarrow.int32(), + "volatile", "double_it") + ``` + + **Using `udf` as a decorator:** + ``` + @udf([pyarrow.int32()], pyarrow.int32(), "volatile", "double_it") + def double_udf(x): + return x * 2 + ``` """ - if not callable(func): - raise TypeError("`func` argument must be callable") - if name is None: - if hasattr(func, "__qualname__"): - name = func.__qualname__.lower() + + def __new__(cls, *args, **kwargs): + """Create a new UDF. + + Trigger UDF function or decorator depending on if the first args is callable + """ + if args and callable(args[0]): + # Case 1: Used as a function, require the first parameter to be callable + return cls._function(*args, **kwargs) else: - name = func.__class__.__name__.lower() - return ScalarUDF( - name=name, - func=func, - input_types=input_types, - return_type=return_type, - volatility=volatility, - ) + # Case 2: Used as a decorator with parameters + return cls._decorator(*args, **kwargs) + + @staticmethod + def _function( + func: Callable[..., _R], + input_types: list[pyarrow.DataType], + return_type: _R, + volatility: Volatility | str, + name: Optional[str] = None, + ) -> ScalarUDF: + if not callable(func): + raise TypeError("`func` argument must be callable") + if name is None: + if hasattr(func, "__qualname__"): + name = func.__qualname__.lower() + else: + name = func.__class__.__name__.lower() + return ScalarUDF( + name=name, + func=func, + input_types=input_types, + return_type=return_type, + volatility=volatility, + ) + + @staticmethod + def _decorator( + input_types: list[pyarrow.DataType], + return_type: _R, + volatility: Volatility | str, + name: Optional[str] = None, + ): + def decorator(func): + udf_caller = ScalarUDF.udf( + func, input_types, return_type, volatility, name + ) + + @functools.wraps(func) + def wrapper(*args, **kwargs): + return udf_caller(*args, **kwargs) + + return wrapper + + return decorator class Accumulator(metaclass=ABCMeta): @@ -212,25 +272,27 @@ def __call__(self, *args: Expr) -> Expr: args_raw = [arg.expr for arg in args] return Expr(self._udaf.__call__(*args_raw)) - @staticmethod - def udaf( - accum: Callable[[], Accumulator], - input_types: pyarrow.DataType | list[pyarrow.DataType], - return_type: pyarrow.DataType, - state_type: list[pyarrow.DataType], - volatility: Volatility | str, - name: Optional[str] = None, - ) -> AggregateUDF: - """Create a new User-Defined Aggregate Function. + class udaf: + """Create a new User-Defined Aggregate Function (UDAF). - If your :py:class:`Accumulator` can be instantiated with no arguments, you - can simply pass it's type as ``accum``. If you need to pass additional arguments - to it's constructor, you can define a lambda or a factory method. During runtime - the :py:class:`Accumulator` will be constructed for every instance in - which this UDAF is used. The following examples are all valid. + This class allows you to define an **aggregate function** that can be used in + data aggregation or window function calls. - .. code-block:: python + Usage: + - **As a function**: Call `udaf(accum, input_types, return_type, state_type, + volatility, name)`. + - **As a decorator**: Use `@udaf(input_types, return_type, state_type, + volatility, name)`. + When using `udaf` as a decorator, **do not pass `accum` explicitly**. + **Function example:** + + If your `:py:class:Accumulator` can be instantiated with no arguments, you + can simply pass it's type as `accum`. If you need to pass additional + arguments to it's constructor, you can define a lambda or a factory method. + During runtime the `:py:class:Accumulator` will be constructed for every + instance in which this UDAF is used. The following examples are all valid. + ``` import pyarrow as pa import pyarrow.compute as pc @@ -253,12 +315,24 @@ def evaluate(self) -> pa.Scalar: def sum_bias_10() -> Summarize: return Summarize(10.0) - udaf1 = udaf(Summarize, pa.float64(), pa.float64(), [pa.float64()], "immutable") - udaf2 = udaf(sum_bias_10, pa.float64(), pa.float64(), [pa.float64()], "immutable") - udaf3 = udaf(lambda: Summarize(20.0), pa.float64(), pa.float64(), [pa.float64()], "immutable") + udaf1 = udaf(Summarize, pa.float64(), pa.float64(), [pa.float64()], + "immutable") + udaf2 = udaf(sum_bias_10, pa.float64(), pa.float64(), [pa.float64()], + "immutable") + udaf3 = udaf(lambda: Summarize(20.0), pa.float64(), pa.float64(), + [pa.float64()], "immutable") + ``` + + **Decorator example:** + ``` + @udaf(pa.float64(), pa.float64(), [pa.float64()], "immutable") + def udf4() -> Summarize: + return Summarize(10.0) + ``` Args: - accum: The accumulator python function. + accum: The accumulator python function. **Only needed when calling as a + function. Skip this argument when using `udaf` as a decorator.** input_types: The data types of the arguments to ``accum``. return_type: The data type of the return value. state_type: The data types of the intermediate accumulation. @@ -268,26 +342,69 @@ def sum_bias_10() -> Summarize: Returns: A user-defined aggregate function, which can be used in either data aggregation or window function calls. - """ # noqa W505 - if not callable(accum): - raise TypeError("`func` must be callable.") - if not isinstance(accum.__call__(), Accumulator): - raise TypeError( - "Accumulator must implement the abstract base class Accumulator" + """ + + def __new__(cls, *args, **kwargs): + """Create a new UDAF. + + Trigger UDAF function or decorator depending on if the first args is + callable + """ + if args and callable(args[0]): + # Case 1: Used as a function, require the first parameter to be callable + return cls._function(*args, **kwargs) + else: + # Case 2: Used as a decorator with parameters + return cls._decorator(*args, **kwargs) + + @staticmethod + def _function( + accum: Callable[[], Accumulator], + input_types: pyarrow.DataType | list[pyarrow.DataType], + return_type: pyarrow.DataType, + state_type: list[pyarrow.DataType], + volatility: Volatility | str, + name: Optional[str] = None, + ) -> AggregateUDF: + if not callable(accum): + raise TypeError("`func` must be callable.") + if not isinstance(accum.__call__(), Accumulator): + raise TypeError( + "Accumulator must implement the abstract base class Accumulator" + ) + if name is None: + name = accum.__call__().__class__.__qualname__.lower() + if isinstance(input_types, pyarrow.DataType): + input_types = [input_types] + return AggregateUDF( + name=name, + accumulator=accum, + input_types=input_types, + return_type=return_type, + state_type=state_type, + volatility=volatility, ) - if name is None: - name = accum.__call__().__class__.__qualname__.lower() - assert name is not None - if isinstance(input_types, pyarrow.DataType): - input_types = [input_types] - return AggregateUDF( - name=name, - accumulator=accum, - input_types=input_types, - return_type=return_type, - state_type=state_type, - volatility=volatility, - ) + + @staticmethod + def _decorator( + input_types: pyarrow.DataType | list[pyarrow.DataType], + return_type: pyarrow.DataType, + state_type: list[pyarrow.DataType], + volatility: Volatility | str, + name: Optional[str] = None, + ): + def decorator(accum: Callable[[], Accumulator]): + udaf_caller = AggregateUDF.udaf( + accum, input_types, return_type, state_type, volatility, name + ) + + @functools.wraps(accum) + def wrapper(*args, **kwargs): + return udaf_caller(*args, **kwargs) + + return wrapper + + return decorator class WindowEvaluator(metaclass=ABCMeta): diff --git a/python/tests/test_udaf.py b/python/tests/test_udaf.py index 0005a3da8..e69c77d3c 100644 --- a/python/tests/test_udaf.py +++ b/python/tests/test_udaf.py @@ -117,6 +117,26 @@ def test_udaf_aggregate(df): assert result.column(0) == pa.array([1.0 + 2.0 + 3.0]) +def test_udaf_decorator_aggregate(df): + @udaf(pa.float64(), pa.float64(), [pa.float64()], "immutable") + def summarize(): + return Summarize() + + df1 = df.aggregate([], [summarize(column("a"))]) + + # execute and collect the first (and only) batch + result = df1.collect()[0] + + assert result.column(0) == pa.array([1.0 + 2.0 + 3.0]) + + df2 = df.aggregate([], [summarize(column("a"))]) + + # Run a second time to ensure the state is properly reset + result = df2.collect()[0] + + assert result.column(0) == pa.array([1.0 + 2.0 + 3.0]) + + def test_udaf_aggregate_with_arguments(df): bias = 10.0 @@ -143,6 +163,28 @@ def test_udaf_aggregate_with_arguments(df): assert result.column(0) == pa.array([bias + 1.0 + 2.0 + 3.0]) +def test_udaf_decorator_aggregate_with_arguments(df): + bias = 10.0 + + @udaf(pa.float64(), pa.float64(), [pa.float64()], "immutable") + def summarize(): + return Summarize(bias) + + df1 = df.aggregate([], [summarize(column("a"))]) + + # execute and collect the first (and only) batch + result = df1.collect()[0] + + assert result.column(0) == pa.array([bias + 1.0 + 2.0 + 3.0]) + + df2 = df.aggregate([], [summarize(column("a"))]) + + # Run a second time to ensure the state is properly reset + result = df2.collect()[0] + + assert result.column(0) == pa.array([bias + 1.0 + 2.0 + 3.0]) + + def test_group_by(df): summarize = udaf( Summarize, diff --git a/python/tests/test_udf.py b/python/tests/test_udf.py index 3a5dce6d6..a6c047552 100644 --- a/python/tests/test_udf.py +++ b/python/tests/test_udf.py @@ -24,7 +24,7 @@ def df(ctx): # create a RecordBatch and a new DataFrame from it batch = pa.RecordBatch.from_arrays( - [pa.array([1, 2, 3]), pa.array([4, 4, 6])], + [pa.array([1, 2, 3]), pa.array([4, 4, None])], names=["a", "b"], ) return ctx.create_dataframe([[batch]], name="test_table") @@ -39,10 +39,20 @@ def test_udf(df): volatility="immutable", ) - df = df.select(is_null(column("a"))) + df = df.select(is_null(column("b"))) result = df.collect()[0].column(0) - assert result == pa.array([False, False, False]) + assert result == pa.array([False, False, True]) + + +def test_udf_decorator(df): + @udf([pa.int64()], pa.bool_(), "immutable") + def is_null(x: pa.Array) -> pa.Array: + return x.is_null() + + df = df.select(is_null(column("b"))) + result = df.collect()[0].column(0) + assert result == pa.array([False, False, True]) def test_register_udf(ctx, df) -> None: @@ -56,10 +66,10 @@ def test_register_udf(ctx, df) -> None: ctx.register_udf(is_null) - df_result = ctx.sql("select is_null(a) from test_table") + df_result = ctx.sql("select is_null(b) from test_table") result = df_result.collect()[0].column(0) - assert result == pa.array([False, False, False]) + assert result == pa.array([False, False, True]) class OverThresholdUDF: @@ -70,7 +80,7 @@ def __call__(self, values: pa.Array) -> pa.Array: return pa.array(v.as_py() >= self.threshold for v in values) -def test_udf_with_parameters(df) -> None: +def test_udf_with_parameters_function(df) -> None: udf_no_param = udf( OverThresholdUDF(), pa.int64(), @@ -94,3 +104,23 @@ def test_udf_with_parameters(df) -> None: result = df2.collect()[0].column(0) assert result == pa.array([False, True, True]) + + +def test_udf_with_parameters_decorator(df) -> None: + @udf([pa.int64()], pa.bool_(), "immutable") + def udf_no_param(values: pa.Array) -> pa.Array: + return OverThresholdUDF()(values) + + df1 = df.select(udf_no_param(column("a"))) + result = df1.collect()[0].column(0) + + assert result == pa.array([True, True, True]) + + @udf([pa.int64()], pa.bool_(), "immutable") + def udf_with_param(values: pa.Array) -> pa.Array: + return OverThresholdUDF(2)(values) + + df2 = df.select(udf_with_param(column("a"))) + result = df2.collect()[0].column(0) + + assert result == pa.array([False, True, True])