From 660035de53beddd2532d9de66487a5c191a0e620 Mon Sep 17 00:00:00 2001 From: Crystal Zhou Date: Sun, 2 Mar 2025 16:29:29 -0500 Subject: [PATCH 1/5] Implementation of udf and udaf decorator --- python/datafusion/__init__.py | 2 ++ python/datafusion/udf.py | 47 ++++++++++++++++++++++++++++++++ python/tests/test_udaf.py | 50 ++++++++++++++++++++++++++++++++++- python/tests/test_udf.py | 42 ++++++++++++++++++++++++----- 4 files changed, 134 insertions(+), 7 deletions(-) diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py index 85aefcce7..fea7f50bd 100644 --- a/python/datafusion/__init__.py +++ b/python/datafusion/__init__.py @@ -118,7 +118,9 @@ def lit(value): udf = ScalarUDF.udf +udf_decorator = ScalarUDF.udf_decorator udaf = AggregateUDF.udaf +udaf_decorator = AggregateUDF.udaf_decorator udwf = WindowUDF.udwf diff --git a/python/datafusion/udf.py b/python/datafusion/udf.py index c97f453d0..08c2750c8 100644 --- a/python/datafusion/udf.py +++ b/python/datafusion/udf.py @@ -22,6 +22,7 @@ from abc import ABCMeta, abstractmethod from enum import Enum from typing import TYPE_CHECKING, Callable, List, Optional, TypeVar +import functools import pyarrow @@ -148,6 +149,27 @@ def udf( volatility=volatility, ) + @staticmethod + def udf_decorator( + input_types: list[pyarrow.DataType], + return_type: _R, + volatility: Volatility | str, + name: Optional[str] = None + ): + def decorator(func): + udf_caller = ScalarUDF.udf( + func, + input_types, + return_type, + volatility, + name + ) + + @functools.wraps(func) + def wrapper(*args, **kwargs): + return udf_caller(*args, **kwargs) + return wrapper + return decorator class Accumulator(metaclass=ABCMeta): """Defines how an :py:class:`AggregateUDF` accumulates values.""" @@ -287,6 +309,31 @@ def sum_bias_10() -> Summarize: state_type=state_type, volatility=volatility, ) + + @staticmethod + def udaf_decorator( + input_types: pyarrow.DataType | list[pyarrow.DataType], + return_type: pyarrow.DataType, + state_type: list[pyarrow.DataType], + volatility: Volatility | str, + name: Optional[str] = None + ): + def decorator(accum: Callable[[], Accumulator]): + udaf_caller = AggregateUDF.udaf( + accum, + input_types, + return_type, + state_type, + volatility, + name + ) + + @functools.wraps(accum) + def wrapper(*args, **kwargs): + return udaf_caller(*args, **kwargs) + return wrapper + return decorator + class WindowEvaluator(metaclass=ABCMeta): diff --git a/python/tests/test_udaf.py b/python/tests/test_udaf.py index 0005a3da8..0e06d60a7 100644 --- a/python/tests/test_udaf.py +++ b/python/tests/test_udaf.py @@ -20,7 +20,7 @@ import pyarrow as pa import pyarrow.compute as pc import pytest -from datafusion import Accumulator, column, udaf +from datafusion import Accumulator, column, udaf, udaf_decorator class Summarize(Accumulator): @@ -116,6 +116,29 @@ def test_udaf_aggregate(df): assert result.column(0) == pa.array([1.0 + 2.0 + 3.0]) +def test_udaf_decorator_aggregate(df): + + @udaf_decorator(pa.float64(), + pa.float64(), + [pa.float64()], + "immutable") + def summarize(): + return Summarize() + + df1 = df.aggregate([], [summarize(column("a"))]) + + # execute and collect the first (and only) batch + result = df1.collect()[0] + + assert result.column(0) == pa.array([1.0 + 2.0 + 3.0]) + + df2 = df.aggregate([], [summarize(column("a"))]) + + # Run a second time to ensure the state is properly reset + result = df2.collect()[0] + + assert result.column(0) == pa.array([1.0 + 2.0 + 3.0]) + def test_udaf_aggregate_with_arguments(df): bias = 10.0 @@ -143,6 +166,31 @@ def test_udaf_aggregate_with_arguments(df): assert result.column(0) == pa.array([bias + 1.0 + 2.0 + 3.0]) +def test_udaf_decorator_aggregate_with_arguments(df): + bias = 10.0 + + @udaf_decorator(pa.float64(), + pa.float64(), + [pa.float64()], + "immutable") + def summarize(): + return Summarize(bias) + + df1 = df.aggregate([], [summarize(column("a"))]) + + # execute and collect the first (and only) batch + result = df1.collect()[0] + + assert result.column(0) == pa.array([bias + 1.0 + 2.0 + 3.0]) + + df2 = df.aggregate([], [summarize(column("a"))]) + + # Run a second time to ensure the state is properly reset + result = df2.collect()[0] + + assert result.column(0) == pa.array([bias + 1.0 + 2.0 + 3.0]) + + def test_group_by(df): summarize = udaf( Summarize, diff --git a/python/tests/test_udf.py b/python/tests/test_udf.py index 3a5dce6d6..8d95b989b 100644 --- a/python/tests/test_udf.py +++ b/python/tests/test_udf.py @@ -17,14 +17,14 @@ import pyarrow as pa import pytest -from datafusion import column, udf +from datafusion import column, udf, udf_decorator @pytest.fixture def df(ctx): # create a RecordBatch and a new DataFrame from it batch = pa.RecordBatch.from_arrays( - [pa.array([1, 2, 3]), pa.array([4, 4, 6])], + [pa.array([1, 2, 3]), pa.array([4, 4, None])], names=["a", "b"], ) return ctx.create_dataframe([[batch]], name="test_table") @@ -39,10 +39,20 @@ def test_udf(df): volatility="immutable", ) - df = df.select(is_null(column("a"))) + df = df.select(is_null(column("b"))) result = df.collect()[0].column(0) - assert result == pa.array([False, False, False]) + assert result == pa.array([False, False, True]) + + +def test_udf_decorator(df): + @udf_decorator([pa.int64()], pa.bool_(), "immutable") + def is_null(x: pa.Array) -> pa.Array: + return x.is_null() + + df = df.select(is_null(column("b"))) + result = df.collect()[0].column(0) + assert result == pa.array([False, False, True]) def test_register_udf(ctx, df) -> None: @@ -56,10 +66,10 @@ def test_register_udf(ctx, df) -> None: ctx.register_udf(is_null) - df_result = ctx.sql("select is_null(a) from test_table") + df_result = ctx.sql("select is_null(b) from test_table") result = df_result.collect()[0].column(0) - assert result == pa.array([False, False, False]) + assert result == pa.array([False, False, True]) class OverThresholdUDF: @@ -94,3 +104,23 @@ def test_udf_with_parameters(df) -> None: result = df2.collect()[0].column(0) assert result == pa.array([False, True, True]) + + +def test_udf_with_parameters(df) -> None: + @udf_decorator([pa.int64()], pa.bool_(), "immutable") + def udf_no_param(values: pa.Array) -> pa.Array: + return OverThresholdUDF()(values) + + df1 = df.select(udf_no_param(column("a"))) + result = df1.collect()[0].column(0) + + assert result == pa.array([True, True, True]) + + @udf_decorator([pa.int64()], pa.bool_(), "immutable") + def udf_with_param(values: pa.Array) -> pa.Array: + return OverThresholdUDF(2)(values) + + df2 = df.select(udf_with_param(column("a"))) + result = df2.collect()[0].column(0) + + assert result == pa.array([False, True, True]) From f68c0ef08b11e039a7d62110d614bcaf60d90278 Mon Sep 17 00:00:00 2001 From: Crystal Zhou Date: Mon, 3 Mar 2025 11:53:51 -0500 Subject: [PATCH 2/5] Rename decorators back to udf and udaf, update documentations --- python/datafusion/__init__.py | 2 - python/datafusion/udf.py | 285 +++++++++++++++++++++------------- python/tests/test_udaf.py | 9 +- python/tests/test_udf.py | 8 +- 4 files changed, 180 insertions(+), 124 deletions(-) diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py index fea7f50bd..85aefcce7 100644 --- a/python/datafusion/__init__.py +++ b/python/datafusion/__init__.py @@ -118,9 +118,7 @@ def lit(value): udf = ScalarUDF.udf -udf_decorator = ScalarUDF.udf_decorator udaf = AggregateUDF.udaf -udaf_decorator = AggregateUDF.udaf_decorator udwf = WindowUDF.udwf diff --git a/python/datafusion/udf.py b/python/datafusion/udf.py index 08c2750c8..ae06c0ad5 100644 --- a/python/datafusion/udf.py +++ b/python/datafusion/udf.py @@ -111,65 +111,96 @@ def __call__(self, *args: Expr) -> Expr: args_raw = [arg.expr for arg in args] return Expr(self._udf.__call__(*args_raw)) - @staticmethod - def udf( - func: Callable[..., _R], - input_types: list[pyarrow.DataType], - return_type: _R, - volatility: Volatility | str, - name: Optional[str] = None, - ) -> ScalarUDF: - """Create a new User-Defined Function. + class udf: + """Create a new User-Defined Function (UDF). + + This class can be used both as a **function** and as a **decorator**. + + Usage: + - **As a function**: Call `udf(func, input_types, return_type, volatility, name)`. + - **As a decorator**: Use `@udf(input_types, return_type, volatility, name)`. + In this case, do **not** pass `func` explicitly. Args: - func: A callable python function. - input_types: The data types of the arguments to ``func``. This list - must be of the same length as the number of arguments. - return_type: The data type of the return value from the python - function. - volatility: See ``Volatility`` for allowed values. - name: A descriptive name for the function. + func (Callable, optional): **Only needed when calling as a function.** + Skip this argument when using `udf` as a decorator. + input_types (list[pyarrow.DataType]): The data types of the arguments + to `func`. This list must be of the same length as the number of arguments. + return_type (_R): The data type of the return value from the function. + volatility (Volatility | str): See `Volatility` for allowed values. + name (Optional[str]): A descriptive name for the function. Returns: - A user-defined aggregate function, which can be used in either data - aggregation or window function calls. + A user-defined function that can be used in SQL expressions, + data aggregation, or window function calls. + + Example: + **Using `udf` as a function:** + ```python + def double_func(x): + return x * 2 + double_udf = udf(double_func, [pyarrow.int32()], pyarrow.int32(), "volatile", "double_it") + ``` + + **Using `udf` as a decorator:** + ```python + @udf([pyarrow.int32()], pyarrow.int32(), "volatile", "double_it") + def double_udf(x): + return x * 2 + ``` """ - if not callable(func): - raise TypeError("`func` argument must be callable") - if name is None: - if hasattr(func, "__qualname__"): - name = func.__qualname__.lower() + def __new__(cls, *args, **kwargs): + if args and callable(args[0]): + # Case 1: Used as a function, require the first parameter to be callable + return cls._function(*args, **kwargs) else: - name = func.__class__.__name__.lower() - return ScalarUDF( - name=name, - func=func, - input_types=input_types, - return_type=return_type, - volatility=volatility, - ) - - @staticmethod - def udf_decorator( - input_types: list[pyarrow.DataType], - return_type: _R, - volatility: Volatility | str, - name: Optional[str] = None - ): - def decorator(func): - udf_caller = ScalarUDF.udf( - func, - input_types, - return_type, - volatility, - name + # Case 2: Used as a decorator with parameters + return cls._decorator(*args, **kwargs) + + @staticmethod + def _function( + func: Callable[..., _R], + input_types: list[pyarrow.DataType], + return_type: _R, + volatility: Volatility | str, + name: Optional[str] = None, + ) -> ScalarUDF: + if not callable(func): + raise TypeError("`func` argument must be callable") + if name is None: + if hasattr(func, "__qualname__"): + name = func.__qualname__.lower() + else: + name = func.__class__.__name__.lower() + return ScalarUDF( + name=name, + func=func, + input_types=input_types, + return_type=return_type, + volatility=volatility, ) - - @functools.wraps(func) - def wrapper(*args, **kwargs): - return udf_caller(*args, **kwargs) - return wrapper - return decorator + + @staticmethod + def _decorator( + input_types: list[pyarrow.DataType], + return_type: _R, + volatility: Volatility | str, + name: Optional[str] = None + ): + def decorator(func): + udf_caller = ScalarUDF.udf( + func, + input_types, + return_type, + volatility, + name + ) + + @functools.wraps(func) + def wrapper(*args, **kwargs): + return udf_caller(*args, **kwargs) + return wrapper + return decorator class Accumulator(metaclass=ABCMeta): """Defines how an :py:class:`AggregateUDF` accumulates values.""" @@ -234,25 +265,27 @@ def __call__(self, *args: Expr) -> Expr: args_raw = [arg.expr for arg in args] return Expr(self._udaf.__call__(*args_raw)) - @staticmethod - def udaf( - accum: Callable[[], Accumulator], - input_types: pyarrow.DataType | list[pyarrow.DataType], - return_type: pyarrow.DataType, - state_type: list[pyarrow.DataType], - volatility: Volatility | str, - name: Optional[str] = None, - ) -> AggregateUDF: - """Create a new User-Defined Aggregate Function. + class udaf: + """Create a new User-Defined Aggregate Function (UDAF). - If your :py:class:`Accumulator` can be instantiated with no arguments, you - can simply pass it's type as ``accum``. If you need to pass additional arguments - to it's constructor, you can define a lambda or a factory method. During runtime - the :py:class:`Accumulator` will be constructed for every instance in - which this UDAF is used. The following examples are all valid. + This class allows you to define an **aggregate function** that can be used in data + aggregation or window function calls. - .. code-block:: python + Usage: + - **As a function**: Call `udaf(accum, input_types, return_type, state_type, + volatility, name)`. + - **As a decorator**: Use `@udaf(input_types, return_type, state_type, + volatility, name)`. + When using `udaf` as a decorator, **do not pass `accum` explicitly**. + **Function example:** + + If your `:py:class:Accumulator` can be instantiated with no arguments, you can + simply pass it's type as `accum`. If you need to pass additional arguments to + it's constructor, you can define a lambda or a factory method. During runtime the + `:py:class:Accumulator` will be constructed for every instance in which this UDAF is + used. The following examples are all valid. + ``` import pyarrow as pa import pyarrow.compute as pc @@ -278,61 +311,89 @@ def sum_bias_10() -> Summarize: udaf1 = udaf(Summarize, pa.float64(), pa.float64(), [pa.float64()], "immutable") udaf2 = udaf(sum_bias_10, pa.float64(), pa.float64(), [pa.float64()], "immutable") udaf3 = udaf(lambda: Summarize(20.0), pa.float64(), pa.float64(), [pa.float64()], "immutable") + ``` + + **Decorator example:** + ``` + @udaf(pa.float64(), pa.float64(), [pa.float64()], "immutable") + def udf4() -> Summarize: + return Summarize(10.0) + ``` Args: - accum: The accumulator python function. - input_types: The data types of the arguments to ``accum``. + accum: The accumulator python function. **Only needed when calling as a function. Skip this argument when using `udaf` as a decorator.** + input_types: The data types of the arguments to `accum. return_type: The data type of the return value. state_type: The data types of the intermediate accumulation. - volatility: See :py:class:`Volatility` for allowed values. + volatility: See :py:class:Volatility for allowed values. name: A descriptive name for the function. Returns: A user-defined aggregate function, which can be used in either data aggregation or window function calls. - """ # noqa W505 - if not callable(accum): - raise TypeError("`func` must be callable.") - if not isinstance(accum.__call__(), Accumulator): - raise TypeError( - "Accumulator must implement the abstract base class Accumulator" - ) - if name is None: - name = accum.__call__().__class__.__qualname__.lower() - if isinstance(input_types, pyarrow.DataType): - input_types = [input_types] - return AggregateUDF( - name=name, - accumulator=accum, - input_types=input_types, - return_type=return_type, - state_type=state_type, - volatility=volatility, - ) - - @staticmethod - def udaf_decorator( - input_types: pyarrow.DataType | list[pyarrow.DataType], - return_type: pyarrow.DataType, - state_type: list[pyarrow.DataType], - volatility: Volatility | str, - name: Optional[str] = None - ): - def decorator(accum: Callable[[], Accumulator]): - udaf_caller = AggregateUDF.udaf( - accum, - input_types, - return_type, - state_type, - volatility, - name + """ + + + + def __new__(cls, *args, **kwargs): + if args and callable(args[0]): + # Case 1: Used as a function, require the first parameter to be callable + return cls._function(*args, **kwargs) + else: + # Case 2: Used as a decorator with parameters + return cls._decorator(*args, **kwargs) + + @staticmethod + def _function( + accum: Callable[[], Accumulator], + input_types: pyarrow.DataType | list[pyarrow.DataType], + return_type: pyarrow.DataType, + state_type: list[pyarrow.DataType], + volatility: Volatility | str, + name: Optional[str] = None, + ) -> AggregateUDF: + if not callable(accum): + raise TypeError("`func` must be callable.") + if not isinstance(accum.__call__(), Accumulator): + raise TypeError( + "Accumulator must implement the abstract base class Accumulator" + ) + if name is None: + name = accum.__call__().__class__.__qualname__.lower() + if isinstance(input_types, pyarrow.DataType): + input_types = [input_types] + return AggregateUDF( + name=name, + accumulator=accum, + input_types=input_types, + return_type=return_type, + state_type=state_type, + volatility=volatility, ) - @functools.wraps(accum) - def wrapper(*args, **kwargs): - return udaf_caller(*args, **kwargs) - return wrapper - return decorator + @staticmethod + def _decorator( + input_types: pyarrow.DataType | list[pyarrow.DataType], + return_type: pyarrow.DataType, + state_type: list[pyarrow.DataType], + volatility: Volatility | str, + name: Optional[str] = None + ): + def decorator(accum: Callable[[], Accumulator]): + udaf_caller = AggregateUDF.udaf( + accum, + input_types, + return_type, + state_type, + volatility, + name + ) + + @functools.wraps(accum) + def wrapper(*args, **kwargs): + return udaf_caller(*args, **kwargs) + return wrapper + return decorator diff --git a/python/tests/test_udaf.py b/python/tests/test_udaf.py index 0e06d60a7..87cfafdf9 100644 --- a/python/tests/test_udaf.py +++ b/python/tests/test_udaf.py @@ -20,7 +20,7 @@ import pyarrow as pa import pyarrow.compute as pc import pytest -from datafusion import Accumulator, column, udaf, udaf_decorator +from datafusion import Accumulator, column, udaf class Summarize(Accumulator): @@ -118,10 +118,7 @@ def test_udaf_aggregate(df): def test_udaf_decorator_aggregate(df): - @udaf_decorator(pa.float64(), - pa.float64(), - [pa.float64()], - "immutable") + @udaf(pa.float64(), pa.float64(), [pa.float64()], "immutable") def summarize(): return Summarize() @@ -169,7 +166,7 @@ def test_udaf_aggregate_with_arguments(df): def test_udaf_decorator_aggregate_with_arguments(df): bias = 10.0 - @udaf_decorator(pa.float64(), + @udaf(pa.float64(), pa.float64(), [pa.float64()], "immutable") diff --git a/python/tests/test_udf.py b/python/tests/test_udf.py index 8d95b989b..5f820d4dc 100644 --- a/python/tests/test_udf.py +++ b/python/tests/test_udf.py @@ -17,7 +17,7 @@ import pyarrow as pa import pytest -from datafusion import column, udf, udf_decorator +from datafusion import column, udf @pytest.fixture @@ -46,7 +46,7 @@ def test_udf(df): def test_udf_decorator(df): - @udf_decorator([pa.int64()], pa.bool_(), "immutable") + @udf([pa.int64()], pa.bool_(), "immutable") def is_null(x: pa.Array) -> pa.Array: return x.is_null() @@ -107,7 +107,7 @@ def test_udf_with_parameters(df) -> None: def test_udf_with_parameters(df) -> None: - @udf_decorator([pa.int64()], pa.bool_(), "immutable") + @udf([pa.int64()], pa.bool_(), "immutable") def udf_no_param(values: pa.Array) -> pa.Array: return OverThresholdUDF()(values) @@ -116,7 +116,7 @@ def udf_no_param(values: pa.Array) -> pa.Array: assert result == pa.array([True, True, True]) - @udf_decorator([pa.int64()], pa.bool_(), "immutable") + @udf([pa.int64()], pa.bool_(), "immutable") def udf_with_param(values: pa.Array) -> pa.Array: return OverThresholdUDF(2)(values) From 23d6c0ce5b4f10238e5e856ae26eeb2ffffc02e7 Mon Sep 17 00:00:00 2001 From: Crystal Zhou Date: Mon, 3 Mar 2025 11:57:48 -0500 Subject: [PATCH 3/5] Minor typo fixes --- python/datafusion/udf.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/python/datafusion/udf.py b/python/datafusion/udf.py index ae06c0ad5..2a47d2882 100644 --- a/python/datafusion/udf.py +++ b/python/datafusion/udf.py @@ -136,14 +136,14 @@ class udf: Example: **Using `udf` as a function:** - ```python + ``` def double_func(x): return x * 2 double_udf = udf(double_func, [pyarrow.int32()], pyarrow.int32(), "volatile", "double_it") ``` **Using `udf` as a decorator:** - ```python + ``` @udf([pyarrow.int32()], pyarrow.int32(), "volatile", "double_it") def double_udf(x): return x * 2 @@ -322,10 +322,10 @@ def udf4() -> Summarize: Args: accum: The accumulator python function. **Only needed when calling as a function. Skip this argument when using `udaf` as a decorator.** - input_types: The data types of the arguments to `accum. + input_types: The data types of the arguments to ``accum``. return_type: The data type of the return value. state_type: The data types of the intermediate accumulation. - volatility: See :py:class:Volatility for allowed values. + volatility: See :py:class:`Volatility` for allowed values. name: A descriptive name for the function. Returns: @@ -333,8 +333,6 @@ def udf4() -> Summarize: aggregation or window function calls. """ - - def __new__(cls, *args, **kwargs): if args and callable(args[0]): # Case 1: Used as a function, require the first parameter to be callable From 6d7a61fdb9f175b2dda2251d6f07b461a4683941 Mon Sep 17 00:00:00 2001 From: Crystal Zhou Date: Mon, 3 Mar 2025 21:33:34 -0500 Subject: [PATCH 4/5] Fixing linting errors --- python/datafusion/udf.py | 90 +++++++++++++++++++++++----------------- 1 file changed, 51 insertions(+), 39 deletions(-) diff --git a/python/datafusion/udf.py b/python/datafusion/udf.py index 2a47d2882..e5249921b 100644 --- a/python/datafusion/udf.py +++ b/python/datafusion/udf.py @@ -19,10 +19,10 @@ from __future__ import annotations +import functools from abc import ABCMeta, abstractmethod from enum import Enum from typing import TYPE_CHECKING, Callable, List, Optional, TypeVar -import functools import pyarrow @@ -117,15 +117,17 @@ class udf: This class can be used both as a **function** and as a **decorator**. Usage: - - **As a function**: Call `udf(func, input_types, return_type, volatility, name)`. - - **As a decorator**: Use `@udf(input_types, return_type, volatility, name)`. - In this case, do **not** pass `func` explicitly. + - **As a function**: Call `udf(func, input_types, return_type, volatility, + name)`. + - **As a decorator**: Use `@udf(input_types, return_type, volatility, + name)`. In this case, do **not** pass `func` explicitly. Args: func (Callable, optional): **Only needed when calling as a function.** Skip this argument when using `udf` as a decorator. input_types (list[pyarrow.DataType]): The data types of the arguments - to `func`. This list must be of the same length as the number of arguments. + to `func`. This list must be of the same length as the number of + arguments. return_type (_R): The data type of the return value from the function. volatility (Volatility | str): See `Volatility` for allowed values. name (Optional[str]): A descriptive name for the function. @@ -139,7 +141,8 @@ class udf: ``` def double_func(x): return x * 2 - double_udf = udf(double_func, [pyarrow.int32()], pyarrow.int32(), "volatile", "double_it") + double_udf = udf(double_func, [pyarrow.int32()], pyarrow.int32(), + "volatile", "double_it") ``` **Using `udf` as a decorator:** @@ -149,8 +152,13 @@ def double_udf(x): return x * 2 ``` """ + def __new__(cls, *args, **kwargs): - if args and callable(args[0]): + """Create a new UDF. + + Trigger UDF function or decorator depending on if the first args is callable + """ + if args and callable(args[0]): # Case 1: Used as a function, require the first parameter to be callable return cls._function(*args, **kwargs) else: @@ -185,23 +193,22 @@ def _decorator( input_types: list[pyarrow.DataType], return_type: _R, volatility: Volatility | str, - name: Optional[str] = None + name: Optional[str] = None, ): def decorator(func): udf_caller = ScalarUDF.udf( - func, - input_types, - return_type, - volatility, - name + func, input_types, return_type, volatility, name ) - + @functools.wraps(func) def wrapper(*args, **kwargs): return udf_caller(*args, **kwargs) + return wrapper + return decorator + class Accumulator(metaclass=ABCMeta): """Defines how an :py:class:`AggregateUDF` accumulates values.""" @@ -268,8 +275,8 @@ def __call__(self, *args: Expr) -> Expr: class udaf: """Create a new User-Defined Aggregate Function (UDAF). - This class allows you to define an **aggregate function** that can be used in data - aggregation or window function calls. + This class allows you to define an **aggregate function** that can be used in + data aggregation or window function calls. Usage: - **As a function**: Call `udaf(accum, input_types, return_type, state_type, @@ -279,12 +286,12 @@ class udaf: When using `udaf` as a decorator, **do not pass `accum` explicitly**. **Function example:** - - If your `:py:class:Accumulator` can be instantiated with no arguments, you can - simply pass it's type as `accum`. If you need to pass additional arguments to - it's constructor, you can define a lambda or a factory method. During runtime the - `:py:class:Accumulator` will be constructed for every instance in which this UDAF is - used. The following examples are all valid. + + If your `:py:class:Accumulator` can be instantiated with no arguments, you + can simply pass it's type as `accum`. If you need to pass additional + arguments to it's constructor, you can define a lambda or a factory method. + During runtime the `:py:class:Accumulator` will be constructed for every + instance in which this UDAF is used. The following examples are all valid. ``` import pyarrow as pa import pyarrow.compute as pc @@ -308,11 +315,14 @@ def evaluate(self) -> pa.Scalar: def sum_bias_10() -> Summarize: return Summarize(10.0) - udaf1 = udaf(Summarize, pa.float64(), pa.float64(), [pa.float64()], "immutable") - udaf2 = udaf(sum_bias_10, pa.float64(), pa.float64(), [pa.float64()], "immutable") - udaf3 = udaf(lambda: Summarize(20.0), pa.float64(), pa.float64(), [pa.float64()], "immutable") + udaf1 = udaf(Summarize, pa.float64(), pa.float64(), [pa.float64()], + "immutable") + udaf2 = udaf(sum_bias_10, pa.float64(), pa.float64(), [pa.float64()], + "immutable") + udaf3 = udaf(lambda: Summarize(20.0), pa.float64(), pa.float64(), + [pa.float64()], "immutable") ``` - + **Decorator example:** ``` @udaf(pa.float64(), pa.float64(), [pa.float64()], "immutable") @@ -321,7 +331,8 @@ def udf4() -> Summarize: ``` Args: - accum: The accumulator python function. **Only needed when calling as a function. Skip this argument when using `udaf` as a decorator.** + accum: The accumulator python function. **Only needed when calling as a + function. Skip this argument when using `udaf` as a decorator.** input_types: The data types of the arguments to ``accum``. return_type: The data type of the return value. state_type: The data types of the intermediate accumulation. @@ -334,13 +345,18 @@ def udf4() -> Summarize: """ def __new__(cls, *args, **kwargs): - if args and callable(args[0]): + """Create a new UDAF. + + Trigger UDAF function or decorator depending on if the first args is + callable + """ + if args and callable(args[0]): # Case 1: Used as a function, require the first parameter to be callable return cls._function(*args, **kwargs) else: # Case 2: Used as a decorator with parameters return cls._decorator(*args, **kwargs) - + @staticmethod def _function( accum: Callable[[], Accumulator], @@ -368,31 +384,27 @@ def _function( state_type=state_type, volatility=volatility, ) - + @staticmethod def _decorator( input_types: pyarrow.DataType | list[pyarrow.DataType], return_type: pyarrow.DataType, state_type: list[pyarrow.DataType], volatility: Volatility | str, - name: Optional[str] = None + name: Optional[str] = None, ): def decorator(accum: Callable[[], Accumulator]): udaf_caller = AggregateUDF.udaf( - accum, - input_types, - return_type, - state_type, - volatility, - name + accum, input_types, return_type, state_type, volatility, name ) - + @functools.wraps(accum) def wrapper(*args, **kwargs): return udaf_caller(*args, **kwargs) + return wrapper - return decorator + return decorator class WindowEvaluator(metaclass=ABCMeta): From fbe7ee2f2d0c36bb46b9c4ae7bad8e41d0228efa Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sat, 8 Mar 2025 16:14:09 -0500 Subject: [PATCH 5/5] ruff formatting --- python/tests/test_udaf.py | 9 +++------ python/tests/test_udf.py | 4 ++-- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/python/tests/test_udaf.py b/python/tests/test_udaf.py index 87cfafdf9..e69c77d3c 100644 --- a/python/tests/test_udaf.py +++ b/python/tests/test_udaf.py @@ -116,8 +116,8 @@ def test_udaf_aggregate(df): assert result.column(0) == pa.array([1.0 + 2.0 + 3.0]) + def test_udaf_decorator_aggregate(df): - @udaf(pa.float64(), pa.float64(), [pa.float64()], "immutable") def summarize(): return Summarize() @@ -165,11 +165,8 @@ def test_udaf_aggregate_with_arguments(df): def test_udaf_decorator_aggregate_with_arguments(df): bias = 10.0 - - @udaf(pa.float64(), - pa.float64(), - [pa.float64()], - "immutable") + + @udaf(pa.float64(), pa.float64(), [pa.float64()], "immutable") def summarize(): return Summarize(bias) diff --git a/python/tests/test_udf.py b/python/tests/test_udf.py index 5f820d4dc..a6c047552 100644 --- a/python/tests/test_udf.py +++ b/python/tests/test_udf.py @@ -80,7 +80,7 @@ def __call__(self, values: pa.Array) -> pa.Array: return pa.array(v.as_py() >= self.threshold for v in values) -def test_udf_with_parameters(df) -> None: +def test_udf_with_parameters_function(df) -> None: udf_no_param = udf( OverThresholdUDF(), pa.int64(), @@ -106,7 +106,7 @@ def test_udf_with_parameters(df) -> None: assert result == pa.array([False, True, True]) -def test_udf_with_parameters(df) -> None: +def test_udf_with_parameters_decorator(df) -> None: @udf([pa.int64()], pa.bool_(), "immutable") def udf_no_param(values: pa.Array) -> pa.Array: return OverThresholdUDF()(values)