2020import pyarrow as pa
2121import pytest
2222from datafusion import SessionContext , column , lit , udwf
23+ from datafusion import functions as f
2324from datafusion .expr import WindowFrame
2425from datafusion .udf import WindowEvaluator
2526
2627
28+ class ExponentialSmoothDefault (WindowEvaluator ):
29+ def __init__ (self , alpha : float = 0.9 ) -> None :
30+ self .alpha = alpha
31+
32+ def evaluate_all (self , values : list [pa .Array ], num_rows : int ) -> pa .Array :
33+ results = []
34+ curr_value = 0.0
35+ values = values [0 ]
36+ for idx in range (num_rows ):
37+ if idx == 0 :
38+ curr_value = values [idx ].as_py ()
39+ else :
40+ curr_value = values [idx ].as_py () * self .alpha + curr_value * (
41+ 1.0 - self .alpha
42+ )
43+ results .append (curr_value )
44+
45+ return pa .array (results )
46+
47+
48+ class ExponentialSmoothBounded (WindowEvaluator ):
49+ def __init__ (self , alpha : float = 0.9 ) -> None :
50+ self .alpha = alpha
51+
52+ def supports_bounded_execution (self ) -> bool :
53+ return True
54+
55+ def get_range (self , idx : int , num_rows : int ) -> tuple [int , int ]:
56+ # Override the default range of current row since uses_window_frame is False
57+ # So for the purpose of this test we just smooth from the previous row to
58+ # current.
59+ if idx == 0 :
60+ return (0 , 0 )
61+ return (idx - 1 , idx )
62+
63+ def evaluate (
64+ self , values : list [pa .Array ], eval_range : tuple [int , int ]
65+ ) -> pa .Scalar :
66+ (start , stop ) = eval_range
67+ curr_value = 0.0
68+ values = values [0 ]
69+ for idx in range (start , stop + 1 ):
70+ if idx == start :
71+ curr_value = values [idx ].as_py ()
72+ else :
73+ curr_value = values [idx ].as_py () * self .alpha + curr_value * (
74+ 1.0 - self .alpha
75+ )
76+ return pa .scalar (curr_value ).cast (pa .float64 ())
77+
78+
79+ class ExponentialSmoothRank (WindowEvaluator ):
80+ def __init__ (self , alpha : float = 0.9 ) -> None :
81+ self .alpha = alpha
82+
83+ def include_rank (self ) -> bool :
84+ return True
85+
86+ def evaluate_all_with_rank (
87+ self , num_rows : int , ranks_in_partition : list [tuple [int , int ]]
88+ ) -> pa .Array :
89+ results = []
90+ for idx in range (num_rows ):
91+ if idx == 0 :
92+ prior_value = 1.0
93+ matching_row = [
94+ i
95+ for i in range (len (ranks_in_partition ))
96+ if ranks_in_partition [i ][0 ] <= idx and ranks_in_partition [i ][1 ] > idx
97+ ][0 ] + 1
98+ curr_value = matching_row * self .alpha + prior_value * (1.0 - self .alpha )
99+ results .append (curr_value )
100+ prior_value = matching_row
101+
102+ return pa .array (results )
103+
104+
105+ class ExponentialSmoothFrame (WindowEvaluator ):
106+ def __init__ (self , alpha : float = 0.9 ) -> None :
107+ self .alpha = alpha
108+
109+ def uses_window_frame (self ) -> bool :
110+ return True
111+
112+ def evaluate (
113+ self , values : list [pa .Array ], eval_range : tuple [int , int ]
114+ ) -> pa .Scalar :
115+ (start , stop ) = eval_range
116+ curr_value = 0.0
117+ if len (values ) > 1 :
118+ order_by = values [1 ] # noqa: F841
119+ values = values [0 ]
120+ else :
121+ values = values [0 ]
122+ for idx in range (start , stop ):
123+ if idx == start :
124+ curr_value = values [idx ].as_py ()
125+ else :
126+ curr_value = values [idx ].as_py () * self .alpha + curr_value * (
127+ 1.0 - self .alpha
128+ )
129+ return pa .scalar (curr_value ).cast (pa .float64 ())
130+
131+
132+ class SmoothTwoColumn (WindowEvaluator ):
133+ """This class demonstrates using two columns.
134+
135+ If the second column is above a threshold, then smooth over the first column from
136+ the previous and next rows.
137+ """
138+
139+ def __init__ (self , alpha : float = 0.9 ) -> None :
140+ self .alpha = alpha
141+
142+ def evaluate_all (self , values : list [pa .Array ], num_rows : int ) -> pa .Array :
143+ results = []
144+ values_a = values [0 ]
145+ values_b = values [1 ]
146+ for idx in range (num_rows ):
147+ if values_b [idx ].as_py () > 7 :
148+ if idx == 0 :
149+ results .append (values_a [1 ].cast (pa .float64 ()))
150+ elif idx == num_rows - 1 :
151+ results .append (values_a [num_rows - 2 ].cast (pa .float64 ()))
152+ else :
153+ results .append (
154+ pa .scalar (
155+ values_a [idx - 1 ].as_py () * self .alpha
156+ + values_a [idx + 1 ].as_py () * (1.0 - self .alpha )
157+ )
158+ )
159+ else :
160+ results .append (values_a [idx ].cast (pa .float64 ()))
161+
162+ return pa .array (results )
163+
164+
27165class SimpleWindowCount (WindowEvaluator ):
28166 """A simple window evaluator that counts rows."""
29167
@@ -44,7 +182,23 @@ def ctx():
44182
45183
46184@pytest .fixture
47- def df (ctx ):
185+ def df ():
186+ ctx = SessionContext ()
187+
188+ # create a RecordBatch and a new DataFrame from it
189+ batch = pa .RecordBatch .from_arrays (
190+ [
191+ pa .array ([0 , 1 , 2 , 3 , 4 , 5 , 6 ]),
192+ pa .array ([7 , 4 , 3 , 8 , 9 , 1 , 6 ]),
193+ pa .array (["A" , "A" , "A" , "A" , "B" , "B" , "B" ]),
194+ ],
195+ names = ["a" , "b" , "c" ],
196+ )
197+ return ctx .create_dataframe ([[batch ]])
198+
199+
200+ @pytest .fixture
201+ def simple_df (ctx ):
48202 # create a RecordBatch and a new DataFrame from it
49203 batch = pa .RecordBatch .from_arrays (
50204 [pa .array ([1 , 2 , 3 ]), pa .array ([4 , 4 , 6 ])],
@@ -53,7 +207,17 @@ def df(ctx):
53207 return ctx .create_dataframe ([[batch ]], name = "test_table" )
54208
55209
56- def test_udwf_errors ():
210+ def test_udwf_errors (df ):
211+ with pytest .raises (TypeError ):
212+ udwf (
213+ NotSubclassOfWindowEvaluator ,
214+ pa .float64 (),
215+ pa .float64 (),
216+ volatility = "immutable" ,
217+ )
218+
219+
220+ def test_udwf_errors_with_message ():
57221 """Test error cases for UDWF creation."""
58222 with pytest .raises (
59223 TypeError , match = "`func` must implement the abstract base class WindowEvaluator"
@@ -63,13 +227,13 @@ def test_udwf_errors():
63227 )
64228
65229
66- def test_udwf_basic_usage (df ):
230+ def test_udwf_basic_usage (simple_df ):
67231 """Test basic UDWF usage with a simple counting window function."""
68232 simple_count = udwf (
69233 SimpleWindowCount , pa .int64 (), pa .int64 (), volatility = "immutable"
70234 )
71235
72- df = df .select (
236+ df = simple_df .select (
73237 simple_count (column ("a" ))
74238 .window_frame (WindowFrame ("rows" , None , None ))
75239 .build ()
@@ -79,13 +243,13 @@ def test_udwf_basic_usage(df):
79243 assert result .column (0 ) == pa .array ([0 , 1 , 2 ])
80244
81245
82- def test_udwf_with_args (df ):
246+ def test_udwf_with_args (simple_df ):
83247 """Test UDWF with constructor arguments."""
84248 count_base10 = udwf (
85249 lambda : SimpleWindowCount (10 ), pa .int64 (), pa .int64 (), volatility = "immutable"
86250 )
87251
88- df = df .select (
252+ df = simple_df .select (
89253 count_base10 (column ("a" ))
90254 .window_frame (WindowFrame ("rows" , None , None ))
91255 .build ()
@@ -95,14 +259,14 @@ def test_udwf_with_args(df):
95259 assert result .column (0 ) == pa .array ([10 , 11 , 12 ])
96260
97261
98- def test_udwf_decorator_basic (df ):
262+ def test_udwf_decorator_basic (simple_df ):
99263 """Test UDWF used as a decorator."""
100264
101265 @udwf ([pa .int64 ()], pa .int64 (), "immutable" )
102266 def window_count () -> WindowEvaluator :
103267 return SimpleWindowCount ()
104268
105- df = df .select (
269+ df = simple_df .select (
106270 window_count (column ("a" ))
107271 .window_frame (WindowFrame ("rows" , None , None ))
108272 .build ()
@@ -112,14 +276,14 @@ def window_count() -> WindowEvaluator:
112276 assert result .column (0 ) == pa .array ([0 , 1 , 2 ])
113277
114278
115- def test_udwf_decorator_with_args (df ):
279+ def test_udwf_decorator_with_args (simple_df ):
116280 """Test UDWF decorator with constructor arguments."""
117281
118282 @udwf ([pa .int64 ()], pa .int64 (), "immutable" )
119283 def window_count_base10 () -> WindowEvaluator :
120284 return SimpleWindowCount (10 )
121285
122- df = df .select (
286+ df = simple_df .select (
123287 window_count_base10 (column ("a" ))
124288 .window_frame (WindowFrame ("rows" , None , None ))
125289 .build ()
@@ -129,7 +293,7 @@ def window_count_base10() -> WindowEvaluator:
129293 assert result .column (0 ) == pa .array ([10 , 11 , 12 ])
130294
131295
132- def test_register_udwf (ctx , df ):
296+ def test_register_udwf (ctx , simple_df ):
133297 """Test registering and using UDWF in SQL context."""
134298 window_count = udwf (
135299 SimpleWindowCount ,
@@ -144,3 +308,119 @@ def test_register_udwf(ctx, df):
144308 "SELECT window_count(a) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) FROM test_table"
145309 ).collect ()[0 ]
146310 assert result .column (0 ) == pa .array ([0 , 1 , 2 ])
311+
312+
313+ smooth_default = udwf (
314+ ExponentialSmoothDefault ,
315+ pa .float64 (),
316+ pa .float64 (),
317+ volatility = "immutable" ,
318+ )
319+
320+ smooth_w_arguments = udwf (
321+ lambda : ExponentialSmoothDefault (0.8 ),
322+ pa .float64 (),
323+ pa .float64 (),
324+ volatility = "immutable" ,
325+ )
326+
327+ smooth_bounded = udwf (
328+ ExponentialSmoothBounded ,
329+ pa .float64 (),
330+ pa .float64 (),
331+ volatility = "immutable" ,
332+ )
333+
334+ smooth_rank = udwf (
335+ ExponentialSmoothRank ,
336+ pa .utf8 (),
337+ pa .float64 (),
338+ volatility = "immutable" ,
339+ )
340+
341+ smooth_frame = udwf (
342+ ExponentialSmoothFrame ,
343+ pa .float64 (),
344+ pa .float64 (),
345+ volatility = "immutable" ,
346+ )
347+
348+ smooth_two_col = udwf (
349+ SmoothTwoColumn ,
350+ [pa .int64 (), pa .int64 ()],
351+ pa .float64 (),
352+ volatility = "immutable" ,
353+ )
354+
355+ data_test_udwf_functions = [
356+ (
357+ "default_udwf_no_arguments" ,
358+ smooth_default (column ("a" )),
359+ [0 , 0.9 , 1.89 , 2.889 , 3.889 , 4.889 , 5.889 ],
360+ ),
361+ (
362+ "default_udwf_w_arguments" ,
363+ smooth_w_arguments (column ("a" )),
364+ [0 , 0.8 , 1.76 , 2.752 , 3.75 , 4.75 , 5.75 ],
365+ ),
366+ (
367+ "default_udwf_partitioned" ,
368+ smooth_default (column ("a" )).partition_by (column ("c" )).build (),
369+ [0 , 0.9 , 1.89 , 2.889 , 4.0 , 4.9 , 5.89 ],
370+ ),
371+ (
372+ "default_udwf_ordered" ,
373+ smooth_default (column ("a" )).order_by (column ("b" )).build (),
374+ [0.551 , 1.13 , 2.3 , 2.755 , 3.876 , 5.0 , 5.513 ],
375+ ),
376+ (
377+ "bounded_udwf" ,
378+ smooth_bounded (column ("a" )),
379+ [0 , 0.9 , 1.9 , 2.9 , 3.9 , 4.9 , 5.9 ],
380+ ),
381+ (
382+ "bounded_udwf_ignores_frame" ,
383+ smooth_bounded (column ("a" ))
384+ .window_frame (WindowFrame ("rows" , None , None ))
385+ .build (),
386+ [0 , 0.9 , 1.9 , 2.9 , 3.9 , 4.9 , 5.9 ],
387+ ),
388+ (
389+ "rank_udwf" ,
390+ smooth_rank (column ("c" )).order_by (column ("c" )).build (),
391+ [1 , 1 , 1 , 1 , 1.9 , 2 , 2 ],
392+ ),
393+ (
394+ "frame_unbounded_udwf" ,
395+ smooth_frame (column ("a" )).window_frame (WindowFrame ("rows" , None , None )).build (),
396+ [5.889 , 5.889 , 5.889 , 5.889 , 5.889 , 5.889 , 5.889 ],
397+ ),
398+ (
399+ "frame_bounded_udwf" ,
400+ smooth_frame (column ("a" )).window_frame (WindowFrame ("rows" , None , 0 )).build (),
401+ [0.0 , 0.9 , 1.89 , 2.889 , 3.889 , 4.889 , 5.889 ],
402+ ),
403+ (
404+ "frame_bounded_udwf" ,
405+ smooth_frame (column ("a" ))
406+ .window_frame (WindowFrame ("rows" , None , 0 ))
407+ .order_by (column ("b" ))
408+ .build (),
409+ [0.551 , 1.13 , 2.3 , 2.755 , 3.876 , 5.0 , 5.513 ],
410+ ),
411+ (
412+ "two_column_udwf" ,
413+ smooth_two_col (column ("a" ), column ("b" )),
414+ [0.0 , 1.0 , 2.0 , 2.2 , 3.2 , 5.0 , 6.0 ],
415+ ),
416+ ]
417+
418+
419+ @pytest .mark .parametrize (("name" , "expr" , "expected" ), data_test_udwf_functions )
420+ def test_udwf_functions (df , name , expr , expected ):
421+ df = df .select ("a" , "b" , f .round (expr , lit (3 )).alias (name ))
422+
423+ # execute and collect the first (and only) batch
424+ result = df .sort (column ("a" )).select (column (name )).collect ()[0 ]
425+
426+ assert result .column (0 ) == pa .array (expected )
0 commit comments