Skip to content

Commit 8663e77

Browse files
committed
adding unit tests for user defined window functions
1 parent 7232b4e commit 8663e77

File tree

2 files changed

+130
-18
lines changed

2 files changed

+130
-18
lines changed
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
import pyarrow as pa
19+
import pytest
20+
21+
from datafusion import SessionContext, column, udwf, lit, functions as f
22+
from datafusion.udf import WindowEvaluator
23+
24+
25+
class ExponentialSmooth(WindowEvaluator):
26+
"""Interface of a user-defined accumulation."""
27+
28+
def __init__(self) -> None:
29+
self.alpha = 0.9
30+
31+
def evaluate_all(self, values: pa.Array, num_rows: int) -> pa.Array:
32+
results = []
33+
curr_value = 0.0
34+
for idx in range(num_rows):
35+
if idx == 0:
36+
curr_value = values[idx].as_py()
37+
else:
38+
curr_value = values[idx].as_py() * self.alpha + curr_value * (
39+
1.0 - self.alpha
40+
)
41+
results.append(curr_value)
42+
43+
return pa.array(results)
44+
45+
46+
class NotSubclassOfWindowEvaluator:
47+
pass
48+
49+
50+
@pytest.fixture
51+
def df():
52+
ctx = SessionContext()
53+
54+
# create a RecordBatch and a new DataFrame from it
55+
batch = pa.RecordBatch.from_arrays(
56+
[
57+
pa.array([0, 1, 2, 3, 4, 5, 6]),
58+
pa.array([7, 4, 3, 8, 9, 1, 6]),
59+
pa.array(["A", "A", "A", "A", "B", "B", "B"]),
60+
],
61+
names=["a", "b", "c"],
62+
)
63+
return ctx.create_dataframe([[batch]])
64+
65+
66+
def test_udwf_errors(df):
67+
with pytest.raises(TypeError):
68+
udwf(
69+
NotSubclassOfWindowEvaluator,
70+
pa.float64(),
71+
pa.float64(),
72+
volatility="immutable",
73+
)
74+
75+
76+
smooth = udwf(
77+
ExponentialSmooth,
78+
pa.float64(),
79+
pa.float64(),
80+
volatility="immutable",
81+
)
82+
83+
data_test_udwf_functions = [
84+
("smooth_udwf", smooth(column("a")), [0, 0.9, 1.89, 2.889, 3.889, 4.889, 5.889]),
85+
(
86+
"partitioned_udwf",
87+
smooth(column("a")).partition_by(column("c")).build(),
88+
[0, 0.9, 1.89, 2.889, 4.0, 4.9, 5.89],
89+
),
90+
(
91+
"ordered_udwf",
92+
smooth(column("a")).order_by(column("b")).build(),
93+
[0.551, 1.13, 2.3, 2.755, 3.876, 5.0, 5.513],
94+
),
95+
]
96+
97+
98+
@pytest.mark.parametrize("name,expr,expected", data_test_udwf_functions)
99+
def test_udwf_functions(df, name, expr, expected):
100+
df = df.select("a", f.round(expr, lit(3)).alias(name))
101+
102+
# execute and collect the first (and only) batch
103+
result = df.sort(column("a")).select(column(name)).collect()[0]
104+
105+
assert result.column(0) == pa.array(expected)

