Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
257 changes: 187 additions & 70 deletions python/datafusion/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from __future__ import annotations

import functools
from abc import ABCMeta, abstractmethod
from enum import Enum
from typing import TYPE_CHECKING, Callable, List, Optional, TypeVar
Expand Down Expand Up @@ -110,43 +111,102 @@ def __call__(self, *args: Expr) -> Expr:
args_raw = [arg.expr for arg in args]
return Expr(self._udf.__call__(*args_raw))

@staticmethod
def udf(
func: Callable[..., _R],
input_types: list[pyarrow.DataType],
return_type: _R,
volatility: Volatility | str,
name: Optional[str] = None,
) -> ScalarUDF:
"""Create a new User-Defined Function.
class udf:
"""Create a new User-Defined Function (UDF).

This class can be used both as a **function** and as a **decorator**.

Usage:
- **As a function**: Call `udf(func, input_types, return_type, volatility,
name)`.
- **As a decorator**: Use `@udf(input_types, return_type, volatility,
name)`. In this case, do **not** pass `func` explicitly.

Args:
func: A callable python function.
input_types: The data types of the arguments to ``func``. This list
must be of the same length as the number of arguments.
return_type: The data type of the return value from the python
function.
volatility: See ``Volatility`` for allowed values.
name: A descriptive name for the function.
func (Callable, optional): **Only needed when calling as a function.**
Skip this argument when using `udf` as a decorator.
input_types (list[pyarrow.DataType]): The data types of the arguments
to `func`. This list must be of the same length as the number of
arguments.
return_type (_R): The data type of the return value from the function.
volatility (Volatility | str): See `Volatility` for allowed values.
name (Optional[str]): A descriptive name for the function.

Returns:
A user-defined aggregate function, which can be used in either data
aggregation or window function calls.
A user-defined function that can be used in SQL expressions,
data aggregation, or window function calls.

Example:
**Using `udf` as a function:**
```
def double_func(x):
return x * 2
double_udf = udf(double_func, [pyarrow.int32()], pyarrow.int32(),
"volatile", "double_it")
```

**Using `udf` as a decorator:**
```
@udf([pyarrow.int32()], pyarrow.int32(), "volatile", "double_it")
def double_udf(x):
return x * 2
```
"""
if not callable(func):
raise TypeError("`func` argument must be callable")
if name is None:
if hasattr(func, "__qualname__"):
name = func.__qualname__.lower()

def __new__(cls, *args, **kwargs):
"""Create a new UDF.

