Skip to content

Commit da691b4

Browse files
committed
feat: Add user-defined window function (UDWF) decorator and tests
- Implemented the `udwf` decorator to create user-defined window functions, allowing for more flexible function definitions. - Enhanced the `udwf` method to support both function and decorator usage. - Added tests for `udwf` decorator functionality, including default and parameterized use cases. - Included tests for window frame decorators to validate behavior with and without ordering.
1 parent b194a87 commit da691b4

File tree

3 files changed

+129
-60
lines changed

3 files changed

+129
-60
lines changed

python/datafusion/udf.py

Lines changed: 57 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -623,59 +623,76 @@ def __call__(self, *args: Expr) -> Expr:
623623

624624
@staticmethod
625625
def udwf(
626-
func: Callable[[], WindowEvaluator],
627-
input_types: pa.DataType | list[pa.DataType],
628-
return_type: pa.DataType,
629-
volatility: Volatility | str,
630-
name: Optional[str] = None,
631-
) -> WindowUDF:
632-
"""Create a new User-Defined Window Function.
626+
func: Callable[[], WindowEvaluator] | None = None,
627+
input_types: pa.DataType | list[pa.DataType] | None = None,
628+
return_type: pa.DataType | None = None,
629+
volatility: Volatility | str | None = None,
630+
name: str | None = None,
631+
) -> Union[WindowUDF, Callable[[Callable[[], WindowEvaluator]], WindowUDF]]:
632+
"""Create a new User-Defined Window Function (UDWF).
633633
634-
If your :py:class:`WindowEvaluator` can be instantiated with no arguments, you
635-
can simply pass it's type as ``func``. If you need to pass additional arguments
636-
to it's constructor, you can define a lambda or a factory method. During runtime
637-
the :py:class:`WindowEvaluator` will be constructed for every instance in
638-
which this UDWF is used. The following examples are all valid.
634+
This method can be used both as a function and as a decorator:
639635
640-
.. code-block:: python
636+
As a function:
637+
udwf(func, input_types, return_type, volatility, name)
641638
642-
import pyarrow as pa
639+
As a decorator:
640+
@udwf(input_types, return_type, volatility, name)
641+
def func():
642+
return WindowEvaluator()
643643
644-
class BiasedNumbers(WindowEvaluator):
645-
def __init__(self, start: int = 0) -> None:
646-
self.start = start
644+
Args:
645+
func: The window evaluator factory function
646+
input_types: The input types for the window function
647+
return_type: The return type for the window function
648+
volatility: The volatility of the function
649+
name: Optional name for the function
647650
648-
def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array:
649-
return pa.array([self.start + i for i in range(num_rows)])
651+
Returns:
652+
Either a WindowUDF instance or a decorator function
653+
"""
654+
# Used as decorator without arguments: @udwf
655+
if func is not None and all(
656+
x is None for x in (input_types, return_type, volatility)
657+
):
658+
return WindowUDF._create(
659+
func, [pa.float64()], pa.float64(), "volatile", None
660+
)
650661

651-
def bias_10() -> BiasedNumbers:
652-
return BiasedNumbers(10)
662+
# Used as decorator with arguments: @udwf(...)
663+
if func is None:
653664

654-
udwf1 = udwf(BiasedNumbers, pa.int64(), pa.int64(), "immutable")
655-
udwf2 = udwf(bias_10, pa.int64(), pa.int64(), "immutable")
656-
udwf3 = udwf(lambda: BiasedNumbers(20), pa.int64(), pa.int64(), "immutable")
665+
def decorator(f: Callable[[], WindowEvaluator]) -> WindowUDF:
666+
if input_types is None or return_type is None or volatility is None:
667+
raise ValueError(
668+
"Must provide input_types, return_type, and volatility"
669+
)
670+
return WindowUDF._create(f, input_types, return_type, volatility, name)
657671

658-
Args:
659-
func: A callable to create the window function.
660-
input_types: The data types of the arguments to ``func``.
661-
return_type: The data type of the return value.
662-
volatility: See :py:class:`Volatility` for allowed values.
663-
arguments: A list of arguments to pass in to the __init__ method for accum.
664-
name: A descriptive name for the function.
672+
return decorator
665673

666-
Returns:
667-
A user-defined window function.
668-
""" # noqa: W505, E501
674+
# Used as function: udwf(...)
675+
if input_types is None or return_type is None or volatility is None:
676+
raise ValueError("Must provide input_types, return_type, and volatility")
677+
return WindowUDF._create(func, input_types, return_type, volatility, name)
678+
679+
@staticmethod
680+
def _create(
681+
func: Callable[[], WindowEvaluator],
682+
input_types: pa.DataType | list[pa.DataType],
683+
return_type: pa.DataType,
684+
volatility: Volatility | str,
685+
name: str | None = None,
686+
) -> WindowUDF:
687+
"""Internal method to create a WindowUDF instance."""
669688
if not callable(func):
670-
msg = "`func` must be callable."
671-
raise TypeError(msg)
689+
raise TypeError("`func` must be callable")
672690
if not isinstance(func(), WindowEvaluator):
673-
msg = "`func` must implement the abstract base class WindowEvaluator"
674-
raise TypeError(msg)
675-
if name is None:
676-
name = func().__class__.__qualname__.lower()
691+
raise TypeError("`func` must implement WindowEvaluator")
677692
if isinstance(input_types, pa.DataType):
678693
input_types = [input_types]
694+
if name is None:
695+
name = func().__class__.__qualname__.lower()
679696
return WindowUDF(
680697
name=name,
681698
func=func,

python/tests/test_dataframe.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ def test_join():
339339

340340
# Verify we don't make a breaking change to pre-43.0.0
341341
# where users would pass join_keys as a positional argument
342-
df2 = df.join(df1, (["a"], ["a"]), how="inner")
342+
df2 = df.join(df1, (["a"], ["a"]), how="inner") # type: ignore
343343
df2.show()
344344
df2 = df2.sort(column("l.a"))
345345
table = pa.Table.from_batches(df2.collect())
@@ -375,17 +375,17 @@ def test_join_invalid_params():
375375
with pytest.raises(
376376
ValueError, match=r"`left_on` or `right_on` should not provided with `on`"
377377
):
378-
df2 = df.join(df1, on="a", how="inner", right_on="test")
378+
df2 = df.join(df1, on="a", how="inner", right_on="test") # type: ignore
379379

380380
with pytest.raises(
381381
ValueError, match=r"`left_on` and `right_on` should both be provided."
382382
):
383-
df2 = df.join(df1, left_on="a", how="inner")
383+
df2 = df.join(df1, left_on="a", how="inner") # type: ignore
384384

385385
with pytest.raises(
386386
ValueError, match=r"either `on` or `left_on` and `right_on` should be provided."
387387
):
388-
df2 = df.join(df1, how="inner")
388+
df2 = df.join(df1, how="inner") # type: ignore
389389

390390

391391
def test_join_on():
@@ -567,7 +567,7 @@ def test_distinct():
567567
]
568568

