Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
b708133
feat: Introduce create_udwf method for User-Defined Window Functions
kosiew Mar 13, 2025
333b80e
refactor: Simplify UDWF test suite and introduce SimpleWindowCount ev…
kosiew Mar 13, 2025
a52af17
fix: Update type alias import to use typing_extensions for compatibility
kosiew Mar 13, 2025
cd972b5
Add udwf tests for multiple input types and decorator syntax
kosiew Mar 13, 2025
d7ffa02
replace old def udwf
kosiew Mar 13, 2025
3eade95
refactor: Simplify df fixture by passing ctx as an argument
kosiew Mar 13, 2025
86cc70e
refactor: Rename DataFrame fixtures and update test functions
kosiew Mar 13, 2025
ae62383
refactor: Update udwf calls in WindowUDF to use BiasedNumbers directly
kosiew Mar 13, 2025
4c397cf
feat: Add overloads for udwf function to support multiple input types…
kosiew Mar 13, 2025
1164374
refactor: Simplify udwf method signature by removing redundant type h…
kosiew Mar 13, 2025
d29acf6
refactor: Remove state_type from udwf method signature and update ret…
kosiew Mar 13, 2025
46097d1
refactor: Update volatility parameter type in udwf method signature t…
kosiew Mar 13, 2025
b0a1803
Fix ruff errors
kosiew Mar 13, 2025
ad33378
fix C901 for def udwf
kosiew Mar 13, 2025
6f25337
refactor: Update udwf method signature and simplify input handling
kosiew Mar 13, 2025
20d5dd9
refactor: Rename input_type to input_types in udwf method signature f…
kosiew Mar 13, 2025
16dbe5f
refactor: Enhance typing in udf.py by introducing Protocol for Window…
kosiew Mar 14, 2025
78c0203
Revert "refactor: Enhance typing in udf.py by introducing Protocol fo…
kosiew Mar 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 99 additions & 24 deletions python/datafusion/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,31 +621,48 @@ def __call__(self, *args: Expr) -> Expr:
args_raw = [arg.expr for arg in args]
return Expr(self._udwf.__call__(*args_raw))

@overload
@staticmethod
def udwf(
input_types: pa.DataType | list[pa.DataType],
return_type: pa.DataType,
volatility: Volatility | str,
name: Optional[str] = None,
) -> Callable[..., WindowUDF]: ...

@overload
@staticmethod
def udwf(
func: Callable[[], WindowEvaluator],
input_types: pa.DataType | list[pa.DataType],
return_type: pa.DataType,
volatility: Volatility | str,
name: Optional[str] = None,
) -> WindowUDF:
"""Create a new User-Defined Window Function.
) -> WindowUDF: ...

If your :py:class:`WindowEvaluator` can be instantiated with no arguments, you
can simply pass it's type as ``func``. If you need to pass additional arguments
to it's constructor, you can define a lambda or a factory method. During runtime
the :py:class:`WindowEvaluator` will be constructed for every instance in
which this UDWF is used. The following examples are all valid.
@staticmethod
def udwf(*args: Any, **kwargs: Any): # noqa: D417
"""Create a new User-Defined Window Function (UDWF).

.. code-block:: python
This class can be used both as a **function** and as a **decorator**.

Usage:
- **As a function**: Call `udwf(func, input_types, return_type, volatility,
name)`.
- **As a decorator**: Use `@udwf(input_types, return_type, volatility,
name)`. When using `udwf` as a decorator, **do not pass `func`
explicitly**.

**Function example:**
```
import pyarrow as pa

class BiasedNumbers(WindowEvaluator):
def __init__(self, start: int = 0) -> None:
self.start = start

def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array:
def evaluate_all(self, values: list[pa.Array],
num_rows: int) -> pa.Array:
return pa.array([self.start + i for i in range(num_rows)])

def bias_10() -> BiasedNumbers:
Expand All @@ -655,35 +672,93 @@ def bias_10() -> BiasedNumbers:
udwf2 = udwf(bias_10, pa.int64(), pa.int64(), "immutable")
udwf3 = udwf(lambda: BiasedNumbers(20), pa.int64(), pa.int64(), "immutable")

```

**Decorator example:**
```
@udwf(pa.int64(), pa.int64(), "immutable")
def biased_numbers() -> BiasedNumbers:
return BiasedNumbers(10)
```