Trigger UDF function or decorator depending on if the first args is callable
"""
if args and callable(args[0]):
# Case 1: Used as a function, require the first parameter to be callable
return cls._function(*args, **kwargs)
else:
name = func.__class__.__name__.lower()
return ScalarUDF(
name=name,
func=func,
input_types=input_types,
return_type=return_type,
volatility=volatility,
)
# Case 2: Used as a decorator with parameters
return cls._decorator(*args, **kwargs)

@staticmethod
def _function(
func: Callable[..., _R],
input_types: list[pyarrow.DataType],
return_type: _R,
volatility: Volatility | str,
name: Optional[str] = None,
) -> ScalarUDF:
if not callable(func):
raise TypeError("`func` argument must be callable")
if name is None:
if hasattr(func, "__qualname__"):
name = func.__qualname__.lower()
else:
name = func.__class__.__name__.lower()
return ScalarUDF(
name=name,
func=func,
input_types=input_types,
return_type=return_type,
volatility=volatility,
)

@staticmethod
def _decorator(
input_types: list[pyarrow.DataType],
return_type: _R,
volatility: Volatility | str,
name: Optional[str] = None,
):
def decorator(func):
udf_caller = ScalarUDF.udf(
func, input_types, return_type, volatility, name
)

@functools.wraps(func)
def wrapper(*args, **kwargs):
return udf_caller(*args, **kwargs)

return wrapper

return decorator


class Accumulator(metaclass=ABCMeta):
Expand Down Expand Up @@ -212,25 +272,27 @@ def __call__(self, *args: Expr) -> Expr:
args_raw = [arg.expr for arg in args]
return Expr(self._udaf.__call__(*args_raw))

@staticmethod
def udaf(
accum: Callable[[], Accumulator],
input_types: pyarrow.DataType | list[pyarrow.DataType],
return_type: pyarrow.DataType,
state_type: list[pyarrow.DataType],
volatility: Volatility | str,
name: Optional[str] = None,
) -> AggregateUDF:
"""Create a new User-Defined Aggregate Function.
class udaf:
"""Create a new User-Defined Aggregate Function (UDAF).

If your :py:class:`Accumulator` can be instantiated with no arguments, you
can simply pass it's type as ``accum``. If you need to pass additional arguments
to it's constructor, you can define a lambda or a factory method. During runtime
the :py:class:`Accumulator` will be constructed for every instance in
which this UDAF is used. The following examples are all valid.
This class allows you to define an **aggregate function** that can be used in
data aggregation or window function calls.

.. code-block:: python
Usage:
- **As a function**: Call `udaf(accum, input_types, return_type, state_type,
volatility, name)`.
- **As a decorator**: Use `@udaf(input_types, return_type, state_type,
volatility, name)`.
When using `udaf` as a decorator, **do not pass `accum` explicitly**.

**Function example:**

If your `:py:class:Accumulator` can be instantiated with no arguments, you
can simply pass it's type as `accum`. If you need to pass additional
arguments to it's constructor, you can define a lambda or a factory method.
During runtime the `:py:class:Accumulator` will be constructed for every
instance in which this UDAF is used. The following examples are all valid.
```
import pyarrow as pa
import pyarrow.compute as pc

Expand All @@ -253,12 +315,24 @@ def evaluate(self) -> pa.Scalar:
def sum_bias_10() -> Summarize:
return Summarize(10.0)

udaf1 = udaf(Summarize, pa.float64(), pa.float64(), [pa.float64()], "immutable")
udaf2 = udaf(sum_bias_10, pa.float64(), pa.float64(), [pa.float64()], "immutable")
udaf3 = udaf(lambda: Summarize(20.0), pa.float64(), pa.float64(), [pa.float64()], "immutable")
udaf1 = udaf(Summarize, pa.float64(), pa.float64(), [pa.float64()],
"immutable")
udaf2 = udaf(sum_bias_10, pa.float64(), pa.float64(), [pa.float64()],
"immutable")
udaf3 = udaf(lambda: Summarize(20.0), pa.float64(), pa.float64(),
[pa.float64()], "immutable")
```

**Decorator example:**
```
@udaf(pa.float64(), pa.float64(), [pa.float64()], "immutable")
def udf4() -> Summarize:
return Summarize(10.0)
```

Args:
accum: The accumulator python function.
accum: The accumulator python function. **Only needed when calling as a
function. Skip this argument when using `udaf` as a decorator.**
input_types: The data types of the arguments to ``accum``.
return_type: The data type of the return value.
state_type: The data types of the intermediate accumulation.
Expand All @@ -268,26 +342,69 @@ def sum_bias_10() -> Summarize:
Returns:
A user-defined aggregate function, which can be used in either data
aggregation or window function calls.
""" # noqa W505
if not callable(accum):
raise TypeError("`func` must be callable.")
if not isinstance(accum.__call__(), Accumulator):
raise TypeError(
"Accumulator must implement the abstract base class Accumulator"
"""

def __new__(cls, *args, **kwargs):
"""Create a new UDAF.

Trigger UDAF function or decorator depending on if the first args is
callable
"""
if args and callable(args[0]):
# Case 1: Used as a function, require the first parameter to be callable
return cls._function(*args, **kwargs)
else:
# Case 2: Used as a decorator with parameters
return cls._decorator(*args, **kwargs)

@staticmethod
def _function(
accum: Callable[[], Accumulator],
input_types: pyarrow.DataType | list[pyarrow.DataType],
return_type: pyarrow.DataType,
state_type: list[pyarrow.DataType],
volatility: Volatility | str,
name: Optional[str] = None,
) -> AggregateUDF:
if not callable(accum):
raise TypeError("`func` must be callable.")
if not isinstance(accum.__call__(), Accumulator):
raise TypeError(
"Accumulator must implement the abstract base class Accumulator"
)
if name is None:
name = accum.__call__().__class__.__qualname__.lower()
if isinstance(input_types, pyarrow.DataType):
input_types = [input_types]
return AggregateUDF(
name=name,
accumulator=accum,
input_types=input_types,
return_type=return_type,
state_type=state_type,
volatility=volatility,
)
if name is None:
name = accum.__call__().__class__.__qualname__.lower()
assert name is not None
if isinstance(input_types, pyarrow.DataType):
input_types = [input_types]
return AggregateUDF(
name=name,
accumulator=accum,
input_types=input_types,
return_type=return_type,
state_type=state_type,
volatility=volatility,
)

@staticmethod
def _decorator(
input_types: pyarrow.DataType | list[pyarrow.DataType],
return_type: pyarrow.DataType,
state_type: list[pyarrow.DataType],
volatility: Volatility | str,
name: Optional[str] = None,
):
def decorator(accum: Callable[[], Accumulator]):
udaf_caller = AggregateUDF.udaf(
accum, input_types, return_type, state_type, volatility, name
)

@functools.wraps(accum)
def wrapper(*args, **kwargs):
return udaf_caller(*args, **kwargs)

return wrapper

return decorator


class WindowEvaluator(metaclass=ABCMeta):
Expand Down
42 changes: 42 additions & 0 deletions python/tests/test_udaf.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,26 @@ def test_udaf_aggregate(df):
assert result.column(0) == pa.array([1.0 + 2.0 + 3.0])


def test_udaf_decorator_aggregate(df):
@udaf(pa.float64(), pa.float64(), [pa.float64()], "immutable")
def summarize():
return Summarize()

df1 = df.aggregate([], [summarize(column("a"))])

# execute and collect the first (and only) batch
result = df1.collect()[0]

assert result.column(0) == pa.array([1.0 + 2.0 + 3.0])

df2 = df.aggregate([], [summarize(column("a"))])

# Run a second time to ensure the state is properly reset
result = df2.collect()[0]

assert result.column(0) == pa.array([1.0 + 2.0 + 3.0])


def test_udaf_aggregate_with_arguments(df):
bias = 10.0

Expand All @@ -143,6 +163,28 @@ def test_udaf_aggregate_with_arguments(df):
assert result.column(0) == pa.array([bias + 1.0 + 2.0 + 3.0])


def test_udaf_decorator_aggregate_with_arguments(df):
bias = 10.0

@udaf(pa.float64(), pa.float64(), [pa.float64()], "immutable")
def summarize():
return Summarize(bias)

df1 = df.aggregate([], [summarize(column("a"))])

# execute and collect the first (and only) batch
result = df1.collect()[0]

assert result.column(0) == pa.array([bias + 1.0 + 2.0 + 3.0])

df2 = df.aggregate([], [summarize(column("a"))])

# Run a second time to ensure the state is properly reset
result = df2.collect()[0]

assert result.column(0) == pa.array([bias + 1.0 + 2.0 + 3.0])


def test_group_by(df):
summarize = udaf(
Summarize,
Expand Down
Loading