Skip to content

Commit 9430084

Browse files
committed
Switching to use factory methods for udaf and udwf
1 parent 6b935c8 commit 9430084

File tree

5 files changed

+114
-77
lines changed

5 files changed

+114
-77
lines changed

python/datafusion/tests/test_udaf.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -79,22 +79,19 @@ def test_errors(df):
7979
volatility="immutable",
8080
)
8181

82-
accum = udaf(
83-
MissingMethods,
84-
pa.int64(),
85-
pa.int64(),
86-
[pa.int64()],
87-
volatility="immutable",
88-
)
89-
df = df.aggregate([], [accum(column("a"))])
90-
9182
msg = (
9283
"Can't instantiate abstract class MissingMethods (without an implementation "
9384
"for abstract methods 'evaluate', 'merge', 'update'|with abstract methods "
9485
"evaluate, merge, update)"
9586
)
9687
with pytest.raises(Exception, match=msg):
97-
df.collect()
88+
accum = udaf( # noqa F841
89+
MissingMethods,
90+
pa.int64(),
91+
pa.int64(),
92+
[pa.int64()],
93+
volatility="immutable",
94+
)
9895

9996

10097
def test_udaf_aggregate(df):
@@ -125,12 +122,11 @@ def test_udaf_aggregate_with_arguments(df):
125122
bias = 10.0
126123

127124
summarize = udaf(
128-
Summarize,
125+
lambda: Summarize(bias),
129126
pa.float64(),
130127
pa.float64(),
131128
[pa.float64()],
132129
volatility="immutable",
133-
arguments=[bias],
134130
)
135131

136132
df1 = df.aggregate([], [summarize(column("a"))])
@@ -140,6 +136,13 @@ def test_udaf_aggregate_with_arguments(df):
140136

141137
assert result.column(0) == pa.array([bias + 1.0 + 2.0 + 3.0])
142138

139+
df2 = df.aggregate([], [summarize(column("a"))])
140+
141+
# Run a second time to ensure the state is properly reset
142+
result = df2.collect()[0]
143+
144+
assert result.column(0) == pa.array([bias + 1.0 + 2.0 + 3.0])
145+
143146

144147
def test_group_by(df):
145148
summarize = udaf(

python/datafusion/tests/test_udwf.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525

2626
class ExponentialSmoothDefault(WindowEvaluator):
27-
def __init__(self, alpha: float = 0.8) -> None:
27+
def __init__(self, alpha: float = 0.9) -> None:
2828
self.alpha = alpha
2929

3030
def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array:
@@ -44,7 +44,7 @@ def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array:
4444

4545

4646
class ExponentialSmoothBounded(WindowEvaluator):
47-
def __init__(self, alpha: float) -> None:
47+
def __init__(self, alpha: float = 0.9) -> None:
4848
self.alpha = alpha
4949

5050
def supports_bounded_execution(self) -> bool:
@@ -75,7 +75,7 @@ def evaluate(
7575

7676

7777
class ExponentialSmoothRank(WindowEvaluator):
78-
def __init__(self, alpha: float) -> None:
78+
def __init__(self, alpha: float = 0.9) -> None:
7979
self.alpha = alpha
8080

8181
def include_rank(self) -> bool:
@@ -101,7 +101,7 @@ def evaluate_all_with_rank(
101101

102102

103103
class ExponentialSmoothFrame(WindowEvaluator):
104-
def __init__(self, alpha: float) -> None:
104+
def __init__(self, alpha: float = 0.9) -> None:
105105
self.alpha = alpha
106106

107107
def uses_window_frame(self) -> bool:
@@ -134,7 +134,7 @@ class SmoothTwoColumn(WindowEvaluator):
134134
the previous and next rows.
135135
"""
136136

137-
def __init__(self, alpha: float) -> None:
137+
def __init__(self, alpha: float = 0.9) -> None:
138138
self.alpha = alpha
139139

140140
def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array:
@@ -195,11 +195,10 @@ def test_udwf_errors(df):
195195
pa.float64(),
196196
pa.float64(),
197197
volatility="immutable",
198-
arguments=[0.9],
199198
)
200199

201-
smooth_no_arugments = udwf(
202-
ExponentialSmoothDefault,
200+
smooth_w_arguments = udwf(
201+
lambda: ExponentialSmoothDefault(0.8),
203202
pa.float64(),
204203
pa.float64(),
205204
volatility="immutable",
@@ -210,42 +209,38 @@ def test_udwf_errors(df):
210209
pa.float64(),
211210
pa.float64(),
212211
volatility="immutable",
213-
arguments=[0.9],
214212
)
215213

216214
smooth_rank = udwf(
217215
ExponentialSmoothRank,
218216
pa.utf8(),
219217
pa.float64(),
220218
volatility="immutable",
221-
arguments=[0.9],
222219
)
223220

224221
smooth_frame = udwf(
225222
ExponentialSmoothFrame,
226223
pa.float64(),
227224
pa.float64(),
228225
volatility="immutable",
229-
arguments=[0.9],
230226
)
231227

232228
smooth_two_col = udwf(
233229
SmoothTwoColumn,
234230
[pa.int64(), pa.int64()],
235231
pa.float64(),
236232
volatility="immutable",
237-
arguments=[0.9],
238233
)
239234

240235
data_test_udwf_functions = [
241236
(
242-
"default_udwf",
237+
"default_udwf_no_arguments",
243238
smooth_default(column("a")),
244239
[0, 0.9, 1.89, 2.889, 3.889, 4.889, 5.889],
245240
),
246241
(
247-
"default_udwf_no_arguments",
248-
smooth_no_arugments(column("a")),
242+
"default_udwf_w_arguments",
243+
smooth_w_arguments(column("a")),
249244
[0, 0.8, 1.76, 2.752, 3.75, 4.75, 5.75],
250245
),
251246
(

0 commit comments

Comments
 (0)