Skip to content

Commit 6ee8696

Browse files
committed
Revert "feat: Add user-defined window function (UDWF) decorator and tests"
This reverts commit da691b4.
1 parent da691b4 commit 6ee8696

File tree

3 files changed

+60
-129
lines changed

3 files changed

+60
-129
lines changed

python/datafusion/udf.py

Lines changed: 40 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -623,76 +623,59 @@ def __call__(self, *args: Expr) -> Expr:
623623

624624
@staticmethod
625625
def udwf(
626-
func: Callable[[], WindowEvaluator] | None = None,
627-
input_types: pa.DataType | list[pa.DataType] | None = None,
628-
return_type: pa.DataType | None = None,
629-
volatility: Volatility | str | None = None,
630-
name: str | None = None,
631-
) -> Union[WindowUDF, Callable[[Callable[[], WindowEvaluator]], WindowUDF]]:
632-
"""Create a new User-Defined Window Function (UDWF).
633-
634-
This method can be used both as a function and as a decorator:
626+
func: Callable[[], WindowEvaluator],
627+
input_types: pa.DataType | list[pa.DataType],
628+
return_type: pa.DataType,
629+
volatility: Volatility | str,
630+
name: Optional[str] = None,
631+
) -> WindowUDF:
632+
"""Create a new User-Defined Window Function.
635633
636-
As a function:
637-
udwf(func, input_types, return_type, volatility, name)
634+
If your :py:class:`WindowEvaluator` can be instantiated with no arguments, you
635+
can simply pass it's type as ``func``. If you need to pass additional arguments
636+
to it's constructor, you can define a lambda or a factory method. During runtime
637+
the :py:class:`WindowEvaluator` will be constructed for every instance in
638+
which this UDWF is used. The following examples are all valid.
638639
639-
As a decorator:
640-
@udwf(input_types, return_type, volatility, name)
641-
def func():
642-
return WindowEvaluator()
640+
.. code-block:: python
643641
644-
Args:
645-
func: The window evaluator factory function
646-
input_types: The input types for the window function
647-
return_type: The return type for the window function
648-
volatility: The volatility of the function
649-
name: Optional name for the function
642+
import pyarrow as pa
650643
651-
Returns:
652-
Either a WindowUDF instance or a decorator function
653-
"""
654-
# Used as decorator without arguments: @udwf
655-
if func is not None and all(
656-
x is None for x in (input_types, return_type, volatility)
657-
):
658-
return WindowUDF._create(
659-
func, [pa.float64()], pa.float64(), "volatile", None
660-
)
644+
class BiasedNumbers(WindowEvaluator):
645+
def __init__(self, start: int = 0) -> None:
646+
self.start = start
661647
662-
# Used as decorator with arguments: @udwf(...)
663-
if func is None:
648+
def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array:
649+
return pa.array([self.start + i for i in range(num_rows)])
664650
665-
def decorator(f: Callable[[], WindowEvaluator]) -> WindowUDF:
666-
if input_types is None or return_type is None or volatility is None:
667-
raise ValueError(
668-
"Must provide input_types, return_type, and volatility"
669-
)
670-
return WindowUDF._create(f, input_types, return_type, volatility, name)
651+
def bias_10() -> BiasedNumbers:
652+
return BiasedNumbers(10)
671653
672-
return decorator
654+
udwf1 = udwf(BiasedNumbers, pa.int64(), pa.int64(), "immutable")
655+
udwf2 = udwf(bias_10, pa.int64(), pa.int64(), "immutable")
656+
udwf3 = udwf(lambda: BiasedNumbers(20), pa.int64(), pa.int64(), "immutable")
673657
674-
# Used as function: udwf(...)
675-
if input_types is None or return_type is None or volatility is None:
676-
raise ValueError("Must provide input_types, return_type, and volatility")
677-
return WindowUDF._create(func, input_types, return_type, volatility, name)
658+
Args:
659+
func: A callable to create the window function.
660+
input_types: The data types of the arguments to ``func``.
661+
return_type: The data type of the return value.
662+
volatility: See :py:class:`Volatility` for allowed values.
663+
arguments: A list of arguments to pass in to the __init__ method for accum.
664+
name: A descriptive name for the function.
678665
679-
@staticmethod
680-
def _create(
681-
func: Callable[[], WindowEvaluator],
682-
input_types: pa.DataType | list[pa.DataType],
683-
return_type: pa.DataType,
684-
volatility: Volatility | str,
685-
name: str | None = None,
686-
) -> WindowUDF:
687-
"""Internal method to create a WindowUDF instance."""
666+
Returns:
667+
A user-defined window function.
668+
""" # noqa: W505, E501
688669
if not callable(func):
689-
raise TypeError("`func` must be callable")
670+
msg = "`func` must be callable."
671+
raise TypeError(msg)
690672
if not isinstance(func(), WindowEvaluator):
691-
raise TypeError("`func` must implement WindowEvaluator")
692-
if isinstance(input_types, pa.DataType):
693-
input_types = [input_types]
673+
msg = "`func` must implement the abstract base class WindowEvaluator"
674+
raise TypeError(msg)
694675
if name is None:
695676
name = func().__class__.__qualname__.lower()
677+
if isinstance(input_types, pa.DataType):
678+
input_types = [input_types]
696679
return WindowUDF(
697680
name=name,
698681
func=func,

python/tests/test_dataframe.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ def test_join():
339339

340340
# Verify we don't make a breaking change to pre-43.0.0
341341
# where users would pass join_keys as a positional argument
342-
df2 = df.join(df1, (["a"], ["a"]), how="inner") # type: ignore
342+
df2 = df.join(df1, (["a"], ["a"]), how="inner")
343343
df2.show()
344344
df2 = df2.sort(column("l.a"))
345345
table = pa.Table.from_batches(df2.collect())
@@ -375,17 +375,17 @@ def test_join_invalid_params():
375375
with pytest.raises(
376376
ValueError, match=r"`left_on` or `right_on` should not provided with `on`"
377377
):
378-
df2 = df.join(df1, on="a", how="inner", right_on="test") # type: ignore
378+
df2 = df.join(df1, on="a", how="inner", right_on="test")
379379

380380
with pytest.raises(
381381
ValueError, match=r"`left_on` and `right_on` should both be provided."
382382
):
383-
df2 = df.join(df1, left_on="a", how="inner") # type: ignore
383+
df2 = df.join(df1, left_on="a", how="inner")
384384

385385
with pytest.raises(
386386
ValueError, match=r"either `on` or `left_on` and `right_on` should be provided."
387387
):
388-
df2 = df.join(df1, how="inner") # type: ignore
388+
df2 = df.join(df1, how="inner")
389389

390390

391391
def test_join_on():
@@ -567,7 +567,7 @@ def test_distinct():
567567
]
568568

