Skip to content

Commit cd972b5

Browse files
committed
Add udwf tests for multiple input types and decorator syntax
1 parent a52af17 commit cd972b5

File tree

1 file changed

+291
-11
lines changed

1 file changed

+291
-11
lines changed

python/tests/test_udwf.py

Lines changed: 291 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,148 @@
2020
import pyarrow as pa
2121
import pytest
2222
from datafusion import SessionContext, column, lit, udwf
23+
from datafusion import functions as f
2324
from datafusion.expr import WindowFrame
2425
from datafusion.udf import WindowEvaluator
2526

2627

28+
class ExponentialSmoothDefault(WindowEvaluator):
29+
def __init__(self, alpha: float = 0.9) -> None:
30+
self.alpha = alpha
31+
32+
def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array:
33+
results = []
34+
curr_value = 0.0
35+
values = values[0]
36+
for idx in range(num_rows):
37+
if idx == 0:
38+
curr_value = values[idx].as_py()
39+
else:
40+
curr_value = values[idx].as_py() * self.alpha + curr_value * (
41+
1.0 - self.alpha
42+
)
43+
results.append(curr_value)
44+
45+
return pa.array(results)
46+
47+
48+
class ExponentialSmoothBounded(WindowEvaluator):
49+
def __init__(self, alpha: float = 0.9) -> None:
50+
self.alpha = alpha
51+
52+
def supports_bounded_execution(self) -> bool:
53+
return True
54+
55+
def get_range(self, idx: int, num_rows: int) -> tuple[int, int]:
56+
# Override the default range of current row since uses_window_frame is False
57+
# So for the purpose of this test we just smooth from the previous row to
58+
# current.
59+
if idx == 0:
60+
return (0, 0)
61+
return (idx - 1, idx)
62+
63+
def evaluate(
64+
self, values: list[pa.Array], eval_range: tuple[int, int]
65+
) -> pa.Scalar:
66+
(start, stop) = eval_range
67+
curr_value = 0.0
68+
values = values[0]
69+
for idx in range(start, stop + 1):
70+
if idx == start:
71+
curr_value = values[idx].as_py()
72+
else:
73+
curr_value = values[idx].as_py() * self.alpha + curr_value * (
74+
1.0 - self.alpha
75+
)
76+
return pa.scalar(curr_value).cast(pa.float64())
77+
78+
79+
class ExponentialSmoothRank(WindowEvaluator):
80+
def __init__(self, alpha: float = 0.9) -> None:
81+
self.alpha = alpha
82+
83+
def include_rank(self) -> bool:
84+
return True
85+
86+
def evaluate_all_with_rank(
87+
self, num_rows: int, ranks_in_partition: list[tuple[int, int]]
88+
) -> pa.Array:
89+
results = []
90+
for idx in range(num_rows):
91+
if idx == 0:
92+
prior_value = 1.0
93+
matching_row = [
94+
i
95+
for i in range(len(ranks_in_partition))
96+
if ranks_in_partition[i][0] <= idx and ranks_in_partition[i][1] > idx
97+
][0] + 1
98+
curr_value = matching_row * self.alpha + prior_value * (1.0 - self.alpha)
99+
results.append(curr_value)
100+
prior_value = matching_row
101+
102+
return pa.array(results)
103+
104+
105+
class ExponentialSmoothFrame(WindowEvaluator):
106+
def __init__(self, alpha: float = 0.9) -> None:
107+
self.alpha = alpha
108+
109+
def uses_window_frame(self) -> bool:
110+
return True
111+
112+
def evaluate(
113+
self, values: list[pa.Array], eval_range: tuple[int, int]
114+
) -> pa.Scalar:
115+
(start, stop) = eval_range
116+
curr_value = 0.0
117+
if len(values) > 1:
118+
order_by = values[1] # noqa: F841
119+
values = values[0]
120+
else:
121+
values = values[0]
122+
for idx in range(start, stop):
123+
if idx == start:
124+
curr_value = values[idx].as_py()
125+
else:
126+
curr_value = values[idx].as_py() * self.alpha + curr_value * (
127+
1.0 - self.alpha
128+
)
129+
return pa.scalar(curr_value).cast(pa.float64())
130+
131+
132+
class SmoothTwoColumn(WindowEvaluator):
133+
"""This class demonstrates using two columns.
134+
135+
If the second column is above a threshold, then smooth over the first column from
136+
the previous and next rows.
137+
"""
138+
139+
def __init__(self, alpha: float = 0.9) -> None:
140+
self.alpha = alpha
141+
142+
def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array:
143+
results = []
144+
values_a = values[0]
145+
values_b = values[1]
146+
for idx in range(num_rows):
147+
if values_b[idx].as_py() > 7:
148+
if idx == 0:
149+
results.append(values_a[1].cast(pa.float64()))
150+
elif idx == num_rows - 1:
151+
results.append(values_a[num_rows - 2].cast(pa.float64()))
152+
else:
153+
results.append(
154+
pa.scalar(
155+
values_a[idx - 1].as_py() * self.alpha
156+
+ values_a[idx + 1].as_py() * (1.0 - self.alpha)
157+
)
158+
)
159+
else:
160+
results.append(values_a[idx].cast(pa.float64()))
161+
162+
return pa.array(results)
163+
164+
27165
class SimpleWindowCount(WindowEvaluator):
28166
"""A simple window evaluator that counts rows."""
29167