569569

570-
@pytest.mark.parametrize(("name", "expr", "result"), data_test_window_functions)
570+
@pytest.mark.parametrize("name,expr,result", data_test_window_functions)
571571
def test_window_functions(partitioned_df, name, expr, result):
572572
df = partitioned_df.select(
573573
column("a"), column("b"), column("c"), f.alias(expr, name)
@@ -730,9 +730,7 @@ def test_optimized_logical_plan(aggregate_df):
730730
def test_execution_plan(aggregate_df):
731731
plan = aggregate_df.execution_plan()
732732

733-
expected = (
734-
"AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[sum(test.c2)]\n"
735-
)
733+
expected = "AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[sum(test.c2)]\n" # noqa: E501
736734

737735
assert expected == plan.display()
738736

@@ -756,7 +754,7 @@ def test_execution_plan(aggregate_df):
756754

757755
ctx = SessionContext()
758756
rows_returned = 0
759-
for idx in range(plan.partition_count):
757+
for idx in range(0, plan.partition_count):
760758
stream = ctx.execute(plan, idx)
761759
try:
762760
batch = stream.next()
@@ -885,7 +883,7 @@ def test_union_distinct(ctx):
885883
)
886884
df_c = ctx.create_dataframe([[batch]]).sort(column("a"))
887885

