Skip to content

Commit 660035d

Browse files
committed
Implementation of udf and udaf decorator
1 parent 69ebf70 commit 660035d

File tree

4 files changed

+134
-7
lines changed

4 files changed

+134
-7
lines changed

python/datafusion/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,9 @@ def lit(value):
118118

119119

120120
udf = ScalarUDF.udf
121+
udf_decorator = ScalarUDF.udf_decorator
121122

122123
udaf = AggregateUDF.udaf
124+
udaf_decorator = AggregateUDF.udaf_decorator
123125

124126
udwf = WindowUDF.udwf

python/datafusion/udf.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from abc import ABCMeta, abstractmethod
2323
from enum import Enum
2424
from typing import TYPE_CHECKING, Callable, List, Optional, TypeVar
25+
import functools
2526

2627
import pyarrow
2728

@@ -148,6 +149,27 @@ def udf(
148149
volatility=volatility,
149150
)
150151

152+
@staticmethod
153+
def udf_decorator(
154+
input_types: list[pyarrow.DataType],
155+
return_type: _R,
156+
volatility: Volatility | str,
157+
name: Optional[str] = None
158+
):
159+
def decorator(func):
160+
udf_caller = ScalarUDF.udf(
161+
func,
162+
input_types,
163+
return_type,
164+
volatility,
165+
name
166+
)
167+
168+
@functools.wraps(func)
169+
def wrapper(*args, **kwargs):
170+
return udf_caller(*args, **kwargs)
171+
return wrapper
172+
return decorator
151173

152174
class Accumulator(metaclass=ABCMeta):
153175
"""Defines how an :py:class:`AggregateUDF` accumulates values."""
@@ -287,6 +309,31 @@ def sum_bias_10() -> Summarize:
287309
state_type=state_type,
288310
volatility=volatility,
289311
)
312+
313+
@staticmethod
314+
def udaf_decorator(
315+
input_types: pyarrow.DataType | list[pyarrow.DataType],
316+
return_type: pyarrow.DataType,
317+
state_type: list[pyarrow.DataType],
318+
volatility: Volatility | str,
319+
name: Optional[str] = None
320+
):
321+
def decorator(accum: Callable[[], Accumulator]):
322+
udaf_caller = AggregateUDF.udaf(
323+
accum,
324+
input_types,
325+
return_type,
326+
state_type,
327+
volatility,
328+
name
329+
)
330+
331+
@functools.wraps(accum)
332+
def wrapper(*args, **kwargs):
333+
return udaf_caller(*args, **kwargs)
334+
return wrapper
335+
return decorator
336+
290337

291338

292339
class WindowEvaluator(metaclass=ABCMeta):

python/tests/test_udaf.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import pyarrow as pa
2121
import pyarrow.compute as pc
2222
import pytest
23-
from datafusion import Accumulator, column, udaf
23+
from datafusion import Accumulator, column, udaf, udaf_decorator
2424

2525

2626
class Summarize(Accumulator):
@@ -116,6 +116,29 @@ def test_udaf_aggregate(df):
116116

117117
assert result.column(0) == pa.array([1.0 + 2.0 + 3.0])
118118

119+
def test_udaf_decorator_aggregate(df):
120+
121+
@udaf_decorator(pa.float64(),
122+
pa.float64(),
123+
[pa.float64()],
124+
"immutable")
125+
def summarize():
126+
return Summarize()
127+
128+
df1 = df.aggregate([], [summarize(column("a"))])
129+
130+
# execute and collect the first (and only) batch
131+
result = df1.collect()[0]
132+
133+
assert result.column(0) == pa.array([1.0 + 2.0 + 3.0])
134+
135+
df2 = df.aggregate([], [summarize(column("a"))])
136+
137+
# Run a second time to ensure the state is properly reset
138+
result = df2.collect()[0]
139+
140+
assert result.column(0) == pa.array([1.0 + 2.0 + 3.0])
141+
119142

120143
def test_udaf_aggregate_with_arguments(df):
121144
bias = 10.0
@@ -143,6 +166,31 @@ def test_udaf_aggregate_with_arguments(df):
143166
assert result.column(0) == pa.array([bias + 1.0 + 2.0 + 3.0])
144167

