Skip to content
136 changes: 101 additions & 35 deletions python/datafusion/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from datafusion.expr import Expr
from typing import Callable, TYPE_CHECKING, TypeVar
from abc import ABCMeta, abstractmethod
from typing import List
from typing import List, Optional
from enum import Enum
import pyarrow

Expand Down Expand Up @@ -84,16 +84,18 @@ class ScalarUDF:

def __init__(
self,
name: str | None,
name: Optional[str],
func: Callable[..., _R],
input_types: list[pyarrow.DataType],
input_types: pyarrow.DataType | list[pyarrow.DataType],
return_type: _R,
volatility: Volatility | str,
) -> None:
"""Instantiate a scalar user-defined function (UDF).

See helper method :py:func:`udf` for argument details.
"""
if isinstance(input_types, pyarrow.DataType):
input_types = [input_types]
self._udf = df_internal.ScalarUDF(
name, func, input_types, return_type, str(volatility)
)
Expand All @@ -104,16 +106,16 @@ def __call__(self, *args: Expr) -> Expr:
This function is not typically called by an end user. These calls will
occur during the evaluation of the dataframe.
"""
args = [arg.expr for arg in args]
return Expr(self._udf.__call__(*args))
args_raw = [arg.expr for arg in args]
return Expr(self._udf.__call__(*args_raw))

@staticmethod
def udf(
func: Callable[..., _R],
input_types: list[pyarrow.DataType],
return_type: _R,
volatility: Volatility | str,
name: str | None = None,
name: Optional[str] = None,
) -> ScalarUDF:
"""Create a new User-Defined Function.

