Skip to content

Commit c191be7

Browse files
committed
refactor: Simplify UDWF test suite and introduce SimpleWindowCount evaluator
- Removed multiple exponential smoothing classes to streamline the code. - Introduced SimpleWindowCount class for basic row counting functionality. - Updated test cases to validate the new SimpleWindowCount evaluator. - Refactored fixture and test functions for clarity and consistency. - Enhanced error handling in UDWF creation tests.
1 parent 3002567 commit c191be7

File tree

1 file changed

+89
-251
lines changed

1 file changed

+89
-251
lines changed

python/tests/test_udwf.py

Lines changed: 89 additions & 251 deletions
Original file line numberDiff line numberDiff line change
@@ -20,289 +20,127 @@
2020
import pyarrow as pa
2121
import pytest
2222
from datafusion import SessionContext, column, lit, udwf
23-
from datafusion import functions as f
2423
from datafusion.expr import WindowFrame
2524
from datafusion.udf import WindowEvaluator
2625

2726

28-
class ExponentialSmoothDefault(WindowEvaluator):
29-
def __init__(self, alpha: float = 0.9) -> None:
30-
self.alpha = alpha
27+
class SimpleWindowCount(WindowEvaluator):
28+
"""A simple window evaluator that counts rows."""
3129

32-
def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array:
33-
results = []
34-
curr_value = 0.0
35-
values = values[0]
36-
for idx in range(num_rows):
37-
if idx == 0:
38-
curr_value = values[idx].as_py()
39-
else:
40-
curr_value = values[idx].as_py() * self.alpha + curr_value * (
41-
1.0 - self.alpha
42-
)
43-
results.append(curr_value)
44-
45-
return pa.array(results)
46-
47-
48-
class ExponentialSmoothBounded(WindowEvaluator):
49-
def __init__(self, alpha: float = 0.9) -> None:
50-
self.alpha = alpha
51-
52-
def supports_bounded_execution(self) -> bool:
53-
return True
54-
55-
def get_range(self, idx: int, num_rows: int) -> tuple[int, int]:
56-
# Override the default range of current row since uses_window_frame is False
57-
# So for the purpose of this test we just smooth from the previous row to
58-
# current.
59-
if idx == 0:
60-
return (0, 0)
61-
return (idx - 1, idx)
62-
63-
def evaluate(
64-
self, values: list[pa.Array], eval_range: tuple[int, int]
65-
) -> pa.Scalar:
66-
(start, stop) = eval_range
67-
curr_value = 0.0
68-
values = values[0]
69-
for idx in range(start, stop + 1):
70-
if idx == start:
71-
curr_value = values[idx].as_py()
72-
else:
73-
curr_value = values[idx].as_py() * self.alpha + curr_value * (
74-
1.0 - self.alpha
75-
)
76-
return pa.scalar(curr_value).cast(pa.float64())
77-
78-
79-
class ExponentialSmoothRank(WindowEvaluator):
80-
def __init__(self, alpha: float = 0.9) -> None:
81-
self.alpha = alpha
82-
83-
def include_rank(self) -> bool:
84-
return True
85-
86-
def evaluate_all_with_rank(
87-
self, num_rows: int, ranks_in_partition: list[tuple[int, int]]
88-
) -> pa.Array:
89-
results = []
90-
for idx in range(num_rows):
91-
if idx == 0:
92-
prior_value = 1.0
93-
matching_row = [
94-
i
95-
for i in range(len(ranks_in_partition))
96-
if ranks_in_partition[i][0] <= idx and ranks_in_partition[i][1] > idx
97-
][0] + 1
98-
curr_value = matching_row * self.alpha + prior_value * (1.0 - self.alpha)
99-
results.append(curr_value)
100-
prior_value = matching_row
101-
102-
return pa.array(results)
103-
104-
105-
class ExponentialSmoothFrame(WindowEvaluator):
106-
def __init__(self, alpha: float = 0.9) -> None:
107-
self.alpha = alpha
108-
109-
def uses_window_frame(self) -> bool:
110-
return True
111-
112-
def evaluate(
113-
self, values: list[pa.Array], eval_range: tuple[int, int]
114-
) -> pa.Scalar:
115-
(start, stop) = eval_range
116-
curr_value = 0.0
117-
if len(values) > 1:
118-
order_by = values[1] # noqa: F841
119-
values = values[0]
120-
else:
121-
values = values[0]
122-
for idx in range(start, stop):
123-
if idx == start:
124-
curr_value = values[idx].as_py()
125-
else:
126-
curr_value = values[idx].as_py() * self.alpha + curr_value * (
127-
1.0 - self.alpha
128-
)
129-
return pa.scalar(curr_value).cast(pa.float64())
130-
131-
132-
class SmoothTwoColumn(WindowEvaluator):
133-
"""This class demonstrates using two columns.
134-
135-
If the second column is above a threshold, then smooth over the first column from
136-
the previous and next rows.
137-
"""
138-
139-
def __init__(self, alpha: float = 0.9) -> None:
140-
self.alpha = alpha
30+
def __init__(self, base: int = 0) -> None:
31+
self.base = base
14132