569569

570-
@pytest.mark.parametrize("name,expr,result", data_test_window_functions)
570+
@pytest.mark.parametrize(("name", "expr", "result"), data_test_window_functions)
571571
def test_window_functions(partitioned_df, name, expr, result):
572572
df = partitioned_df.select(
573573
column("a"), column("b"), column("c"), f.alias(expr, name)
@@ -730,7 +730,9 @@ def test_optimized_logical_plan(aggregate_df):
730730
def test_execution_plan(aggregate_df):
731731
plan = aggregate_df.execution_plan()
732732

733-
expected = "AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[sum(test.c2)]\n" # noqa: E501
733+
expected = (
734+
"AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[sum(test.c2)]\n"
735+
)
734736

735737
assert expected == plan.display()
736738

@@ -754,7 +756,7 @@ def test_execution_plan(aggregate_df):
754756

755757
ctx = SessionContext()
756758
rows_returned = 0
757-
for idx in range(0, plan.partition_count):
759+
for idx in range(plan.partition_count):
758760
stream = ctx.execute(plan, idx)
759761
try:
760762
batch = stream.next()
@@ -883,7 +885,7 @@ def test_union_distinct(ctx):
883885
)
884886
df_c = ctx.create_dataframe([[batch]]).sort(column("a"))
885887