Args:
func: A callable to create the window function.
input_types: The data types of the arguments to ``func``.
func: **Only needed when calling as a function. Skip this argument when
using `udwf` as a decorator.**
input_types: The data types of the arguments.
return_type: The data type of the return value.
volatility: See :py:class:`Volatility` for allowed values.
arguments: A list of arguments to pass in to the __init__ method for accum.
name: A descriptive name for the function.

Returns:
A user-defined window function.
""" # noqa: W505, E501
A user-defined window function that can be used in window function calls.
"""
if args and callable(args[0]):
# Case 1: Used as a function, require the first parameter to be callable
return WindowUDF._create_window_udf(*args, **kwargs)
# Case 2: Used as a decorator with parameters
return WindowUDF._create_window_udf_decorator(*args, **kwargs)

@staticmethod
def _create_window_udf(
func: Callable[[], WindowEvaluator],
input_types: pa.DataType | list[pa.DataType],
return_type: pa.DataType,
volatility: Volatility | str,
name: Optional[str] = None,
) -> WindowUDF:
"""Create a WindowUDF instance from function arguments."""
if not callable(func):
msg = "`func` must be callable."
raise TypeError(msg)
if not isinstance(func(), WindowEvaluator):
msg = "`func` must implement the abstract base class WindowEvaluator"
raise TypeError(msg)
if name is None:
name = func().__class__.__qualname__.lower()
if isinstance(input_types, pa.DataType):
input_types = [input_types]
return WindowUDF(
name=name,
func=func,
input_types=input_types,
return_type=return_type,
volatility=volatility,

name = name or func.__qualname__.lower()
input_types = (
[input_types] if isinstance(input_types, pa.DataType) else input_types
)

return WindowUDF(name, func, input_types, return_type, volatility)

@staticmethod
def _get_default_name(func: Callable) -> str:
"""Get the default name for a function based on its attributes."""
if hasattr(func, "__qualname__"):
return func.__qualname__.lower()
return func.__class__.__name__.lower()

@staticmethod
def _normalize_input_types(
input_types: pa.DataType | list[pa.DataType],
) -> list[pa.DataType]:
"""Convert a single DataType to a list if needed."""
if isinstance(input_types, pa.DataType):
return [input_types]
return input_types

@staticmethod
def _create_window_udf_decorator(
input_types: pa.DataType | list[pa.DataType],
return_type: pa.DataType,
volatility: Volatility | str,
name: Optional[str] = None,
) -> Callable[[Callable[[], WindowEvaluator]], Callable[..., Expr]]:
"""Create a decorator for a WindowUDF."""

def decorator(func: Callable[[], WindowEvaluator]) -> Callable[..., Expr]:
udwf_caller = WindowUDF._create_window_udf(
func, input_types, return_type, volatility, name
)

@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Expr:
return udwf_caller(*args, **kwargs)

return wrapper

return decorator


# Convenience exports so we can import instead of treating as
# variables at the package root
Expand Down
170 changes: 165 additions & 5 deletions python/tests/test_udwf.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,14 +162,27 @@ def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array:
return pa.array(results)


class SimpleWindowCount(WindowEvaluator):
"""A simple window evaluator that counts rows."""

def __init__(self, base: int = 0) -> None:
self.base = base

def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array:
return pa.array([self.base + i for i in range(num_rows)])


class NotSubclassOfWindowEvaluator:
pass


@pytest.fixture
def df():
ctx = SessionContext()
def ctx():
return SessionContext()


@pytest.fixture
def complex_window_df(ctx):
# create a RecordBatch and a new DataFrame from it
batch = pa.RecordBatch.from_arrays(
[
Expand All @@ -182,7 +195,17 @@ def df():
return ctx.create_dataframe([[batch]])


def test_udwf_errors(df):
@pytest.fixture
def count_window_df(ctx):
# create a RecordBatch and a new DataFrame from it
batch = pa.RecordBatch.from_arrays(
[pa.array([1, 2, 3]), pa.array([4, 4, 6])],
names=["a", "b"],
)
return ctx.create_dataframe([[batch]], name="test_table")


def test_udwf_errors(complex_window_df):
with pytest.raises(TypeError):
udwf(
NotSubclassOfWindowEvaluator,
Expand All @@ -192,6 +215,103 @@ def test_udwf_errors(df):
)


def test_udwf_errors_with_message():
"""Test error cases for UDWF creation."""
with pytest.raises(
TypeError, match="`func` must implement the abstract base class WindowEvaluator"
):
udwf(
NotSubclassOfWindowEvaluator, pa.int64(), pa.int64(), volatility="immutable"
)


def test_udwf_basic_usage(count_window_df):
"""Test basic UDWF usage with a simple counting window function."""
simple_count = udwf(
SimpleWindowCount, pa.int64(), pa.int64(), volatility="immutable"
)

df = count_window_df.select(
simple_count(column("a"))
.window_frame(WindowFrame("rows", None, None))
.build()
.alias("count")
)
result = df.collect()[0]
assert result.column(0) == pa.array([0, 1, 2])


def test_udwf_with_args(count_window_df):
"""Test UDWF with constructor arguments."""
count_base10 = udwf(
lambda: SimpleWindowCount(10), pa.int64(), pa.int64(), volatility="immutable"
)

df = count_window_df.select(
count_base10(column("a"))
.window_frame(WindowFrame("rows", None, None))
.build()
.alias("count")
)
result = df.collect()[0]
assert result.column(0) == pa.array([10, 11, 12])


def test_udwf_decorator_basic(count_window_df):
"""Test UDWF used as a decorator."""

@udwf([pa.int64()], pa.int64(), "immutable")
def window_count() -> WindowEvaluator:
return SimpleWindowCount()

df = count_window_df.select(
window_count(column("a"))
.window_frame(WindowFrame("rows", None, None))
.build()
.alias("count")
)
result = df.collect()[0]
assert result.column(0) == pa.array([0, 1, 2])


def test_udwf_decorator_with_args(count_window_df):
"""Test UDWF decorator with constructor arguments."""

@udwf([pa.int64()], pa.int64(), "immutable")
def window_count_base10() -> WindowEvaluator:
return SimpleWindowCount(10)

df = count_window_df.select(
window_count_base10(column("a"))
.window_frame(WindowFrame("rows", None, None))
.build()
.alias("count")
)
result = df.collect()[0]
assert result.column(0) == pa.array([10, 11, 12])


def test_register_udwf(ctx, count_window_df):
"""Test registering and using UDWF in SQL context."""
window_count = udwf(
SimpleWindowCount,
[pa.int64()],
pa.int64(),
volatility="immutable",
name="window_count",
)

ctx.register_udwf(window_count)
result = ctx.sql(
"""
SELECT window_count(a)
OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED
FOLLOWING) FROM test_table
"""
).collect()[0]
assert result.column(0) == pa.array([0, 1, 2])


smooth_default = udwf(
ExponentialSmoothDefault,
pa.float64(),
Expand Down Expand Up @@ -299,10 +419,50 @@ def test_udwf_errors(df):


@pytest.mark.parametrize(("name", "expr", "expected"), data_test_udwf_functions)
def test_udwf_functions(df, name, expr, expected):
df = df.select("a", "b", f.round(expr, lit(3)).alias(name))
def test_udwf_functions(complex_window_df, name, expr, expected):
df = complex_window_df.select("a", "b", f.round(expr, lit(3)).alias(name))

# execute and collect the first (and only) batch
result = df.sort(column("a")).select(column(name)).collect()[0]

assert result.column(0) == pa.array(expected)


@pytest.mark.parametrize(
"udwf_func",
[
udwf(SimpleWindowCount, pa.int64(), pa.int64(), "immutable"),
udwf(SimpleWindowCount, [pa.int64()], pa.int64(), "immutable"),
udwf([pa.int64()], pa.int64(), "immutable")(lambda: SimpleWindowCount()),
udwf(pa.int64(), pa.int64(), "immutable")(lambda: SimpleWindowCount()),
],
)
def test_udwf_overloads(udwf_func, count_window_df):
df = count_window_df.select(
udwf_func(column("a"))
.window_frame(WindowFrame("rows", None, None))
.build()
.alias("count")
)
result = df.collect()[0]
assert result.column(0) == pa.array([0, 1, 2])


def test_udwf_named_function(ctx, count_window_df):
"""Test UDWF with explicit name parameter."""
window_count = udwf(
SimpleWindowCount,
pa.int64(),
pa.int64(),
volatility="immutable",
name="my_custom_counter",
)

ctx.register_udwf(window_count)
result = ctx.sql(
"""
SELECT my_custom_counter(a)
OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED
FOLLOWING) FROM test_table"""
).collect()[0]
assert result.column(0) == pa.array([0, 1, 2])