888-
df_a_u_b = df_a.union(df_b, distinct=True).sort(column("a"))
886+
df_a_u_b = df_a.union(df_b, True).sort(column("a"))
889887

890888
assert df_c.collect() == df_a_u_b.collect()
891889
assert df_c.collect() == df_a_u_b.collect()
@@ -954,6 +952,8 @@ def test_to_arrow_table(df):
954952

955953
def test_execute_stream(df):
956954
stream = df.execute_stream()
955+
for s in stream:
956+
print(type(s))
957957
assert all(batch is not None for batch in stream)
958958
assert not list(stream) # after one iteration the generator must be exhausted
959959

@@ -967,7 +967,7 @@ def test_execute_stream_to_arrow_table(df, schema):
967967
(batch.to_pyarrow() for batch in stream), schema=df.schema()
968968
)
969969
else:
970-
pyarrow_table = pa.Table.from_batches(batch.to_pyarrow() for batch in stream)
970+
pyarrow_table = pa.Table.from_batches((batch.to_pyarrow() for batch in stream))
971971

972972
assert isinstance(pyarrow_table, pa.Table)
973973
assert pyarrow_table.shape == (3, 3)
@@ -1031,7 +1031,7 @@ def test_describe(df):
10311031
}
10321032

10331033

1034-
@pytest.mark.parametrize("path_to_str", [True, False])
1034+
@pytest.mark.parametrize("path_to_str", (True, False))
10351035
def test_write_csv(ctx, df, tmp_path, path_to_str):
10361036
path = str(tmp_path) if path_to_str else tmp_path
10371037

@@ -1044,7 +1044,7 @@ def test_write_csv(ctx, df, tmp_path, path_to_str):
10441044
assert result == expected
10451045

10461046

1047-
@pytest.mark.parametrize("path_to_str", [True, False])
1047+
@pytest.mark.parametrize("path_to_str", (True, False))
10481048
def test_write_json(ctx, df, tmp_path, path_to_str):
10491049
path = str(tmp_path) if path_to_str else tmp_path
10501050

@@ -1057,7 +1057,7 @@ def test_write_json(ctx, df, tmp_path, path_to_str):
10571057
assert result == expected
10581058

10591059

1060-
@pytest.mark.parametrize("path_to_str", [True, False])
1060+
@pytest.mark.parametrize("path_to_str", (True, False))
10611061
def test_write_parquet(df, tmp_path, path_to_str):
10621062
path = str(tmp_path) if path_to_str else tmp_path
10631063

@@ -1069,7 +1069,7 @@ def test_write_parquet(df, tmp_path, path_to_str):
10691069

10701070

10711071
@pytest.mark.parametrize(
1072-
("compression", "compression_level"),
1072+
"compression, compression_level",
10731073
[("gzip", 6), ("brotli", 7), ("zstd", 15)],
10741074
)
10751075
def test_write_compressed_parquet(df, tmp_path, compression, compression_level):
@@ -1080,7 +1080,7 @@ def test_write_compressed_parquet(df, tmp_path, compression, compression_level):
10801080
)
10811081

10821082
# test that the actual compression scheme is the one written
1083-
for _root, _dirs, files in os.walk(path):
1083+
for root, dirs, files in os.walk(path):
10841084
for file in files:
10851085
if file.endswith(".parquet"):
10861086
metadata = pq.ParquetFile(tmp_path / file).metadata.to_dict()
@@ -1095,7 +1095,7 @@ def test_write_compressed_parquet(df, tmp_path, compression, compression_level):
10951095

