Skip to content

Commit d6378af

Browse files
committed
move new tests to test_udwf2.py
1 parent 5c8dbcd commit d6378af

File tree

2 files changed

+397
-89
lines changed

2 files changed

+397
-89
lines changed

python/tests/test_udwf.py

Lines changed: 251 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -20,127 +20,289 @@
2020
import pyarrow as pa
2121
import pytest
2222
from datafusion import SessionContext, column, lit, udwf
23+
from datafusion import functions as f
2324
from datafusion.expr import WindowFrame
2425
from datafusion.udf import WindowEvaluator
2526

2627

27-
class SimpleWindowCount(WindowEvaluator):
28-
"""A simple window evaluator that counts rows."""
28+
class ExponentialSmoothDefault(WindowEvaluator):
29+
def __init__(self, alpha: float = 0.9) -> None:
30+
self.alpha = alpha
2931

30-
def __init__(self, base: int = 0) -> None:
31-
self.base = base
32+
def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array:
33+
results = []
34+
curr_value = 0.0
35+
values = values[0]
36+
for idx in range(num_rows):
37+
if idx == 0:
38+
curr_value = values[idx].as_py()
39+
else:
40+
curr_value = values[idx].as_py() * self.alpha + curr_value * (
41+
1.0 - self.alpha
42+
)
43+
results.append(curr_value)
44+
45+
return pa.array(results)
46+
47+
48+
class ExponentialSmoothBounded(WindowEvaluator):
49+
def __init__(self, alpha: float = 0.9) -> None:
50+
self.alpha = alpha
51+
52+
def supports_bounded_execution(self) -> bool:
53+
return True
54+
55+
def get_range(self, idx: int, num_rows: int) -> tuple[int, int]:
56+
# Override the default range of current row since uses_window_frame is False
57+
# So for the purpose of this test we just smooth from the previous row to
58+
# current.
59+
if idx == 0:
60+
return (0, 0)
61+
return (idx - 1, idx)
62+
63+
def evaluate(
64+
self, values: list[pa.Array], eval_range: tuple[int, int]
65+
) -> pa.Scalar:
66+
(start, stop) = eval_range
67+
curr_value = 0.0
68+
values = values[0]
69+
for idx in range(start, stop + 1):
70+
if idx == start:
71+
curr_value = values[idx].as_py()
72+
else:
73+
curr_value = values[idx].as_py() * self.alpha + curr_value * (
74+
1.0 - self.alpha
75+
)
76+
return pa.scalar(curr_value).cast(pa.float64())
77+
78+
79+
class ExponentialSmoothRank(WindowEvaluator):
80+
def __init__(self, alpha: float = 0.9) -> None:
81+
self.alpha = alpha
82+
83+
def include_rank(self) -> bool:
84+
return True
85+
86+
def evaluate_all_with_rank(
87+
self, num_rows: int, ranks_in_partition: list[tuple[int, int]]
88+
) -> pa.Array:
89+
results = []
90+
for idx in range(num_rows):
91+
if idx == 0:
92+
prior_value = 1.0
93+
matching_row = [
94+
i
95+
for i in range(len(ranks_in_partition))
96+
if ranks_in_partition[i][0] <= idx and ranks_in_partition[i][1] > idx
97+
][0] + 1
98+
curr_value = matching_row * self.alpha + prior_value * (1.0 - self.alpha)
99+
results.append(curr_value)
100+
prior_value = matching_row
101+
102+
return pa.array(results)
103+
104+
105+
class ExponentialSmoothFrame(WindowEvaluator):
106+
def __init__(self, alpha: float = 0.9) -> None:
107+
self.alpha = alpha
108+
109+
def uses_window_frame(self) -> bool:
110+
return True
111+
112+
def evaluate(
113+
self, values: list[pa.Array], eval_range: tuple[int, int]
114+
) -> pa.Scalar:
115+
(start, stop) = eval_range
116+
curr_value = 0.0
117+
if len(values) > 1:
118+
order_by = values[1] # noqa: F841
119+
values = values[0]
120+
else:
121+
values = values[0]
122+
for idx in range(start, stop):
123+
if idx == start:
124+
curr_value = values[idx].as_py()
125+
else:
126+
curr_value = values[idx].as_py() * self.alpha + curr_value * (
127+
1.0 - self.alpha
128+
)
129+
return pa.scalar(curr_value).cast(pa.float64())
130+
131+
132+
class SmoothTwoColumn(WindowEvaluator):
133+
"""This class demonstrates using two columns.
134+
135+
If the second column is above a threshold, then smooth over the first column from
136+
the previous and next rows.
137+
"""
138+
139+
def __init__(self, alpha: float = 0.9) -> None:
140+
self.alpha = alpha
32141