145168

169+
def test_udaf_decorator_aggregate_with_arguments(df):
170+
bias = 10.0
171+
172+
@udaf_decorator(pa.float64(),
173+
pa.float64(),
174+
[pa.float64()],
175+
"immutable")
176+
def summarize():
177+
return Summarize(bias)
178+
179+
df1 = df.aggregate([], [summarize(column("a"))])
180+
181+
# execute and collect the first (and only) batch
182+
result = df1.collect()[0]
183+
184+
assert result.column(0) == pa.array([bias + 1.0 + 2.0 + 3.0])
185+
186+
df2 = df.aggregate([], [summarize(column("a"))])
187+
188+
# Run a second time to ensure the state is properly reset
189+
result = df2.collect()[0]
190+
191+
assert result.column(0) == pa.array([bias + 1.0 + 2.0 + 3.0])
192+
193+
146194
def test_group_by(df):
147195
summarize = udaf(
148196
Summarize,

python/tests/test_udf.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@
1717

1818
import pyarrow as pa
1919
import pytest
20-
from datafusion import column, udf
20+
from datafusion import column, udf, udf_decorator
2121

2222

2323
@pytest.fixture
2424
def df(ctx):
2525
# create a RecordBatch and a new DataFrame from it
2626
batch = pa.RecordBatch.from_arrays(
27-
[pa.array([1, 2, 3]), pa.array([4, 4, 6])],
27+
[pa.array([1, 2, 3]), pa.array([4, 4, None])],
2828
names=["a", "b"],
2929
)
3030
return ctx.create_dataframe([[batch]], name="test_table")
@@ -39,10 +39,20 @@ def test_udf(df):
3939
volatility="immutable",
4040
)
4141

42-
df = df.select(is_null(column("a")))
42+
df = df.select(is_null(column("b")))
4343
result = df.collect()[0].column(0)
4444

45-
assert result == pa.array([False, False, False])
45+
assert result == pa.array([False, False, True])
46+
47+
48+
def test_udf_decorator(df):
49+
@udf_decorator([pa.int64()], pa.bool_(), "immutable")
50+
def is_null(x: pa.Array) -> pa.Array:
51+
return x.is_null()
52+
53+
df = df.select(is_null(column("b")))
54+
result = df.collect()[0].column(0)
55+
assert result == pa.array([False, False, True])
4656

4757

4858
def test_register_udf(ctx, df) -> None:
@@ -56,10 +66,10 @@ def test_register_udf(ctx, df) -> None:
5666

5767
ctx.register_udf(is_null)
5868

59-
df_result = ctx.sql("select is_null(a) from test_table")
69+
df_result = ctx.sql("select is_null(b) from test_table")
6070
result = df_result.collect()[0].column(0)
6171

62-
assert result == pa.array([False, False, False])
72+
assert result == pa.array([False, False, True])
6373

6474

6575
class OverThresholdUDF:
@@ -94,3 +104,23 @@ def test_udf_with_parameters(df) -> None:
94104
result = df2.collect()[0].column(0)
95105

96106
assert result == pa.array([False, True, True])
107+
108+
109+
def test_udf_with_parameters(df) -> None:
110+
@udf_decorator([pa.int64()], pa.bool_(), "immutable")
111+
def udf_no_param(values: pa.Array) -> pa.Array:
112+
return OverThresholdUDF()(values)
113+
114+
df1 = df.select(udf_no_param(column("a")))
115+
result = df1.collect()[0].column(0)
116+
117+
assert result == pa.array([True, True, True])
118+
119+
@udf_decorator([pa.int64()], pa.bool_(), "immutable")
120+
def udf_with_param(values: pa.Array) -> pa.Array:
121+
return OverThresholdUDF(2)(values)
122+
123+
df2 = df.select(udf_with_param(column("a")))
124+
result = df2.collect()[0].column(0)
125+
126+
assert result == pa.array([False, True, True])

0 commit comments

Comments
 (0)