@@ -44,7 +182,23 @@ def ctx():
44182

45183

46184
@pytest.fixture
47-
def df(ctx):
185+
def df():
186+
ctx = SessionContext()
187+
188+
# create a RecordBatch and a new DataFrame from it
189+
batch = pa.RecordBatch.from_arrays(
190+
[
191+
pa.array([0, 1, 2, 3, 4, 5, 6]),
192+
pa.array([7, 4, 3, 8, 9, 1, 6]),
193+
pa.array(["A", "A", "A", "A", "B", "B", "B"]),
194+
],
195+
names=["a", "b", "c"],
196+
)
197+
return ctx.create_dataframe([[batch]])
198+
199+
200+
@pytest.fixture
201+
def simple_df(ctx):
48202
# create a RecordBatch and a new DataFrame from it
49203
batch = pa.RecordBatch.from_arrays(
50204
[pa.array([1, 2, 3]), pa.array([4, 4, 6])],
@@ -53,7 +207,17 @@ def df(ctx):
53207
return ctx.create_dataframe([[batch]], name="test_table")
54208

55209

56-
def test_udwf_errors():
210+
def test_udwf_errors(df):
211+
with pytest.raises(TypeError):
212+
udwf(
213+
NotSubclassOfWindowEvaluator,
214+
pa.float64(),
215+
pa.float64(),
216+
volatility="immutable",
217+
)
218+
219+
220+
def test_udwf_errors_with_message():
57221
"""Test error cases for UDWF creation."""
58222
with pytest.raises(
59223
TypeError, match="`func` must implement the abstract base class WindowEvaluator"
@@ -63,13 +227,13 @@ def test_udwf_errors():
63227
)
64228

65229

66-
def test_udwf_basic_usage(df):
230+
def test_udwf_basic_usage(simple_df):
67231
"""Test basic UDWF usage with a simple counting window function."""
68232
simple_count = udwf(
69233
SimpleWindowCount, pa.int64(), pa.int64(), volatility="immutable"
70234
)
71235

72-
df = df.select(
236+
df = simple_df.select(
73237
simple_count(column("a"))
74238
.window_frame(WindowFrame("rows", None, None))
75239
.build()
@@ -79,13 +243,13 @@ def test_udwf_basic_usage(df):
79243
assert result.column(0) == pa.array([0, 1, 2])
80244

81245

82-
def test_udwf_with_args(df):
246+
def test_udwf_with_args(simple_df):
83247
"""Test UDWF with constructor arguments."""
84248
count_base10 = udwf(
85249
lambda: SimpleWindowCount(10), pa.int64(), pa.int64(), volatility="immutable"
86250
)
87251

88-
df = df.select(
252+
df = simple_df.select(
89253
count_base10(column("a"))
90254
.window_frame(WindowFrame("rows", None, None))
91255
.build()
@@ -95,14 +259,14 @@ def test_udwf_with_args(df):
95259
assert result.column(0) == pa.array([10, 11, 12])
96260

97261

98-
def test_udwf_decorator_basic(df):
262+
def test_udwf_decorator_basic(simple_df):
99263
"""Test UDWF used as a decorator."""
100264

101265
@udwf([pa.int64()], pa.int64(), "immutable")
102266
def window_count() -> WindowEvaluator:
103267
return SimpleWindowCount()
104268

105-
df = df.select(
269+
df = simple_df.select(
106270
window_count(column("a"))
107271
.window_frame(WindowFrame("rows", None, None))
108272
.build()
@@ -112,14 +276,14 @@ def window_count() -> WindowEvaluator:
112276
assert result.column(0) == pa.array([0, 1, 2])
113277

114278

115-
def test_udwf_decorator_with_args(df):
279+
def test_udwf_decorator_with_args(simple_df):
116280
"""Test UDWF decorator with constructor arguments."""
117281

118282
@udwf([pa.int64()], pa.int64(), "immutable")
119283
def window_count_base10() -> WindowEvaluator:
120284
return SimpleWindowCount(10)
121285

122-
df = df.select(
286+
df = simple_df.select(
123287
window_count_base10(column("a"))
124288
.window_frame(WindowFrame("rows", None, None))
125289
.build()
@@ -129,7 +293,7 @@ def window_count_base10() -> WindowEvaluator:
129293
assert result.column(0) == pa.array([10, 11, 12])
130294

131295

132-
def test_register_udwf(ctx, df):
296+
def test_register_udwf(ctx, simple_df):
133297
"""Test registering and using UDWF in SQL context."""
134298
window_count = udwf(
135299
SimpleWindowCount,
@@ -144,3 +308,119 @@ def test_register_udwf(ctx, df):
144308
"SELECT window_count(a) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) FROM test_table"
145309
).collect()[0]
146310
assert result.column(0) == pa.array([0, 1, 2])
311+
312+
313+
smooth_default = udwf(
314+
ExponentialSmoothDefault,
315+
pa.float64(),
316+
pa.float64(),
317+
volatility="immutable",
318+
)
319+
320+
smooth_w_arguments = udwf(
321+
lambda: ExponentialSmoothDefault(0.8),
322+
pa.float64(),
323+
pa.float64(),
324+
volatility="immutable",
325+
)
326+
327+
smooth_bounded = udwf(
328+
ExponentialSmoothBounded,
329+
pa.float64(),
330+
pa.float64(),
331+
volatility="immutable",
332+
)
333+
334+
smooth_rank = udwf(
335+
ExponentialSmoothRank,
336+
pa.utf8(),
337+
pa.float64(),
338+
volatility="immutable",
339+
)
340+
341+
smooth_frame = udwf(
342+
ExponentialSmoothFrame,
343+
pa.float64(),
344+
pa.float64(),
345+
volatility="immutable",
346+
)
347+
348+
smooth_two_col = udwf(
349+
SmoothTwoColumn,
350+
[pa.int64(), pa.int64()],
351+
pa.float64(),
352+
volatility="immutable",
353+
)
354+
355+
data_test_udwf_functions = [
356+
(
357+
"default_udwf_no_arguments",
358+
smooth_default(column("a")),
359+
[0, 0.9, 1.89, 2.889, 3.889, 4.889, 5.889],
360+
),
361+
(
362+
"default_udwf_w_arguments",
363+
smooth_w_arguments(column("a")),
364+
[0, 0.8, 1.76, 2.752, 3.75, 4.75, 5.75],
365+
),
366+
(
367+
"default_udwf_partitioned",
368+
smooth_default(column("a")).partition_by(column("c")).build(),
369+
[0, 0.9, 1.89, 2.889, 4.0, 4.9, 5.89],
370+
),
371+
(
372+
"default_udwf_ordered",
373+
smooth_default(column("a")).order_by(column("b")).build(),
374+
[0.551, 1.13, 2.3, 2.755, 3.876, 5.0, 5.513],
375+
),
376+
(
377+
"bounded_udwf",
378+
smooth_bounded(column("a")),
379+
[0, 0.9, 1.9, 2.9, 3.9, 4.9, 5.9],
380+
),
381+
(
382+
"bounded_udwf_ignores_frame",
383+
smooth_bounded(column("a"))
384+
.window_frame(WindowFrame("rows", None, None))
385+
.build(),
386+
[0, 0.9, 1.9, 2.9, 3.9, 4.9, 5.9],
387+
),
388+
(
389+
"rank_udwf",
390+
smooth_rank(column("c")).order_by(column("c")).build(),
391+
[1, 1, 1, 1, 1.9, 2, 2],
392+
),
393+
(
394+
"frame_unbounded_udwf",
395+
smooth_frame(column("a")).window_frame(WindowFrame("rows", None, None)).build(),
396+
[5.889, 5.889, 5.889, 5.889, 5.889, 5.889, 5.889],
397+
),
398+
(
399+
"frame_bounded_udwf",
400+
smooth_frame(column("a")).window_frame(WindowFrame("rows", None, 0)).build(),
401+
[0.0, 0.9, 1.89, 2.889, 3.889, 4.889, 5.889],
402+
),
403+
(
404+
"frame_bounded_udwf",
405+
smooth_frame(column("a"))
406+
.window_frame(WindowFrame("rows", None, 0))
407+
.order_by(column("b"))
408+
.build(),
409+
[0.551, 1.13, 2.3, 2.755, 3.876, 5.0, 5.513],
410+
),
411+
(
412+
"two_column_udwf",
413+
smooth_two_col(column("a"), column("b")),
414+
[0.0, 1.0, 2.0, 2.2, 3.2, 5.0, 6.0],
415+
),
416+
]
417+
418+
419+
@pytest.mark.parametrize(("name", "expr", "expected"), data_test_udwf_functions)
420+
def test_udwf_functions(df, name, expr, expected):
421+
df = df.select("a", "b", f.round(expr, lit(3)).alias(name))
422+
423+
# execute and collect the first (and only) batch
424+
result = df.sort(column("a")).select(column(name)).collect()[0]
425+
426+
assert result.column(0) == pa.array(expected)

0 commit comments

Comments
 (0)