33142
def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array:
34-
return pa.array([self.base + i for i in range(num_rows)])
143+
results = []
144+
values_a = values[0]
145+
values_b = values[1]
146+
for idx in range(num_rows):
147+
if values_b[idx].as_py() > 7:
148+
if idx == 0:
149+
results.append(values_a[1].cast(pa.float64()))
150+
elif idx == num_rows - 1:
151+
results.append(values_a[num_rows - 2].cast(pa.float64()))
152+
else:
153+
results.append(
154+
pa.scalar(
155+
values_a[idx - 1].as_py() * self.alpha
156+
+ values_a[idx + 1].as_py() * (1.0 - self.alpha)
157+
)
158+
)
159+
else:
160+
results.append(values_a[idx].cast(pa.float64()))
161+
162+
return pa.array(results)
35163

36164

37165
class NotSubclassOfWindowEvaluator:
38166
pass
39167

40168

41169
@pytest.fixture
42-
def ctx():
43-
return SessionContext()
170+
def df():
171+
ctx = SessionContext()
44172

45-
46-
@pytest.fixture
47-
def df(ctx):
48173
# create a RecordBatch and a new DataFrame from it
49174
batch = pa.RecordBatch.from_arrays(
50-
[pa.array([1, 2, 3]), pa.array([4, 4, 6])],
51-
names=["a", "b"],
175+
[
176+
pa.array([0, 1, 2, 3, 4, 5, 6]),
177+
pa.array([7, 4, 3, 8, 9, 1, 6]),
178+
pa.array(["A", "A", "A", "A", "B", "B", "B"]),
179+
],
180+
names=["a", "b", "c"],
52181
)
53-
return ctx.create_dataframe([[batch]], name="test_table")
182+
return ctx.create_dataframe([[batch]])
54183

55184

56-
def test_udwf_errors():
57-
"""Test error cases for UDWF creation."""
58-
with pytest.raises(
59-
TypeError, match="`func` must implement the abstract base class WindowEvaluator"
60-
):
185+
def test_udwf_errors(df):
186+
with pytest.raises(TypeError):
61187
udwf(
62-
NotSubclassOfWindowEvaluator, pa.int64(), pa.int64(), volatility="immutable"
188+
NotSubclassOfWindowEvaluator,
189+
pa.float64(),
190+
pa.float64(),
191+
volatility="immutable",
63192
)
64193

65194

66-
def test_udwf_basic_usage(df):
67-
"""Test basic UDWF usage with a simple counting window function."""
68-
simple_count = udwf(
69-
SimpleWindowCount, pa.int64(), pa.int64(), volatility="immutable"
70-
)
71-
72-
df = df.select(
73-
simple_count(column("a"))
74-
.window_frame(WindowFrame("rows", None, None))
75-
.build()
76-
.alias("count")
77-
)
78-
result = df.collect()[0]
79-
assert result.column(0) == pa.array([0, 1, 2])
80-
195+
smooth_default = udwf(
196+
ExponentialSmoothDefault,
197+
pa.float64(),
198+
pa.float64(),
199+
volatility="immutable",
200+
)
81201

82-
def test_udwf_with_args(df):
83-
"""Test UDWF with constructor arguments."""
84-
count_base10 = udwf(
85-
lambda: SimpleWindowCount(10), pa.int64(), pa.int64(), volatility="immutable"
86-
)
202+
smooth_w_arguments = udwf(
203+
lambda: ExponentialSmoothDefault(0.8),
204+
pa.float64(),
205+
pa.float64(),
206+
volatility="immutable",
207+
)
87208

88-
df = df.select(
89-
count_base10(column("a"))
90-
.window_frame(WindowFrame("rows", None, None))
91-
.build()
92-
.alias("count")
93-
)
94-
result = df.collect()[0]
95-
assert result.column(0) == pa.array([10, 11, 12])
209+
smooth_bounded = udwf(
210+
ExponentialSmoothBounded,
211+
pa.float64(),
212+
pa.float64(),
213+
volatility="immutable",
214+
)
96215

216+
smooth_rank = udwf(
217+
ExponentialSmoothRank,
218+
pa.utf8(),
219+
pa.float64(),
220+
volatility="immutable",
221+
)
97222

98-
def test_udwf_decorator_basic(df):
99-
"""Test UDWF used as a decorator."""
223+
smooth_frame = udwf(
224+
ExponentialSmoothFrame,
225+
pa.float64(),
226+
pa.float64(),
227+
volatility="immutable",
228+
)
100229