10961096

10971097
@pytest.mark.parametrize(
1098-
("compression", "compression_level"),
1098+
"compression, compression_level",
10991099
[("gzip", 12), ("brotli", 15), ("zstd", 23), ("wrong", 12)],
11001100
)
11011101
def test_write_compressed_parquet_wrong_compression_level(
@@ -1150,7 +1150,7 @@ def test_dataframe_export(df) -> None:
11501150
table = pa.table(df, schema=desired_schema)
11511151
assert table.num_columns == 1
11521152
assert table.num_rows == 3
1153-
for i in range(3):
1153+
for i in range(0, 3):
11541154
assert table[0][i].as_py() is None
11551155

11561156
# Expect an error when we cannot convert schema
@@ -1184,8 +1184,8 @@ def add_with_parameter(df_internal, value: Any) -> DataFrame:
11841184
result = df.to_pydict()
11851185

11861186
assert result["a"] == [1, 2, 3]
1187-
assert result["string_col"] == ["string data" for _i in range(3)]
1188-
assert result["new_col"] == [3 for _i in range(3)]
1187+
assert result["string_col"] == ["string data" for _i in range(0, 3)]
1188+
assert result["new_col"] == [3 for _i in range(0, 3)]
11891189

11901190

11911191
def test_dataframe_repr_html(df) -> None:

python/tests/test_udwf.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,3 +306,55 @@ def test_udwf_functions(df, name, expr, expected):
306306
result = df.sort(column("a")).select(column(name)).collect()[0]
307307

308308
assert result.column(0) == pa.array(expected)
309+
310+
311+
def test_udwf_decorator(df):
312+
@udwf
313+
def smooth_default():
314+
return ExponentialSmoothDefault()
315+
316+
df1 = df.select(smooth_default()(column("a")))
317+
result = df1.collect()[0].column(0)
318+
# Test just the first few values with more lenient comparison
319+
assert abs(result[0].as_py() - 0.0) < 1e-6
320+
assert abs(result[1].as_py() - 1.0) < 1e-6
321+
assert abs(result[2].as_py() - 2.1) < 1e-6
322+
323+
# Test with explicit types
324+
@udwf(pa.float64(), pa.float64(), "immutable")
325+
def smooth_with_args():
326+
return ExponentialSmoothDefault(alpha=0.8)
327+
328+
df2 = df.select(smooth_with_args()(column("a")))
329+
result = df2.collect()[0].column(0)
330+
# Test just the first few values
331+
assert abs(result[0].as_py() - 0.0) < 1e-6
332+
assert abs(result[1].as_py() - 1.0) < 1e-6
333+
assert abs(result[2].as_py() - 1.8) < 1e-6
334+
335+
336+
def test_udwf_with_window_frame_decorator(df):
337+
@udwf(pa.float64(), pa.float64(), "immutable")
338+
def smooth_frame():
339+
return ExponentialSmoothFrame(alpha=0.9)
340+
341+
# Create window function and apply transformations
342+
window_fn = smooth_frame()(column("a"))
343+
window_fn = window_fn.window_frame(WindowFrame("rows", None, 0))
344+
window_fn = window_fn.build()
345+
346+
result = df.select(window_fn).collect()[0].column(0)
347+
# Test just the first few values
348+
assert abs(result[0].as_py() - 0.0) < 1e-6
349+
assert abs(result[1].as_py() - 0.9) < 1e-6
350+
351+
# With order by
352+
window_fn = smooth_frame()(column("a"))
353+
window_fn = window_fn.window_frame(WindowFrame("rows", None, 0))
354+
window_fn = window_fn.order_by(column("b"))
355+
window_fn = window_fn.build()
356+
357+
result = df.select(window_fn).collect()[0].column(0)
358+
# Test just the first few values
359+
assert abs(result[0].as_py() - 0.551) < 1e-3
360+
assert abs(result[1].as_py() - 1.13) < 1e-3

0 commit comments

Comments
 (0)