14233
def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array:
143-
results = []
144-
values_a = values[0]
145-
values_b = values[1]
146-
for idx in range(num_rows):
147-
if values_b[idx].as_py() > 7:
148-
if idx == 0:
149-
results.append(values_a[1].cast(pa.float64()))
150-
elif idx == num_rows - 1:
151-
results.append(values_a[num_rows - 2].cast(pa.float64()))
152-
else:
153-
results.append(
154-
pa.scalar(
155-
values_a[idx - 1].as_py() * self.alpha
156-
+ values_a[idx + 1].as_py() * (1.0 - self.alpha)
157-
)
158-
)
159-
else:
160-
results.append(values_a[idx].cast(pa.float64()))
161-
162-
return pa.array(results)
34+
return pa.array([self.base + i for i in range(num_rows)])
16335

16436

16537
class NotSubclassOfWindowEvaluator:
16638
pass
16739

16840

16941
@pytest.fixture
170-
def df():
171-
ctx = SessionContext()
42+
def ctx():
43+
return SessionContext()
17244

45+
46+
@pytest.fixture
47+
def df(ctx):
17348
# create a RecordBatch and a new DataFrame from it
17449
batch = pa.RecordBatch.from_arrays(
175-
[
176-
pa.array([0, 1, 2, 3, 4, 5, 6]),
177-
pa.array([7, 4, 3, 8, 9, 1, 6]),
178-
pa.array(["A", "A", "A", "A", "B", "B", "B"]),
179-
],
180-
names=["a", "b", "c"],
50+
[pa.array([1, 2, 3]), pa.array([4, 4, 6])],
51+
names=["a", "b"],
18152
)
182-
return ctx.create_dataframe([[batch]])
53+
return ctx.create_dataframe([[batch]], name="test_table")
18354

18455

185-
def test_udwf_errors(df):
186-
with pytest.raises(TypeError):
56+
def test_udwf_errors():
57+
"""Test error cases for UDWF creation."""
58+
with pytest.raises(
59+
TypeError, match="`func` must implement the abstract base class WindowEvaluator"
60+
):
18761
udwf(
188-
NotSubclassOfWindowEvaluator,
189-
pa.float64(),
190-
pa.float64(),
191-
volatility="immutable",
62+
NotSubclassOfWindowEvaluator, pa.int64(), pa.int64(), volatility="immutable"
19263
)
19364

19465

195-
smooth_default = udwf(
196-
ExponentialSmoothDefault,
197-
pa.float64(),
198-
pa.float64(),
199-
volatility="immutable",
200-
)
66+
def test_udwf_basic_usage(df):
67+
"""Test basic UDWF usage with a simple counting window function."""
68+
simple_count = udwf(
69+
SimpleWindowCount, pa.int64(), pa.int64(), volatility="immutable"
70+
)
71+
72+
df = df.select(
73+
simple_count(column("a"))
74+
.window_frame(WindowFrame("rows", None, None))
75+
.build()
76+
.alias("count")
77+
)
78+
result = df.collect()[0]
79+
assert result.column(0) == pa.array([0, 1, 2])
80+
20181

202-
smooth_w_arguments = udwf(
203-
lambda: ExponentialSmoothDefault(0.8),
204-
pa.float64(),
205-
pa.float64(),
206-
volatility="immutable",
207-
)
82+
def test_udwf_with_args(df):
83+
"""Test UDWF with constructor arguments."""
84+
count_base10 = udwf(
85+
lambda: SimpleWindowCount(10), pa.int64(), pa.int64(), volatility="immutable"
86+
)
20887

209-
smooth_bounded = udwf(
210-
ExponentialSmoothBounded,
211-
pa.float64(),
212-
pa.float64(),
213-
volatility="immutable",
214-
)
88+
df = df.select(
89+
count_base10(column("a"))
90+
.window_frame(WindowFrame("rows", None, None))
91+
.build()
92+
.alias("count")
93+
)
94+
result = df.collect()[0]
95+
assert result.column(0) == pa.array([10, 11, 12])
21596

216-
smooth_rank = udwf(
217-
ExponentialSmoothRank,
218-
pa.utf8(),
219-
pa.float64(),
220-
volatility="immutable",
221-
)
22297

223-
smooth_frame = udwf(
224-
ExponentialSmoothFrame,
225-
pa.float64(),
226-
pa.float64(),
227-
volatility="immutable",
228-
)
98+
def test_udwf_decorator_basic(df):
99+
"""Test UDWF used as a decorator."""
229100