101-
@udwf([pa.int64()], pa.int64(), "immutable")
102-
def window_count() -> WindowEvaluator:
103-
return SimpleWindowCount()
230+
smooth_two_col = udwf(
231+
SmoothTwoColumn,
232+
[pa.int64(), pa.int64()],
233+
pa.float64(),
234+
volatility="immutable",
235+
)
104236

105-
df = df.select(
106-
window_count(column("a"))
237+
data_test_udwf_functions = [
238+
(
239+
"default_udwf_no_arguments",
240+
smooth_default(column("a")),
241+
[0, 0.9, 1.89, 2.889, 3.889, 4.889, 5.889],
242+
),
243+
(
244+
"default_udwf_w_arguments",
245+
smooth_w_arguments(column("a")),
246+
[0, 0.8, 1.76, 2.752, 3.75, 4.75, 5.75],
247+
),
248+
(
249+
"default_udwf_partitioned",
250+
smooth_default(column("a")).partition_by(column("c")).build(),
251+
[0, 0.9, 1.89, 2.889, 4.0, 4.9, 5.89],
252+
),
253+
(
254+
"default_udwf_ordered",
255+
smooth_default(column("a")).order_by(column("b")).build(),
256+
[0.551, 1.13, 2.3, 2.755, 3.876, 5.0, 5.513],
257+
),
258+
(
259+
"bounded_udwf",
260+
smooth_bounded(column("a")),
261+
[0, 0.9, 1.9, 2.9, 3.9, 4.9, 5.9],
262+
),
263+
(
264+
"bounded_udwf_ignores_frame",
265+
smooth_bounded(column("a"))
107266
.window_frame(WindowFrame("rows", None, None))
108-
.build()
109-
.alias("count")
110-
)
111-
result = df.collect()[0]
112-
assert result.column(0) == pa.array([0, 1, 2])
267+
.build(),
268+
[0, 0.9, 1.9, 2.9, 3.9, 4.9, 5.9],
269+
),
270+
(
271+
"rank_udwf",
272+
smooth_rank(column("c")).order_by(column("c")).build(),
273+
[1, 1, 1, 1, 1.9, 2, 2],
274+
),
275+
(
276+
"frame_unbounded_udwf",
277+
smooth_frame(column("a")).window_frame(WindowFrame("rows", None, None)).build(),
278+
[5.889, 5.889, 5.889, 5.889, 5.889, 5.889, 5.889],
279+
),
280+
(
281+
"frame_bounded_udwf",
282+
smooth_frame(column("a")).window_frame(WindowFrame("rows", None, 0)).build(),
283+
[0.0, 0.9, 1.89, 2.889, 3.889, 4.889, 5.889],
284+
),
285+
(
286+
"frame_bounded_udwf",
287+
smooth_frame(column("a"))
288+
.window_frame(WindowFrame("rows", None, 0))
289+
.order_by(column("b"))
290+
.build(),
291+
[0.551, 1.13, 2.3, 2.755, 3.876, 5.0, 5.513],
292+
),
293+
(
294+
"two_column_udwf",
295+
smooth_two_col(column("a"), column("b")),
296+
[0.0, 1.0, 2.0, 2.2, 3.2, 5.0, 6.0],
297+
),
298+
]
113299

114300

115-
def test_udwf_decorator_with_args(df):
116-
"""Test UDWF decorator with constructor arguments."""
301+
@pytest.mark.parametrize(("name", "expr", "expected"), data_test_udwf_functions)
302+
def test_udwf_functions(df, name, expr, expected):
303+
df = df.select("a", "b", f.round(expr, lit(3)).alias(name))
117304

118-
@udwf([pa.int64()], pa.int64(), "immutable")
119-
def window_count_base10() -> WindowEvaluator:
120-
return SimpleWindowCount(10)
121-
122-
df = df.select(
123-
window_count_base10(column("a"))
124-
.window_frame(WindowFrame("rows", None, None))
125-
.build()
126-
.alias("count")
127-
)
128-
result = df.collect()[0]
129-
assert result.column(0) == pa.array([10, 11, 12])
130-
131-
132-
def test_register_udwf(ctx, df):
133-
"""Test registering and using UDWF in SQL context."""
134-
window_count = udwf(
135-
SimpleWindowCount,
136-
[pa.int64()],
137-
pa.int64(),
138-
volatility="immutable",
139-
name="window_count",
140-
)
305+
# execute and collect the first (and only) batch
306+
result = df.sort(column("a")).select(column(name)).collect()[0]
141307

142-
ctx.register_udwf(window_count)
143-
result = ctx.sql(
144-
"SELECT window_count(a) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) FROM test_table"
145-
).collect()[0]
146-
assert result.column(0) == pa.array([0, 1, 2])
308+
assert result.column(0) == pa.array(expected)

0 commit comments

Comments
 (0)