886-
df_a_u_b = df_a.union(df_b, True).sort(column("a"))
888+
df_a_u_b = df_a.union(df_b, distinct=True).sort(column("a"))
887889

888890
assert df_c.collect() == df_a_u_b.collect()
889891
assert df_c.collect() == df_a_u_b.collect()
@@ -952,8 +954,6 @@ def test_to_arrow_table(df):
952954

953955
def test_execute_stream(df):
954956
stream = df.execute_stream()
955-
for s in stream:
956-
print(type(s))
957957
assert all(batch is not None for batch in stream)
958958
assert not list(stream) # after one iteration the generator must be exhausted
959959

@@ -967,7 +967,7 @@ def test_execute_stream_to_arrow_table(df, schema):
967967
(batch.to_pyarrow() for batch in stream), schema=df.schema()
968968
)
969969
else:
970-
pyarrow_table = pa.Table.from_batches((batch.to_pyarrow() for batch in stream))
970+
pyarrow_table = pa.Table.from_batches(batch.to_pyarrow() for batch in stream)
971971

972972
assert isinstance(pyarrow_table, pa.Table)
973973
assert pyarrow_table.shape == (3, 3)
@@ -1031,7 +1031,7 @@ def test_describe(df):
10311031
}
10321032

10331033

1034-
@pytest.mark.parametrize("path_to_str", (True, False))
1034+
@pytest.mark.parametrize("path_to_str", [True, False])
10351035
def test_write_csv(ctx, df, tmp_path, path_to_str):
10361036
path = str(tmp_path) if path_to_str else tmp_path
10371037

@@ -1044,7 +1044,7 @@ def test_write_csv(ctx, df, tmp_path, path_to_str):
10441044
assert result == expected
10451045

10461046

1047-
@pytest.mark.parametrize("path_to_str", (True, False))
1047+
@pytest.mark.parametrize("path_to_str", [True, False])
10481048
def test_write_json(ctx, df, tmp_path, path_to_str):
10491049
path = str(tmp_path) if path_to_str else tmp_path
10501050

@@ -1057,7 +1057,7 @@ def test_write_json(ctx, df, tmp_path, path_to_str):
10571057
assert result == expected
10581058

10591059

1060-
@pytest.mark.parametrize("path_to_str", (True, False))
1060+
@pytest.mark.parametrize("path_to_str", [True, False])
10611061
def test_write_parquet(df, tmp_path, path_to_str):
10621062
path = str(tmp_path) if path_to_str else tmp_path
10631063

@@ -1069,7 +1069,7 @@ def test_write_parquet(df, tmp_path, path_to_str):
10691069

10701070

10711071
@pytest.mark.parametrize(
1072-
"compression, compression_level",
1072+
("compression", "compression_level"),
10731073
[("gzip", 6), ("brotli", 7), ("zstd", 15)],
10741074
)
10751075
def test_write_compressed_parquet(df, tmp_path, compression, compression_level):
@@ -1080,7 +1080,7 @@ def test_write_compressed_parquet(df, tmp_path, compression, compression_level):
10801080
)
10811081

10821082
# test that the actual compression scheme is the one written
1083-
for root, dirs, files in os.walk(path):
1083+
for _root, _dirs, files in os.walk(path):
10841084
for file in files:
10851085
if file.endswith(".parquet"):
10861086
metadata = pq.ParquetFile(tmp_path / file).metadata.to_dict()
@@ -1095,7 +1095,7 @@ def test_write_compressed_parquet(df, tmp_path, compression, compression_level):
10951095

10961096