python/datafusion/udf.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
import datafusion._internal as df_internal
2323
from datafusion.expr import Expr
24-
from typing import Callable, TYPE_CHECKING, TypeVar
24+
from typing import Callable, TYPE_CHECKING, TypeVar, Type
2525
from abc import ABCMeta, abstractmethod
2626
from typing import List
2727
from enum import Enum
@@ -251,10 +251,15 @@ def udaf(
251251
class WindowEvaluator(metaclass=ABCMeta):
252252
"""Evaluator class for user defined window functions (UDWF).
253253
254-
Users should inherit from this class and implement ``evaluate``, ``evaluate_all``,
255-
and/or ``evaluate_all_with_rank``. If using `evaluate` only you will need to
256-
override ``supports_bounded_execution``.
257-
"""
254+
It is up to the user to decide which evaluate function is appropriate.
255+
256+
|``uses_window_frame``|``supports_bounded_execution``|``include_rank``|function_to_implement|
257+
|---|---|----|----|
258+
|False (default) |False (default) |False (default) | ``evaluate_all`` |
259+
|False |True |False | ``evaluate`` |
260+
|False |True/False |True | ``evaluate_all_with_rank`` |
261+
|True |True/False |True/False | ``evaluate`` |
262+
""" # noqa: W505
258263

259264
def memoize(self) -> None:
260265
"""Perform a memoize operation to improve performance.
@@ -329,15 +334,8 @@ def evaluate_all(self, values: pyarrow.Array, num_rows: int) -> pyarrow.Array:
329334
avg(x) OVER (PARTITION BY y ORDER BY z ROWS BETWEEN 2 PRECEDING AND 3 FOLLOWING)
330335
```
331336
"""
332-
if self.supports_bounded_execution() and not self.uses_window_frame():
333-
res = []
334-
for idx in range(0, num_rows):
335-
res.append(self.evaluate(values, self.get_range(idx, num_rows)))
336-
return pyarrow.array(res)
337-
else:
338-
raise
337+
pass
339338

340-
@abstractmethod
341339
def evaluate(self, values: pyarrow.Array, range: tuple[int, int]) -> pyarrow.Scalar:
342340
"""Evaluate window function on a range of rows in an input partition.
343341
@@ -355,7 +353,6 @@ def evaluate(self, values: pyarrow.Array, range: tuple[int, int]) -> pyarrow.Sca
355353
"""
356354
pass
357355

358-
@abstractmethod
359356
def evaluate_all_with_rank(
360357
self, num_rows: int, ranks_in_partition: list[tuple[int, int]]
361358
) -> pyarrow.Array:
@@ -383,6 +380,8 @@ def evaluate_all_with_rank(
383380
(2,2),
384381
(3,4),
385382
]
383+
384+
The user must implement this method if ``include_rank`` returns True.
386385
"""
387386
pass
388387

@@ -399,6 +398,10 @@ def include_rank(self) -> bool:
399398
return False
400399

401400

401+
if TYPE_CHECKING:
402+
_W = TypeVar("_W", bound=WindowEvaluator)
403+
404+
402405
class WindowUDF:
403406
"""Class for performing window user defined functions (UDF).
404407
@@ -409,9 +412,9 @@ class WindowUDF:
409412
def __init__(
410413
self,
411414
name: str | None,
412-
func: WindowEvaluator,
415+
func: Type[WindowEvaluator],
413416
input_type: pyarrow.DataType,
414-
return_type: _R,
417+
return_type: pyarrow.DataType,
415418
volatility: Volatility | str,
416419
) -> None:
417420
"""Instantiate a user defined window function (UDWF).
@@ -434,9 +437,9 @@ def __call__(self, *args: Expr) -> Expr:
434437

435438
@staticmethod
436439
def udwf(
437-
func: Callable[..., _R],
440+
func: Type[WindowEvaluator],
438441
input_type: pyarrow.DataType,
439-
return_type: _R,
442+
return_type: pyarrow.DataType,
440443
volatility: Volatility | str,
441444
name: str | None = None,
442445
) -> WindowUDF:
@@ -452,6 +455,10 @@ def udwf(
452455
Returns:
453456
A user defined window function.
454457
"""
458+
if not issubclass(func, WindowEvaluator):
459+
raise TypeError(
460+
"`func` must implement the abstract base class WindowEvaluator"
461+
)
455462
if name is None:
456463
name = func.__qualname__.lower()
457464
return WindowUDF(

0 commit comments

Comments
 (0)