diff --git a/python/datafusion/udf.py b/python/datafusion/udf.py index 603b7063d..e93a34ca5 100644 --- a/python/datafusion/udf.py +++ b/python/datafusion/udf.py @@ -621,6 +621,16 @@ def __call__(self, *args: Expr) -> Expr: args_raw = [arg.expr for arg in args] return Expr(self._udwf.__call__(*args_raw)) + @overload + @staticmethod + def udwf( + input_types: pa.DataType | list[pa.DataType], + return_type: pa.DataType, + volatility: Volatility | str, + name: Optional[str] = None, + ) -> Callable[..., WindowUDF]: ... + + @overload @staticmethod def udwf( func: Callable[[], WindowEvaluator], @@ -628,24 +638,31 @@ def udwf( return_type: pa.DataType, volatility: Volatility | str, name: Optional[str] = None, - ) -> WindowUDF: - """Create a new User-Defined Window Function. + ) -> WindowUDF: ... - If your :py:class:`WindowEvaluator` can be instantiated with no arguments, you - can simply pass it's type as ``func``. If you need to pass additional arguments - to it's constructor, you can define a lambda or a factory method. During runtime - the :py:class:`WindowEvaluator` will be constructed for every instance in - which this UDWF is used. The following examples are all valid. + @staticmethod + def udwf(*args: Any, **kwargs: Any): # noqa: D417 + """Create a new User-Defined Window Function (UDWF). - .. code-block:: python + This class can be used both as a **function** and as a **decorator**. + + Usage: + - **As a function**: Call `udwf(func, input_types, return_type, volatility, + name)`. + - **As a decorator**: Use `@udwf(input_types, return_type, volatility, + name)`. When using `udwf` as a decorator, **do not pass `func` + explicitly**. + **Function example:** + ``` import pyarrow as pa class BiasedNumbers(WindowEvaluator): def __init__(self, start: int = 0) -> None: self.start = start - def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array: + def evaluate_all(self, values: list[pa.Array], + num_rows: int) -> pa.Array: return pa.array([self.start + i for i in range(num_rows)]) def bias_10() -> BiasedNumbers: @@ -655,35 +672,93 @@ def bias_10() -> BiasedNumbers: udwf2 = udwf(bias_10, pa.int64(), pa.int64(), "immutable") udwf3 = udwf(lambda: BiasedNumbers(20), pa.int64(), pa.int64(), "immutable") + ``` + + **Decorator example:** + ``` + @udwf(pa.int64(), pa.int64(), "immutable") + def biased_numbers() -> BiasedNumbers: + return BiasedNumbers(10) + ``` + Args: - func: A callable to create the window function. - input_types: The data types of the arguments to ``func``. + func: **Only needed when calling as a function. Skip this argument when + using `udwf` as a decorator.** + input_types: The data types of the arguments. return_type: The data type of the return value. volatility: See :py:class:`Volatility` for allowed values. - arguments: A list of arguments to pass in to the __init__ method for accum. name: A descriptive name for the function. Returns: - A user-defined window function. - """ # noqa: W505, E501 + A user-defined window function that can be used in window function calls. + """ + if args and callable(args[0]): + # Case 1: Used as a function, require the first parameter to be callable + return WindowUDF._create_window_udf(*args, **kwargs) + # Case 2: Used as a decorator with parameters + return WindowUDF._create_window_udf_decorator(*args, **kwargs) + + @staticmethod + def _create_window_udf( + func: Callable[[], WindowEvaluator], + input_types: pa.DataType | list[pa.DataType], + return_type: pa.DataType, + volatility: Volatility | str, + name: Optional[str] = None, + ) -> WindowUDF: + """Create a WindowUDF instance from function arguments.""" if not callable(func): msg = "`func` must be callable." raise TypeError(msg) if not isinstance(func(), WindowEvaluator): msg = "`func` must implement the abstract base class WindowEvaluator" raise TypeError(msg) - if name is None: - name = func().__class__.__qualname__.lower() - if isinstance(input_types, pa.DataType): - input_types = [input_types] - return WindowUDF( - name=name, - func=func, - input_types=input_types, - return_type=return_type, - volatility=volatility, + + name = name or func.__qualname__.lower() + input_types = ( + [input_types] if isinstance(input_types, pa.DataType) else input_types ) + return WindowUDF(name, func, input_types, return_type, volatility) + + @staticmethod + def _get_default_name(func: Callable) -> str: + """Get the default name for a function based on its attributes.""" + if hasattr(func, "__qualname__"): + return func.__qualname__.lower() + return func.__class__.__name__.lower() + + @staticmethod + def _normalize_input_types( + input_types: pa.DataType | list[pa.DataType], + ) -> list[pa.DataType]: + """Convert a single DataType to a list if needed.""" + if isinstance(input_types, pa.DataType): + return [input_types] + return input_types + + @staticmethod + def _create_window_udf_decorator( + input_types: pa.DataType | list[pa.DataType], + return_type: pa.DataType, + volatility: Volatility | str, + name: Optional[str] = None, + ) -> Callable[[Callable[[], WindowEvaluator]], Callable[..., Expr]]: + """Create a decorator for a WindowUDF.""" + + def decorator(func: Callable[[], WindowEvaluator]) -> Callable[..., Expr]: + udwf_caller = WindowUDF._create_window_udf( + func, input_types, return_type, volatility, name + ) + + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Expr: + return udwf_caller(*args, **kwargs) + + return wrapper + + return decorator + # Convenience exports so we can import instead of treating as # variables at the package root diff --git a/python/tests/test_udwf.py b/python/tests/test_udwf.py index 3d6dcf9d8..4190e7d64 100644 --- a/python/tests/test_udwf.py +++ b/python/tests/test_udwf.py @@ -162,14 +162,27 @@ def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array: return pa.array(results) +class SimpleWindowCount(WindowEvaluator): + """A simple window evaluator that counts rows.""" + + def __init__(self, base: int = 0) -> None: + self.base = base + + def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array: + return pa.array([self.base + i for i in range(num_rows)]) + + class NotSubclassOfWindowEvaluator: pass @pytest.fixture -def df(): - ctx = SessionContext() +def ctx(): + return SessionContext() + +@pytest.fixture +def complex_window_df(ctx): # create a RecordBatch and a new DataFrame from it batch = pa.RecordBatch.from_arrays( [ @@ -182,7 +195,17 @@ def df(): return ctx.create_dataframe([[batch]]) -def test_udwf_errors(df): +@pytest.fixture +def count_window_df(ctx): + # create a RecordBatch and a new DataFrame from it + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3]), pa.array([4, 4, 6])], + names=["a", "b"], + ) + return ctx.create_dataframe([[batch]], name="test_table") + + +def test_udwf_errors(complex_window_df): with pytest.raises(TypeError): udwf( NotSubclassOfWindowEvaluator, @@ -192,6 +215,103 @@ def test_udwf_errors(df): ) +def test_udwf_errors_with_message(): + """Test error cases for UDWF creation.""" + with pytest.raises( + TypeError, match="`func` must implement the abstract base class WindowEvaluator" + ): + udwf( + NotSubclassOfWindowEvaluator, pa.int64(), pa.int64(), volatility="immutable" + ) + + +def test_udwf_basic_usage(count_window_df): + """Test basic UDWF usage with a simple counting window function.""" + simple_count = udwf( + SimpleWindowCount, pa.int64(), pa.int64(), volatility="immutable" + ) + + df = count_window_df.select( + simple_count(column("a")) + .window_frame(WindowFrame("rows", None, None)) + .build() + .alias("count") + ) + result = df.collect()[0] + assert result.column(0) == pa.array([0, 1, 2]) + + +def test_udwf_with_args(count_window_df): + """Test UDWF with constructor arguments.""" + count_base10 = udwf( + lambda: SimpleWindowCount(10), pa.int64(), pa.int64(), volatility="immutable" + ) + + df = count_window_df.select( + count_base10(column("a")) + .window_frame(WindowFrame("rows", None, None)) + .build() + .alias("count") + ) + result = df.collect()[0] + assert result.column(0) == pa.array([10, 11, 12]) + + +def test_udwf_decorator_basic(count_window_df): + """Test UDWF used as a decorator.""" + + @udwf([pa.int64()], pa.int64(), "immutable") + def window_count() -> WindowEvaluator: + return SimpleWindowCount() + + df = count_window_df.select( + window_count(column("a")) + .window_frame(WindowFrame("rows", None, None)) + .build() + .alias("count") + ) + result = df.collect()[0] + assert result.column(0) == pa.array([0, 1, 2]) + + +def test_udwf_decorator_with_args(count_window_df): + """Test UDWF decorator with constructor arguments.""" + + @udwf([pa.int64()], pa.int64(), "immutable") + def window_count_base10() -> WindowEvaluator: + return SimpleWindowCount(10) + + df = count_window_df.select( + window_count_base10(column("a")) + .window_frame(WindowFrame("rows", None, None)) + .build() + .alias("count") + ) + result = df.collect()[0] + assert result.column(0) == pa.array([10, 11, 12]) + + +def test_register_udwf(ctx, count_window_df): + """Test registering and using UDWF in SQL context.""" + window_count = udwf( + SimpleWindowCount, + [pa.int64()], + pa.int64(), + volatility="immutable", + name="window_count", + ) + + ctx.register_udwf(window_count) + result = ctx.sql( + """ + SELECT window_count(a) + OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED + FOLLOWING) FROM test_table + """ + ).collect()[0] + assert result.column(0) == pa.array([0, 1, 2]) + + smooth_default = udwf( ExponentialSmoothDefault, pa.float64(), @@ -299,10 +419,50 @@ def test_udwf_errors(df): @pytest.mark.parametrize(("name", "expr", "expected"), data_test_udwf_functions) -def test_udwf_functions(df, name, expr, expected): - df = df.select("a", "b", f.round(expr, lit(3)).alias(name)) +def test_udwf_functions(complex_window_df, name, expr, expected): + df = complex_window_df.select("a", "b", f.round(expr, lit(3)).alias(name)) # execute and collect the first (and only) batch result = df.sort(column("a")).select(column(name)).collect()[0] assert result.column(0) == pa.array(expected) + + +@pytest.mark.parametrize( + "udwf_func", + [ + udwf(SimpleWindowCount, pa.int64(), pa.int64(), "immutable"), + udwf(SimpleWindowCount, [pa.int64()], pa.int64(), "immutable"), + udwf([pa.int64()], pa.int64(), "immutable")(lambda: SimpleWindowCount()), + udwf(pa.int64(), pa.int64(), "immutable")(lambda: SimpleWindowCount()), + ], +) +def test_udwf_overloads(udwf_func, count_window_df): + df = count_window_df.select( + udwf_func(column("a")) + .window_frame(WindowFrame("rows", None, None)) + .build() + .alias("count") + ) + result = df.collect()[0] + assert result.column(0) == pa.array([0, 1, 2]) + + +def test_udwf_named_function(ctx, count_window_df): + """Test UDWF with explicit name parameter.""" + window_count = udwf( + SimpleWindowCount, + pa.int64(), + pa.int64(), + volatility="immutable", + name="my_custom_counter", + ) + + ctx.register_udwf(window_count) + result = ctx.sql( + """ + SELECT my_custom_counter(a) + OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED + FOLLOWING) FROM test_table""" + ).collect()[0] + assert result.column(0) == pa.array([0, 1, 2])