Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/datafusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,9 @@ def lit(value):


udf = ScalarUDF.udf
udf_decorator = ScalarUDF.udf_decorator

udaf = AggregateUDF.udaf
udaf_decorator = AggregateUDF.udaf_decorator

udwf = WindowUDF.udwf
47 changes: 47 additions & 0 deletions python/datafusion/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from abc import ABCMeta, abstractmethod
from enum import Enum
from typing import TYPE_CHECKING, Callable, List, Optional, TypeVar
import functools

import pyarrow

Expand Down Expand Up @@ -148,6 +149,27 @@ def udf(
volatility=volatility,
)

@staticmethod
def udf_decorator(
input_types: list[pyarrow.DataType],
return_type: _R,
volatility: Volatility | str,
name: Optional[str] = None
):
def decorator(func):
udf_caller = ScalarUDF.udf(
func,
input_types,
return_type,
volatility,
name
)

@functools.wraps(func)
def wrapper(*args, **kwargs):
return udf_caller(*args, **kwargs)
return wrapper
return decorator

class Accumulator(metaclass=ABCMeta):
"""Defines how an :py:class:`AggregateUDF` accumulates values."""
Expand Down Expand Up @@ -287,6 +309,31 @@ def sum_bias_10() -> Summarize:
state_type=state_type,
volatility=volatility,
)

@staticmethod
def udaf_decorator(
input_types: pyarrow.DataType | list[pyarrow.DataType],
return_type: pyarrow.DataType,
state_type: list[pyarrow.DataType],
volatility: Volatility | str,
name: Optional[str] = None
):
def decorator(accum: Callable[[], Accumulator]):
udaf_caller = AggregateUDF.udaf(
accum,
input_types,
return_type,
state_type,
volatility,
name
)

@functools.wraps(accum)
def wrapper(*args, **kwargs):
return udaf_caller(*args, **kwargs)
return wrapper
return decorator



class WindowEvaluator(metaclass=ABCMeta):
Expand Down
50 changes: 49 additions & 1 deletion python/tests/test_udaf.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import pyarrow as pa
import pyarrow.compute as pc
import pytest
from datafusion import Accumulator, column, udaf
from datafusion import Accumulator, column, udaf, udaf_decorator


class Summarize(Accumulator):
Expand Down Expand Up @@ -116,6 +116,29 @@ def test_udaf_aggregate(df):

assert result.column(0) == pa.array([1.0 + 2.0 + 3.0])

def test_udaf_decorator_aggregate(df):

@udaf_decorator(pa.float64(),
pa.float64(),
[pa.float64()],
"immutable")
def summarize():
return Summarize()

df1 = df.aggregate([], [summarize(column("a"))])

# execute and collect the first (and only) batch
result = df1.collect()[0]

assert result.column(0) == pa.array([1.0 + 2.0 + 3.0])

df2 = df.aggregate([], [summarize(column("a"))])

# Run a second time to ensure the state is properly reset
result = df2.collect()[0]

assert result.column(0) == pa.array([1.0 + 2.0 + 3.0])


def test_udaf_aggregate_with_arguments(df):
bias = 10.0
Expand Down Expand Up @@ -143,6 +166,31 @@ def test_udaf_aggregate_with_arguments(df):
assert result.column(0) == pa.array([bias + 1.0 + 2.0 + 3.0])


def test_udaf_decorator_aggregate_with_arguments(df):
bias = 10.0

@udaf_decorator(pa.float64(),
pa.float64(),
[pa.float64()],
"immutable")
def summarize():
return Summarize(bias)

df1 = df.aggregate([], [summarize(column("a"))])

# execute and collect the first (and only) batch
result = df1.collect()[0]

assert result.column(0) == pa.array([bias + 1.0 + 2.0 + 3.0])

df2 = df.aggregate([], [summarize(column("a"))])

# Run a second time to ensure the state is properly reset
result = df2.collect()[0]

assert result.column(0) == pa.array([bias + 1.0 + 2.0 + 3.0])


def test_group_by(df):
summarize = udaf(
Summarize,
Expand Down
42 changes: 36 additions & 6 deletions python/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@

import pyarrow as pa
import pytest
from datafusion import column, udf
from datafusion import column, udf, udf_decorator


@pytest.fixture
def df(ctx):
# create a RecordBatch and a new DataFrame from it
batch = pa.RecordBatch.from_arrays(
[pa.array([1, 2, 3]), pa.array([4, 4, 6])],
[pa.array([1, 2, 3]), pa.array([4, 4, None])],
names=["a", "b"],
)
return ctx.create_dataframe([[batch]], name="test_table")
Expand All @@ -39,10 +39,20 @@ def test_udf(df):
volatility="immutable",
)

df = df.select(is_null(column("a")))
df = df.select(is_null(column("b")))
result = df.collect()[0].column(0)

assert result == pa.array([False, False, False])
assert result == pa.array([False, False, True])


def test_udf_decorator(df):
@udf_decorator([pa.int64()], pa.bool_(), "immutable")
def is_null(x: pa.Array) -> pa.Array:
return x.is_null()

df = df.select(is_null(column("b")))
result = df.collect()[0].column(0)
assert result == pa.array([False, False, True])


def test_register_udf(ctx, df) -> None:
Expand All @@ -56,10 +66,10 @@ def test_register_udf(ctx, df) -> None:

ctx.register_udf(is_null)

df_result = ctx.sql("select is_null(a) from test_table")
df_result = ctx.sql("select is_null(b) from test_table")
result = df_result.collect()[0].column(0)

assert result == pa.array([False, False, False])
assert result == pa.array([False, False, True])


class OverThresholdUDF:
Expand Down Expand Up @@ -94,3 +104,23 @@ def test_udf_with_parameters(df) -> None:
result = df2.collect()[0].column(0)

assert result == pa.array([False, True, True])


def test_udf_with_parameters(df) -> None:
@udf_decorator([pa.int64()], pa.bool_(), "immutable")
def udf_no_param(values: pa.Array) -> pa.Array:
return OverThresholdUDF()(values)

df1 = df.select(udf_no_param(column("a")))
result = df1.collect()[0].column(0)

assert result == pa.array([True, True, True])

@udf_decorator([pa.int64()], pa.bool_(), "immutable")
def udf_with_param(values: pa.Array) -> pa.Array:
return OverThresholdUDF(2)(values)

df2 = df.select(udf_with_param(column("a")))
result = df2.collect()[0].column(0)

assert result == pa.array([False, True, True])