Skip to content

Commit 044bbe2

Browse files
authored
Fix regression on register_udaf (#878)
* Test no longer hangs, and updated error string to match latest * Add unit tests for registering udf and udaf * Resolve error on registering udaf #874 * remove stale comment * Update unit test text to match in multiple versions of python * Regex for exception that is compatible with python 3.10 and 3.12
1 parent a00cfbf commit 044bbe2

File tree

2 files changed

+41
-10
lines changed

2 files changed

+41
-10
lines changed

python/datafusion/tests/test_udaf.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import pyarrow.compute as pc
2222
import pytest
2323

24-
from datafusion import Accumulator, SessionContext, column, udaf
24+
from datafusion import Accumulator, column, udaf, udf
2525

2626

2727
class Summarize(Accumulator):
@@ -60,18 +60,15 @@ def state(self) -> List[pa.Scalar]:
6060

6161

6262
@pytest.fixture
63-
def df():
64-
ctx = SessionContext()
65-
63+
def df(ctx):
6664
# create a RecordBatch and a new DataFrame from it
6765
batch = pa.RecordBatch.from_arrays(
6866
[pa.array([1, 2, 3]), pa.array([4, 4, 6])],
6967
names=["a", "b"],
7068
)
71-
return ctx.create_dataframe([[batch]])
69+
return ctx.create_dataframe([[batch]], name="test_table")
7270

7371

74-
@pytest.mark.skip(reason="df.collect() will hang, need more investigations")
7572
def test_errors(df):
7673
with pytest.raises(TypeError):
7774
udaf(
@@ -92,8 +89,9 @@ def test_errors(df):
9289
df = df.aggregate([], [accum(column("a"))])
9390

9491
msg = (
95-
"Can't instantiate abstract class MissingMethods with abstract "
96-
"methods evaluate, merge, update"
92+
"Can't instantiate abstract class MissingMethods (without an implementation "
93+
"for abstract methods 'evaluate', 'merge', 'update'|with abstract methods "
94+
"evaluate, merge, update)"
9795
)
9896
with pytest.raises(Exception, match=msg):
9997
df.collect()
@@ -132,3 +130,36 @@ def test_group_by(df):
132130
arrays = [batch.column(1) for batch in batches]
133131
joined = pa.concat_arrays(arrays)
134132
assert joined == pa.array([1.0 + 2.0, 3.0])
133+
134+
135+
def test_register_udaf(ctx, df) -> None:
136+
summarize = udaf(
137+
Summarize,
138+
pa.float64(),
139+
pa.float64(),
140+
[pa.float64()],
141+
volatility="immutable",
142+
)
143+
144+
ctx.register_udaf(summarize)
145+
146+
df_result = ctx.sql("select summarize(b) from test_table")
147+
148+
assert df_result.collect()[0][0][0].as_py() == 14.0
149+
150+
151+
def test_register_udf(ctx, df) -> None:
152+
is_null = udf(
153+
lambda x: x.is_null(),
154+
[pa.float64()],
155+
pa.bool_(),
156+
volatility="immutable",
157+
name="is_null",
158+
)
159+
160+
ctx.register_udf(is_null)
161+
162+
df_result = ctx.sql("select is_null(a) from test_table")
163+
result = df_result.collect()[0].column(0)
164+
165+
assert result == pa.array([False, False, False])

python/datafusion/udf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def __init__(
192192
See :py:func:`udaf` for a convenience function and argument
193193
descriptions.
194194
"""
195-
self._udf = df_internal.AggregateUDF(
195+
self._udaf = df_internal.AggregateUDF(
196196
name, accumulator, input_types, return_type, state_type, str(volatility)
197197
)
198198

@@ -203,7 +203,7 @@ def __call__(self, *args: Expr) -> Expr:
203203
occur during the evaluation of the dataframe.
204204
"""
205205
args = [arg.expr for arg in args]
206-
return Expr(self._udf.__call__(*args))
206+
return Expr(self._udaf.__call__(*args))
207207

208208
@staticmethod
209209
def udaf(

0 commit comments

Comments
 (0)