10971097
@pytest.mark.parametrize(
1098-
"compression, compression_level",
1098+
("compression", "compression_level"),
10991099
[("gzip", 12), ("brotli", 15), ("zstd", 23), ("wrong", 12)],
11001100
)
11011101
def test_write_compressed_parquet_wrong_compression_level(
@@ -1150,7 +1150,7 @@ def test_dataframe_export(df) -> None:
11501150
table = pa.table(df, schema=desired_schema)
11511151
assert table.num_columns == 1
11521152
assert table.num_rows == 3
1153-
for i in range(0, 3):
1153+
for i in range(3):
11541154
assert table[0][i].as_py() is None
11551155

11561156
# Expect an error when we cannot convert schema
@@ -1184,8 +1184,8 @@ def add_with_parameter(df_internal, value: Any) -> DataFrame:
11841184
result = df.to_pydict()
11851185

11861186
assert result["a"] == [1, 2, 3]
1187-
assert result["string_col"] == ["string data" for _i in range(0, 3)]
1188-
assert result["new_col"] == [3 for _i in range(0, 3)]
1187+
assert result["string_col"] == ["string data" for _i in range(3)]
1188+
assert result["new_col"] == [3 for _i in range(3)]
11891189

11901190

11911191
def test_dataframe_repr_html(df) -> None:

python/tests/test_udwf.py

Lines changed: 0 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -306,55 +306,3 @@ def test_udwf_functions(df, name, expr, expected):
306306
result = df.sort(column("a")).select(column(name)).collect()[0]
307307

308308
assert result.column(0) == pa.array(expected)
309-
310-
311-
def test_udwf_decorator(df):
312-
@udwf
313-
def smooth_default():
314-
return ExponentialSmoothDefault()
315-
316-
df1 = df.select(smooth_default()(column("a")))
317-
result = df1.collect()[0].column(0)
318-
# Test just the first few values with more lenient comparison
319-
assert abs(result[0].as_py() - 0.0) < 1e-6
320-
assert abs(result[1].as_py() - 1.0) < 1e-6
321-
assert abs(result[2].as_py() - 2.1) < 1e-6
322-
323-
# Test with explicit types
324-
@udwf(pa.float64(), pa.float64(), "immutable")
325-
def smooth_with_args():
326-
return ExponentialSmoothDefault(alpha=0.8)
327-
328-
df2 = df.select(smooth_with_args()(column("a")))
329-
result = df2.collect()[0].column(0)
330-
# Test just the first few values
331-
assert abs(result[0].as_py() - 0.0) < 1e-6
332-
assert abs(result[1].as_py() - 1.0) < 1e-6
333-
assert abs(result[2].as_py() - 1.8) < 1e-6
334-
335-
336-
def test_udwf_with_window_frame_decorator(df):
337-
@udwf(pa.float64(), pa.float64(), "immutable")
338-
def smooth_frame():
339-
return ExponentialSmoothFrame(alpha=0.9)
340-
341-
# Create window function and apply transformations
342-
window_fn = smooth_frame()(column("a"))
343-
window_fn = window_fn.window_frame(WindowFrame("rows", None, 0))
344-
window_fn = window_fn.build()
345-
346-
result = df.select(window_fn).collect()[0].column(0)
347-
# Test just the first few values
348-
assert abs(result[0].as_py() - 0.0) < 1e-6
349-
assert abs(result[1].as_py() - 0.9) < 1e-6
350-
351-
# With order by
352-
window_fn = smooth_frame()(column("a"))
353-
window_fn = window_fn.window_frame(WindowFrame("rows", None, 0))
354-
window_fn = window_fn.order_by(column("b"))
355-
window_fn = window_fn.build()
356-
357-
result = df.select(window_fn).collect()[0].column(0)
358-
# Test just the first few values
359-
assert abs(result[0].as_py() - 0.551) < 1e-3
360-
assert abs(result[1].as_py() - 1.13) < 1e-3

0 commit comments

Comments
 (0)