Expand All @@ -133,7 +135,10 @@ def udf(
if not callable(func):
raise TypeError("`func` argument must be callable")
if name is None:
name = func.__qualname__.lower()
if hasattr(func, "__qualname__"):
name = func.__qualname__.lower()
else:
name = func.__class__.__name__.lower()
return ScalarUDF(
name=name,
func=func,
Expand Down Expand Up @@ -167,10 +172,6 @@ def evaluate(self) -> pyarrow.Scalar:
pass


if TYPE_CHECKING:
_A = TypeVar("_A", bound=(Callable[..., _R], Accumulator))


class AggregateUDF:
"""Class for performing scalar user-defined functions (UDF).

Expand All @@ -180,10 +181,10 @@ class AggregateUDF:

def __init__(
self,
name: str | None,
accumulator: _A,
name: Optional[str],
accumulator: Callable[[], Accumulator],
input_types: list[pyarrow.DataType],
return_type: _R,
return_type: pyarrow.DataType,
state_type: list[pyarrow.DataType],
volatility: Volatility | str,
) -> None:
Expand All @@ -193,7 +194,12 @@ def __init__(
descriptions.
"""
self._udaf = df_internal.AggregateUDF(
name, accumulator, input_types, return_type, state_type, str(volatility)
name,
accumulator,
input_types,
return_type,
state_type,
str(volatility),
)

def __call__(self, *args: Expr) -> Expr:
Expand All @@ -202,21 +208,52 @@ def __call__(self, *args: Expr) -> Expr:
This function is not typically called by an end user. These calls will
occur during the evaluation of the dataframe.
"""
args = [arg.expr for arg in args]
return Expr(self._udaf.__call__(*args))
args_raw = [arg.expr for arg in args]
return Expr(self._udaf.__call__(*args_raw))

@staticmethod
def udaf(
accum: _A,
input_types: list[pyarrow.DataType],
return_type: _R,
accum: Callable[[], Accumulator],
input_types: pyarrow.DataType | list[pyarrow.DataType],
return_type: pyarrow.DataType,
state_type: list[pyarrow.DataType],
volatility: Volatility | str,
name: str | None = None,
name: Optional[str] = None,
) -> AggregateUDF:
"""Create a new User-Defined Aggregate Function.

The accumulator function must be callable and implement :py:class:`Accumulator`.
If your :py:class:`Accumulator` can be instantiated with no arguments, you
can simply pass it's type as ``accum``. If you need to pass additional arguments
to it's constructor, you can define a lambda or a factory method. During runtime
the :py:class:`Accumulator` will be constructed for every instance in
which this UDAF is used. The following examples are all valid.

.. code-block:: python
import pyarrow as pa
import pyarrow.compute as pc

class Summarize(Accumulator):
def __init__(self, bias: float = 0.0):
self._sum = pa.scalar(bias)

def state(self) -> List[pa.Scalar]:
return [self._sum]

def update(self, values: pa.Array) -> None:
self._sum = pa.scalar(self._sum.as_py() + pc.sum(values).as_py())

def merge(self, states: List[pa.Array]) -> None:
self._sum = pa.scalar(self._sum.as_py() + pc.sum(states[0]).as_py())

def evaluate(self) -> pa.Scalar:
return self._sum

def sum_bias_10() -> Summarize:
return Summarize(10.0)

udaf1 = udaf(Summarize, pa.float64(), pa.float64(), [pa.float64()], "immutable")
udaf2 = udaf(sum_bias_10, pa.float64(), pa.float64(), [pa.float64()], "immutable")
udaf3 = udaf(lambda: Summarize(20.0), pa.float64(), pa.float64(), [pa.float64()], "immutable")

Args:
accum: The accumulator python function.
Expand All @@ -229,14 +266,16 @@ def udaf(
Returns:
A user-defined aggregate function, which can be used in either data
aggregation or window function calls.
"""
if not issubclass(accum, Accumulator):
""" # noqa W505
if not callable(accum):
raise TypeError("`func` must be callable.")
if not isinstance(accum.__call__(), Accumulator):
raise TypeError(
"`accum` must implement the abstract base class Accumulator"
"Accumulator must implement the abstract base class Accumulator"
)
if name is None:
name = accum.__qualname__.lower()
if isinstance(input_types, pyarrow.lib.DataType):
name = accum.__call__().__class__.__qualname__.lower()
if isinstance(input_types, pyarrow.DataType):
input_types = [input_types]
return AggregateUDF(
name=name,
Expand Down Expand Up @@ -421,8 +460,8 @@ class WindowUDF:

def __init__(
self,
name: str | None,
func: WindowEvaluator,
name: Optional[str],
func: Callable[[], WindowEvaluator],
input_types: list[pyarrow.DataType],
return_type: pyarrow.DataType,
volatility: Volatility | str,
Expand All @@ -447,30 +486,57 @@ def __call__(self, *args: Expr) -> Expr:

@staticmethod
def udwf(
func: WindowEvaluator,
func: Callable[[], WindowEvaluator],
input_types: pyarrow.DataType | list[pyarrow.DataType],
return_type: pyarrow.DataType,
volatility: Volatility | str,
name: str | None = None,
name: Optional[str] = None,
) -> WindowUDF:
"""Create a new User-Defined Window Function.

If your :py:class:`WindowEvaluator` can be instantiated with no arguments, you
can simply pass it's type as ``func``. If you need to pass additional arguments
to it's constructor, you can define a lambda or a factory method. During runtime
the :py:class:`WindowEvaluator` will be constructed for every instance in
which this UDWF is used. The following examples are all valid.

.. code-block:: python

import pyarrow as pa

class BiasedNumbers(WindowEvaluator):
def __init__(self, start: int = 0) -> None:
self.start = start

def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array:
return pa.array([self.start + i for i in range(num_rows)])

def bias_10() -> BiasedNumbers:
return BiasedNumbers(10)

udwf1 = udwf(BiasedNumbers, pa.int64(), pa.int64(), "immutable")
udwf2 = udwf(bias_10, pa.int64(), pa.int64(), "immutable")
udwf3 = udwf(lambda: BiasedNumbers(20), pa.int64(), pa.int64(), "immutable")

Args:
func: The python function.
func: A callable to create the window function.
input_types: The data types of the arguments to ``func``.
return_type: The data type of the return value.
volatility: See :py:class:`Volatility` for allowed values.
arguments: A list of arguments to pass in to the __init__ method for accum.
name: A descriptive name for the function.

Returns:
A user-defined window function.
"""
if not isinstance(func, WindowEvaluator):
""" # noqa W505
if not callable(func):
raise TypeError("`func` must be callable.")
if not isinstance(func.__call__(), WindowEvaluator):
raise TypeError(
"`func` must implement the abstract base class WindowEvaluator"
)
if name is None:
name = func.__class__.__qualname__.lower()
name = func.__call__().__class__.__qualname__.lower()
if isinstance(input_types, pyarrow.DataType):
input_types = [input_types]
return WindowUDF(
Expand Down
16 changes: 0 additions & 16 deletions python/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
WindowFrame,
column,
literal,
udf,
)
from datafusion.expr import Window

Expand Down Expand Up @@ -236,21 +235,6 @@ def test_unnest_without_nulls(nested_df):
assert result.column(1) == pa.array([7, 8, 8, 9, 9, 9])


def test_udf(df):
# is_null is a pa function over arrays
is_null = udf(
lambda x: x.is_null(),
[pa.int64()],
pa.bool_(),
volatility="immutable",
)

df = df.select(is_null(column("a")))
result = df.collect()[0].column(0)

assert result == pa.array([False, False, False])


def test_join():
ctx = SessionContext()

Expand Down
File renamed without changes.
Loading
Loading