|
20 | 20 | import pyarrow as pa |
21 | 21 | import pytest |
22 | 22 | from datafusion import SessionContext, column, lit, udwf |
23 | | -from datafusion import functions as f |
24 | 23 | from datafusion.expr import WindowFrame |
25 | 24 | from datafusion.udf import WindowEvaluator |
26 | 25 |
|
27 | 26 |
|
28 | | -class ExponentialSmoothDefault(WindowEvaluator): |
29 | | - def __init__(self, alpha: float = 0.9) -> None: |
30 | | - self.alpha = alpha |
| 27 | +class SimpleWindowCount(WindowEvaluator): |
| 28 | + """A simple window evaluator that counts rows.""" |
31 | 29 |
|
32 | | - def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array: |
33 | | - results = [] |
34 | | - curr_value = 0.0 |
35 | | - values = values[0] |
36 | | - for idx in range(num_rows): |
37 | | - if idx == 0: |
38 | | - curr_value = values[idx].as_py() |
39 | | - else: |
40 | | - curr_value = values[idx].as_py() * self.alpha + curr_value * ( |
41 | | - 1.0 - self.alpha |
42 | | - ) |
43 | | - results.append(curr_value) |
44 | | - |
45 | | - return pa.array(results) |
46 | | - |
47 | | - |
48 | | -class ExponentialSmoothBounded(WindowEvaluator): |
49 | | - def __init__(self, alpha: float = 0.9) -> None: |
50 | | - self.alpha = alpha |
51 | | - |
52 | | - def supports_bounded_execution(self) -> bool: |
53 | | - return True |
54 | | - |
55 | | - def get_range(self, idx: int, num_rows: int) -> tuple[int, int]: |
56 | | - # Override the default range of current row since uses_window_frame is False |
57 | | - # So for the purpose of this test we just smooth from the previous row to |
58 | | - # current. |
59 | | - if idx == 0: |
60 | | - return (0, 0) |
61 | | - return (idx - 1, idx) |
62 | | - |
63 | | - def evaluate( |
64 | | - self, values: list[pa.Array], eval_range: tuple[int, int] |
65 | | - ) -> pa.Scalar: |
66 | | - (start, stop) = eval_range |
67 | | - curr_value = 0.0 |
68 | | - values = values[0] |
69 | | - for idx in range(start, stop + 1): |
70 | | - if idx == start: |
71 | | - curr_value = values[idx].as_py() |
72 | | - else: |
73 | | - curr_value = values[idx].as_py() * self.alpha + curr_value * ( |
74 | | - 1.0 - self.alpha |
75 | | - ) |
76 | | - return pa.scalar(curr_value).cast(pa.float64()) |
77 | | - |
78 | | - |
79 | | -class ExponentialSmoothRank(WindowEvaluator): |
80 | | - def __init__(self, alpha: float = 0.9) -> None: |
81 | | - self.alpha = alpha |
82 | | - |
83 | | - def include_rank(self) -> bool: |
84 | | - return True |
85 | | - |
86 | | - def evaluate_all_with_rank( |
87 | | - self, num_rows: int, ranks_in_partition: list[tuple[int, int]] |
88 | | - ) -> pa.Array: |
89 | | - results = [] |
90 | | - for idx in range(num_rows): |
91 | | - if idx == 0: |
92 | | - prior_value = 1.0 |
93 | | - matching_row = [ |
94 | | - i |
95 | | - for i in range(len(ranks_in_partition)) |
96 | | - if ranks_in_partition[i][0] <= idx and ranks_in_partition[i][1] > idx |
97 | | - ][0] + 1 |
98 | | - curr_value = matching_row * self.alpha + prior_value * (1.0 - self.alpha) |
99 | | - results.append(curr_value) |
100 | | - prior_value = matching_row |
101 | | - |
102 | | - return pa.array(results) |
103 | | - |
104 | | - |
105 | | -class ExponentialSmoothFrame(WindowEvaluator): |
106 | | - def __init__(self, alpha: float = 0.9) -> None: |
107 | | - self.alpha = alpha |
108 | | - |
109 | | - def uses_window_frame(self) -> bool: |
110 | | - return True |
111 | | - |
112 | | - def evaluate( |
113 | | - self, values: list[pa.Array], eval_range: tuple[int, int] |
114 | | - ) -> pa.Scalar: |
115 | | - (start, stop) = eval_range |
116 | | - curr_value = 0.0 |
117 | | - if len(values) > 1: |
118 | | - order_by = values[1] # noqa: F841 |
119 | | - values = values[0] |
120 | | - else: |
121 | | - values = values[0] |
122 | | - for idx in range(start, stop): |
123 | | - if idx == start: |
124 | | - curr_value = values[idx].as_py() |
125 | | - else: |
126 | | - curr_value = values[idx].as_py() * self.alpha + curr_value * ( |
127 | | - 1.0 - self.alpha |
128 | | - ) |
129 | | - return pa.scalar(curr_value).cast(pa.float64()) |
130 | | - |
131 | | - |
132 | | -class SmoothTwoColumn(WindowEvaluator): |
133 | | - """This class demonstrates using two columns. |
134 | | -
|
135 | | - If the second column is above a threshold, then smooth over the first column from |
136 | | - the previous and next rows. |
137 | | - """ |
138 | | - |
139 | | - def __init__(self, alpha: float = 0.9) -> None: |
140 | | - self.alpha = alpha |
| 30 | + def __init__(self, base: int = 0) -> None: |
| 31 | + self.base = base |
141 | 32 |
|
142 | 33 | def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array: |
143 | | - results = [] |
144 | | - values_a = values[0] |
145 | | - values_b = values[1] |
146 | | - for idx in range(num_rows): |
147 | | - if values_b[idx].as_py() > 7: |
148 | | - if idx == 0: |
149 | | - results.append(values_a[1].cast(pa.float64())) |
150 | | - elif idx == num_rows - 1: |
151 | | - results.append(values_a[num_rows - 2].cast(pa.float64())) |
152 | | - else: |
153 | | - results.append( |
154 | | - pa.scalar( |
155 | | - values_a[idx - 1].as_py() * self.alpha |
156 | | - + values_a[idx + 1].as_py() * (1.0 - self.alpha) |
157 | | - ) |
158 | | - ) |
159 | | - else: |
160 | | - results.append(values_a[idx].cast(pa.float64())) |
161 | | - |
162 | | - return pa.array(results) |
| 34 | + return pa.array([self.base + i for i in range(num_rows)]) |
163 | 35 |
|
164 | 36 |
|
165 | 37 | class NotSubclassOfWindowEvaluator: |
166 | 38 | pass |
167 | 39 |
|
168 | 40 |
|
169 | 41 | @pytest.fixture |
170 | | -def df(): |
171 | | - ctx = SessionContext() |
| 42 | +def ctx(): |
| 43 | + return SessionContext() |
172 | 44 |
|
| 45 | + |
| 46 | +@pytest.fixture |
| 47 | +def df(ctx): |
173 | 48 | # create a RecordBatch and a new DataFrame from it |
174 | 49 | batch = pa.RecordBatch.from_arrays( |
175 | | - [ |
176 | | - pa.array([0, 1, 2, 3, 4, 5, 6]), |
177 | | - pa.array([7, 4, 3, 8, 9, 1, 6]), |
178 | | - pa.array(["A", "A", "A", "A", "B", "B", "B"]), |
179 | | - ], |
180 | | - names=["a", "b", "c"], |
| 50 | + [pa.array([1, 2, 3]), pa.array([4, 4, 6])], |
| 51 | + names=["a", "b"], |
181 | 52 | ) |
182 | | - return ctx.create_dataframe([[batch]]) |
| 53 | + return ctx.create_dataframe([[batch]], name="test_table") |
183 | 54 |
|
184 | 55 |
|
185 | | -def test_udwf_errors(df): |
186 | | - with pytest.raises(TypeError): |
| 56 | +def test_udwf_errors(): |
| 57 | + """Test error cases for UDWF creation.""" |
| 58 | + with pytest.raises( |
| 59 | + TypeError, match="`func` must implement the abstract base class WindowEvaluator" |
| 60 | + ): |
187 | 61 | udwf( |
188 | | - NotSubclassOfWindowEvaluator, |
189 | | - pa.float64(), |
190 | | - pa.float64(), |
191 | | - volatility="immutable", |
| 62 | + NotSubclassOfWindowEvaluator, pa.int64(), pa.int64(), volatility="immutable" |
192 | 63 | ) |
193 | 64 |
|
194 | 65 |
|
195 | | -smooth_default = udwf( |
196 | | - ExponentialSmoothDefault, |
197 | | - pa.float64(), |
198 | | - pa.float64(), |
199 | | - volatility="immutable", |
200 | | -) |
| 66 | +def test_udwf_basic_usage(df): |
| 67 | + """Test basic UDWF usage with a simple counting window function.""" |
| 68 | + simple_count = udwf( |
| 69 | + SimpleWindowCount, pa.int64(), pa.int64(), volatility="immutable" |
| 70 | + ) |
| 71 | + |
| 72 | + df = df.select( |
| 73 | + simple_count(column("a")) |
| 74 | + .window_frame(WindowFrame("rows", None, None)) |
| 75 | + .build() |
| 76 | + .alias("count") |
| 77 | + ) |
| 78 | + result = df.collect()[0] |
| 79 | + assert result.column(0) == pa.array([0, 1, 2]) |
| 80 | + |
201 | 81 |
|
202 | | -smooth_w_arguments = udwf( |
203 | | - lambda: ExponentialSmoothDefault(0.8), |
204 | | - pa.float64(), |
205 | | - pa.float64(), |
206 | | - volatility="immutable", |
207 | | -) |
| 82 | +def test_udwf_with_args(df): |
| 83 | + """Test UDWF with constructor arguments.""" |
| 84 | + count_base10 = udwf( |
| 85 | + lambda: SimpleWindowCount(10), pa.int64(), pa.int64(), volatility="immutable" |
| 86 | + ) |
208 | 87 |
|
209 | | -smooth_bounded = udwf( |
210 | | - ExponentialSmoothBounded, |
211 | | - pa.float64(), |
212 | | - pa.float64(), |
213 | | - volatility="immutable", |
214 | | -) |
| 88 | + df = df.select( |
| 89 | + count_base10(column("a")) |
| 90 | + .window_frame(WindowFrame("rows", None, None)) |
| 91 | + .build() |
| 92 | + .alias("count") |
| 93 | + ) |
| 94 | + result = df.collect()[0] |
| 95 | + assert result.column(0) == pa.array([10, 11, 12]) |
215 | 96 |
|
216 | | -smooth_rank = udwf( |
217 | | - ExponentialSmoothRank, |
218 | | - pa.utf8(), |
219 | | - pa.float64(), |
220 | | - volatility="immutable", |
221 | | -) |
222 | 97 |
|
223 | | -smooth_frame = udwf( |
224 | | - ExponentialSmoothFrame, |
225 | | - pa.float64(), |
226 | | - pa.float64(), |
227 | | - volatility="immutable", |
228 | | -) |
| 98 | +def test_udwf_decorator_basic(df): |
| 99 | + """Test UDWF used as a decorator.""" |
229 | 100 |
|
230 | | -smooth_two_col = udwf( |
231 | | - SmoothTwoColumn, |
232 | | - [pa.int64(), pa.int64()], |
233 | | - pa.float64(), |
234 | | - volatility="immutable", |
235 | | -) |
| 101 | + @udwf([pa.int64()], pa.int64(), "immutable") |
| 102 | + def window_count() -> WindowEvaluator: |
| 103 | + return SimpleWindowCount() |
236 | 104 |
|
237 | | -data_test_udwf_functions = [ |
238 | | - ( |
239 | | - "default_udwf_no_arguments", |
240 | | - smooth_default(column("a")), |
241 | | - [0, 0.9, 1.89, 2.889, 3.889, 4.889, 5.889], |
242 | | - ), |
243 | | - ( |
244 | | - "default_udwf_w_arguments", |
245 | | - smooth_w_arguments(column("a")), |
246 | | - [0, 0.8, 1.76, 2.752, 3.75, 4.75, 5.75], |
247 | | - ), |
248 | | - ( |
249 | | - "default_udwf_partitioned", |
250 | | - smooth_default(column("a")).partition_by(column("c")).build(), |
251 | | - [0, 0.9, 1.89, 2.889, 4.0, 4.9, 5.89], |
252 | | - ), |
253 | | - ( |
254 | | - "default_udwf_ordered", |
255 | | - smooth_default(column("a")).order_by(column("b")).build(), |
256 | | - [0.551, 1.13, 2.3, 2.755, 3.876, 5.0, 5.513], |
257 | | - ), |
258 | | - ( |
259 | | - "bounded_udwf", |
260 | | - smooth_bounded(column("a")), |
261 | | - [0, 0.9, 1.9, 2.9, 3.9, 4.9, 5.9], |
262 | | - ), |
263 | | - ( |
264 | | - "bounded_udwf_ignores_frame", |
265 | | - smooth_bounded(column("a")) |
| 105 | + df = df.select( |
| 106 | + window_count(column("a")) |
266 | 107 | .window_frame(WindowFrame("rows", None, None)) |
267 | | - .build(), |
268 | | - [0, 0.9, 1.9, 2.9, 3.9, 4.9, 5.9], |
269 | | - ), |
270 | | - ( |
271 | | - "rank_udwf", |
272 | | - smooth_rank(column("c")).order_by(column("c")).build(), |
273 | | - [1, 1, 1, 1, 1.9, 2, 2], |
274 | | - ), |
275 | | - ( |
276 | | - "frame_unbounded_udwf", |
277 | | - smooth_frame(column("a")).window_frame(WindowFrame("rows", None, None)).build(), |
278 | | - [5.889, 5.889, 5.889, 5.889, 5.889, 5.889, 5.889], |
279 | | - ), |
280 | | - ( |
281 | | - "frame_bounded_udwf", |
282 | | - smooth_frame(column("a")).window_frame(WindowFrame("rows", None, 0)).build(), |
283 | | - [0.0, 0.9, 1.89, 2.889, 3.889, 4.889, 5.889], |
284 | | - ), |
285 | | - ( |
286 | | - "frame_bounded_udwf", |
287 | | - smooth_frame(column("a")) |
288 | | - .window_frame(WindowFrame("rows", None, 0)) |
289 | | - .order_by(column("b")) |
290 | | - .build(), |
291 | | - [0.551, 1.13, 2.3, 2.755, 3.876, 5.0, 5.513], |
292 | | - ), |
293 | | - ( |
294 | | - "two_column_udwf", |
295 | | - smooth_two_col(column("a"), column("b")), |
296 | | - [0.0, 1.0, 2.0, 2.2, 3.2, 5.0, 6.0], |
297 | | - ), |
298 | | -] |
| 108 | + .build() |
| 109 | + .alias("count") |
| 110 | + ) |
| 111 | + result = df.collect()[0] |
| 112 | + assert result.column(0) == pa.array([0, 1, 2]) |
299 | 113 |
|
300 | 114 |
|
301 | | -@pytest.mark.parametrize(("name", "expr", "expected"), data_test_udwf_functions) |
302 | | -def test_udwf_functions(df, name, expr, expected): |
303 | | - df = df.select("a", "b", f.round(expr, lit(3)).alias(name)) |
| 115 | +def test_udwf_decorator_with_args(df): |
| 116 | + """Test UDWF decorator with constructor arguments.""" |
304 | 117 |
|
305 | | - # execute and collect the first (and only) batch |
306 | | - result = df.sort(column("a")).select(column(name)).collect()[0] |
| 118 | + @udwf([pa.int64()], pa.int64(), "immutable") |
| 119 | + def window_count_base10() -> WindowEvaluator: |
| 120 | + return SimpleWindowCount(10) |
| 121 | + |
| 122 | + df = df.select( |
| 123 | + window_count_base10(column("a")) |
| 124 | + .window_frame(WindowFrame("rows", None, None)) |
| 125 | + .build() |
| 126 | + .alias("count") |
| 127 | + ) |
| 128 | + result = df.collect()[0] |
| 129 | + assert result.column(0) == pa.array([10, 11, 12]) |
| 130 | + |
| 131 | + |
| 132 | +def test_register_udwf(ctx, df): |
| 133 | + """Test registering and using UDWF in SQL context.""" |
| 134 | + window_count = udwf( |
| 135 | + SimpleWindowCount, |
| 136 | + [pa.int64()], |
| 137 | + pa.int64(), |
| 138 | + volatility="immutable", |
| 139 | + name="window_count", |
| 140 | + ) |
307 | 141 |
|
308 | | - assert result.column(0) == pa.array(expected) |
| 142 | + ctx.register_udwf(window_count) |
| 143 | + result = ctx.sql( |
| 144 | + "SELECT window_count(a) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) FROM test_table" |
| 145 | + ).collect()[0] |
| 146 | + assert result.column(0) == pa.array([0, 1, 2]) |
0 commit comments