230-
smooth_two_col = udwf(
231-
SmoothTwoColumn,
232-
[pa.int64(), pa.int64()],
233-
pa.float64(),
234-
volatility="immutable",
235-
)
101+
@udwf([pa.int64()], pa.int64(), "immutable")
102+
def window_count() -> WindowEvaluator:
103+
return SimpleWindowCount()
236104

237-
data_test_udwf_functions = [
238-
(
239-
"default_udwf_no_arguments",
240-
smooth_default(column("a")),
241-
[0, 0.9, 1.89, 2.889, 3.889, 4.889, 5.889],
242-
),
243-
(
244-
"default_udwf_w_arguments",
245-
smooth_w_arguments(column("a")),
246-
[0, 0.8, 1.76, 2.752, 3.75, 4.75, 5.75],
247-
),
248-
(
249-
"default_udwf_partitioned",
250-
smooth_default(column("a")).partition_by(column("c")).build(),
251-
[0, 0.9, 1.89, 2.889, 4.0, 4.9, 5.89],
252-
),
253-
(
254-
"default_udwf_ordered",
255-
smooth_default(column("a")).order_by(column("b")).build(),
256-
[0.551, 1.13, 2.3, 2.755, 3.876, 5.0, 5.513],
257-
),
258-
(
259-
"bounded_udwf",
260-
smooth_bounded(column("a")),
261-
[0, 0.9, 1.9, 2.9, 3.9, 4.9, 5.9],
262-
),
263-
(
264-
"bounded_udwf_ignores_frame",
265-
smooth_bounded(column("a"))
105+
df = df.select(
106+
window_count(column("a"))
266107
.window_frame(WindowFrame("rows", None, None))
267-
.build(),
268-
[0, 0.9, 1.9, 2.9, 3.9, 4.9, 5.9],
269-
),
270-
(
271-
"rank_udwf",
272-
smooth_rank(column("c")).order_by(column("c")).build(),
273-
[1, 1, 1, 1, 1.9, 2, 2],
274-
),
275-
(
276-
"frame_unbounded_udwf",
277-
smooth_frame(column("a")).window_frame(WindowFrame("rows", None, None)).build(),
278-
[5.889, 5.889, 5.889, 5.889, 5.889, 5.889, 5.889],
279-
),
280-
(
281-
"frame_bounded_udwf",
282-
smooth_frame(column("a")).window_frame(WindowFrame("rows", None, 0)).build(),
283-
[0.0, 0.9, 1.89, 2.889, 3.889, 4.889, 5.889],
284-
),
285-
(
286-
"frame_bounded_udwf",
287-
smooth_frame(column("a"))
288-
.window_frame(WindowFrame("rows", None, 0))
289-
.order_by(column("b"))
290-
.build(),
291-
[0.551, 1.13, 2.3, 2.755, 3.876, 5.0, 5.513],
292-
),
293-
(
294-
"two_column_udwf",
295-
smooth_two_col(column("a"), column("b")),
296-
[0.0, 1.0, 2.0, 2.2, 3.2, 5.0, 6.0],
297-
),
298-
]
108+
.build()
109+
.alias("count")
110+
)
111+
result = df.collect()[0]
112+
assert result.column(0) == pa.array([0, 1, 2])
299113

300114

301-
@pytest.mark.parametrize(("name", "expr", "expected"), data_test_udwf_functions)
302-
def test_udwf_functions(df, name, expr, expected):
303-
df = df.select("a", "b", f.round(expr, lit(3)).alias(name))
115+
def test_udwf_decorator_with_args(df):
116+
"""Test UDWF decorator with constructor arguments."""
304117

305-
# execute and collect the first (and only) batch
306-
result = df.sort(column("a")).select(column(name)).collect()[0]
118+
@udwf([pa.int64()], pa.int64(), "immutable")
119+
def window_count_base10() -> WindowEvaluator:
120+
return SimpleWindowCount(10)
121+
122+
df = df.select(
123+
window_count_base10(column("a"))
124+
.window_frame(WindowFrame("rows", None, None))
125+
.build()
126+
.alias("count")
127+
)
128+
result = df.collect()[0]
129+
assert result.column(0) == pa.array([10, 11, 12])
130+
131+
132+
def test_register_udwf(ctx, df):
133+
"""Test registering and using UDWF in SQL context."""
134+
window_count = udwf(
135+
SimpleWindowCount,
136+
[pa.int64()],
137+
pa.int64(),
138+
volatility="immutable",
139+
name="window_count",
140+
)
307141

308-
assert result.column(0) == pa.array(expected)
142+
ctx.register_udwf(window_count)
143+
result = ctx.sql(
144+
"SELECT window_count(a) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) FROM test_table"
145+
).collect()[0]
146+
assert result.column(0) == pa.array([0, 1, 2])

0 commit comments

Comments
 (0)