Skip to content

Commit 9885dae

Browse files
committed
Add unit tests for UDF showing callable class
1 parent 03df7fe commit 9885dae

File tree

3 files changed

+80
-34
lines changed

3 files changed

+80
-34
lines changed

python/datafusion/tests/test_dataframe.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
WindowFrame,
3030
column,
3131
literal,
32-
udf,
3332
)
3433
from datafusion.expr import Window
3534

@@ -236,21 +235,6 @@ def test_unnest_without_nulls(nested_df):
236235
assert result.column(1) == pa.array([7, 8, 8, 9, 9, 9])
237236

238237

239-
def test_udf(df):
240-
# is_null is a pa function over arrays
241-
is_null = udf(
242-
lambda x: x.is_null(),
243-
[pa.int64()],
244-
pa.bool_(),
245-
volatility="immutable",
246-
)
247-
248-
df = df.select(is_null(column("a")))
249-
result = df.collect()[0].column(0)
250-
251-
assert result == pa.array([False, False, False])
252-
253-
254238
def test_join():
255239
ctx = SessionContext()
256240

python/datafusion/tests/test_udaf.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import pyarrow.compute as pc
2222
import pytest
2323

24-
from datafusion import Accumulator, column, udaf, udf
24+
from datafusion import Accumulator, column, udaf
2525

2626

2727
class Summarize(Accumulator):
@@ -173,20 +173,3 @@ def test_register_udaf(ctx, df) -> None:
173173
df_result = ctx.sql("select summarize(b) from test_table")
174174

175175
assert df_result.collect()[0][0][0].as_py() == 14.0
176-
177-
178-
def test_register_udf(ctx, df) -> None:
179-
is_null = udf(
180-
lambda x: x.is_null(),
181-
[pa.float64()],
182-
pa.bool_(),
183-
volatility="immutable",
184-
name="is_null",
185-
)
186-
187-
ctx.register_udf(is_null)
188-
189-
df_result = ctx.sql("select is_null(a) from test_table")
190-
result = df_result.collect()[0].column(0)
191-
192-
assert result == pa.array([False, False, False])
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
from datafusion import udf, column
2+
import pyarrow as pa
3+
import pytest
4+
5+
6+
@pytest.fixture
7+
def df(ctx):
8+
# create a RecordBatch and a new DataFrame from it
9+
batch = pa.RecordBatch.from_arrays(
10+
[pa.array([1, 2, 3]), pa.array([4, 4, 6])],
11+
names=["a", "b"],
12+
)
13+
return ctx.create_dataframe([[batch]], name="test_table")
14+
15+
16+
def test_udf(df):
17+
# is_null is a pa function over arrays
18+
is_null = udf(
19+
lambda x: x.is_null(),
20+
[pa.int64()],
21+
pa.bool_(),
22+
volatility="immutable",
23+
)
24+
25+
df = df.select(is_null(column("a")))
26+
result = df.collect()[0].column(0)
27+
28+
assert result == pa.array([False, False, False])
29+
30+
31+
def test_register_udf(ctx, df) -> None:
32+
is_null = udf(
33+
lambda x: x.is_null(),
34+
[pa.float64()],
35+
pa.bool_(),
36+
volatility="immutable",
37+
name="is_null",
38+
)
39+
40+
ctx.register_udf(is_null)
41+
42+
df_result = ctx.sql("select is_null(a) from test_table")
43+
result = df_result.collect()[0].column(0)
44+
45+
assert result == pa.array([False, False, False])
46+
47+
48+
class OverThresholdUDF:
49+
def __init__(self, threshold: int = 0) -> None:
50+
self.threshold = threshold
51+
52+
def __call__(self, values: pa.Array) -> pa.Array:
53+
return pa.array(v.as_py() >= self.threshold for v in values)
54+
55+
56+
def test_udf_with_parameters(df) -> None:
57+
udf_no_param = udf(
58+
OverThresholdUDF(),
59+
pa.int64(),
60+
pa.bool_(),
61+
volatility="immutable",
62+
)
63+
64+
df1 = df.select(udf_no_param(column("a")))
65+
result = df1.collect()[0].column(0)
66+
67+
assert result == pa.array([True, True, True])
68+
69+
udf_with_param = udf(
70+
OverThresholdUDF(2),
71+
pa.int64(),
72+
pa.bool_(),
73+
volatility="immutable",
74+
)
75+
76+
df2 = df.select(udf_with_param(column("a")))
77+
result = df2.collect()[0].column(0)
78+
79+
assert result == pa.array([False, True, True])

0 commit comments

Comments
 (0)