Skip to content

Commit f68c0ef

Browse files
committed
Rename decorators back to udf and udaf, update documentations
1 parent 660035d commit f68c0ef

File tree

4 files changed

+180
-124
lines changed

4 files changed

+180
-124
lines changed

python/datafusion/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,7 @@ def lit(value):
118118

119119

120120
udf = ScalarUDF.udf
121-
udf_decorator = ScalarUDF.udf_decorator
122121

123122
udaf = AggregateUDF.udaf
124-
udaf_decorator = AggregateUDF.udaf_decorator
125123

126124
udwf = WindowUDF.udwf

python/datafusion/udf.py

Lines changed: 173 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -111,65 +111,96 @@ def __call__(self, *args: Expr) -> Expr:
111111
args_raw = [arg.expr for arg in args]
112112
return Expr(self._udf.__call__(*args_raw))
113113

114-
@staticmethod
115-
def udf(
116-
func: Callable[..., _R],
117-
input_types: list[pyarrow.DataType],
118-
return_type: _R,
119-
volatility: Volatility | str,
120-
name: Optional[str] = None,
121-
) -> ScalarUDF:
122-
"""Create a new User-Defined Function.
114+
class udf:
115+
"""Create a new User-Defined Function (UDF).
116+
117+
This class can be used both as a **function** and as a **decorator**.
118+
119+
Usage:
120+
- **As a function**: Call `udf(func, input_types, return_type, volatility, name)`.
121+
- **As a decorator**: Use `@udf(input_types, return_type, volatility, name)`.
122+
In this case, do **not** pass `func` explicitly.
123123
124124
Args:
125-
func: A callable python function.
126-
input_types: The data types of the arguments to ``func``. This list
127-
must be of the same length as the number of arguments.
128-
return_type: The data type of the return value from the python
129-
function.
130-
volatility: See ``Volatility`` for allowed values.
131-
name: A descriptive name for the function.
125+
func (Callable, optional): **Only needed when calling as a function.**
126+
Skip this argument when using `udf` as a decorator.
127+
input_types (list[pyarrow.DataType]): The data types of the arguments
128+
to `func`. This list must be of the same length as the number of arguments.
129+
return_type (_R): The data type of the return value from the function.
130+
volatility (Volatility | str): See `Volatility` for allowed values.
131+
name (Optional[str]): A descriptive name for the function.
132132
133133
Returns:
134-
A user-defined aggregate function, which can be used in either data
135-
aggregation or window function calls.
134+
A user-defined function that can be used in SQL expressions,
135+
data aggregation, or window function calls.
136+
137+
Example:
138+
**Using `udf` as a function:**
139+
```python
140+
def double_func(x):
141+
return x * 2
142+
double_udf = udf(double_func, [pyarrow.int32()], pyarrow.int32(), "volatile", "double_it")
143+
```
144+
145+
**Using `udf` as a decorator:**
146+
```python
147+
@udf([pyarrow.int32()], pyarrow.int32(), "volatile", "double_it")
148+
def double_udf(x):
149+
return x * 2
150+
```
136151
"""
137-
if not callable(func):
138-
raise TypeError("`func` argument must be callable")
139-
if name is None:
140-
if hasattr(func, "__qualname__"):
141-
name = func.__qualname__.lower()
152+
def __new__(cls, *args, **kwargs):
153+
if args and callable(args[0]):
154+
# Case 1: Used as a function, require the first parameter to be callable
155+
return cls._function(*args, **kwargs)
142156
else:
143-
name = func.__class__.__name__.lower()
144-
return ScalarUDF(
145-
name=name,
146-
func=func,
147-
input_types=input_types,
148-
return_type=return_type,
149-
volatility=volatility,
150-
)
151-
152-
@staticmethod
153-
def udf_decorator(
154-
input_types: list[pyarrow.DataType],
155-
return_type: _R,
156-
volatility: Volatility | str,
157-
name: Optional[str] = None
158-
):
159-
def decorator(func):
160-
udf_caller = ScalarUDF.udf(
161-
func,
162-
input_types,
163-
return_type,
164-
volatility,
165-
name
157+
# Case 2: Used as a decorator with parameters
158+
return cls._decorator(*args, **kwargs)
159+
160+
@staticmethod
161+
def _function(
162+
func: Callable[..., _R],
163+
input_types: list[pyarrow.DataType],
164+
return_type: _R,
165+
volatility: Volatility | str,
166+
name: Optional[str] = None,
167+
) -> ScalarUDF:
168+
if not callable(func):
169+
raise TypeError("`func` argument must be callable")
170+
if name is None:
171+
if hasattr(func, "__qualname__"):
172+
name = func.__qualname__.lower()
173+
else:
174+
name = func.__class__.__name__.lower()
175+
return ScalarUDF(
176+
name=name,
177+
func=func,
178+
input_types=input_types,
179+
return_type=return_type,
180+
volatility=volatility,
166181
)
167-
168-
@functools.wraps(func)
169-
def wrapper(*args, **kwargs):
170-
return udf_caller(*args, **kwargs)
171-
return wrapper
172-
return decorator
182+
183+
@staticmethod
184+
def _decorator(
185+
input_types: list[pyarrow.DataType],
186+
return_type: _R,
187+
volatility: Volatility | str,
188+
name: Optional[str] = None
189+
):
190+
def decorator(func):
191+
udf_caller = ScalarUDF.udf(
192+
func,
193+
input_types,
194+
return_type,
195+
volatility,
196+
name
197+
)
198+
199+
@functools.wraps(func)
200+
def wrapper(*args, **kwargs):
201+
return udf_caller(*args, **kwargs)
202+
return wrapper
203+
return decorator
173204

174205
class Accumulator(metaclass=ABCMeta):
175206
"""Defines how an :py:class:`AggregateUDF` accumulates values."""
@@ -234,25 +265,27 @@ def __call__(self, *args: Expr) -> Expr:
234265
args_raw = [arg.expr for arg in args]
235266
return Expr(self._udaf.__call__(*args_raw))
236267

237-
@staticmethod
238-
def udaf(
239-
accum: Callable[[], Accumulator],
240-
input_types: pyarrow.DataType | list[pyarrow.DataType],
241-
return_type: pyarrow.DataType,
242-
state_type: list[pyarrow.DataType],
243-
volatility: Volatility | str,
244-
name: Optional[str] = None,
245-
) -> AggregateUDF:
246-
"""Create a new User-Defined Aggregate Function.
268+
class udaf:
269+
"""Create a new User-Defined Aggregate Function (UDAF).
247270
248-
If your :py:class:`Accumulator` can be instantiated with no arguments, you
249-
can simply pass it's type as ``accum``. If you need to pass additional arguments
250-
to it's constructor, you can define a lambda or a factory method. During runtime
251-
the :py:class:`Accumulator` will be constructed for every instance in
252-
which this UDAF is used. The following examples are all valid.
271+
This class allows you to define an **aggregate function** that can be used in data
272+
aggregation or window function calls.
253273
254-
.. code-block:: python
274+
Usage:
275+
- **As a function**: Call `udaf(accum, input_types, return_type, state_type,
276+
volatility, name)`.
277+
- **As a decorator**: Use `@udaf(input_types, return_type, state_type,
278+
volatility, name)`.
279+
When using `udaf` as a decorator, **do not pass `accum` explicitly**.
255280
281+
**Function example:**
282+
283+
If your `:py:class:Accumulator` can be instantiated with no arguments, you can
284+
simply pass it's type as `accum`. If you need to pass additional arguments to
285+
it's constructor, you can define a lambda or a factory method. During runtime the
286+
`:py:class:Accumulator` will be constructed for every instance in which this UDAF is
287+
used. The following examples are all valid.
288+
```
256289
import pyarrow as pa
257290
import pyarrow.compute as pc
258291
@@ -278,61 +311,89 @@ def sum_bias_10() -> Summarize:
278311
udaf1 = udaf(Summarize, pa.float64(), pa.float64(), [pa.float64()], "immutable")
279312
udaf2 = udaf(sum_bias_10, pa.float64(), pa.float64(), [pa.float64()], "immutable")
280313
udaf3 = udaf(lambda: Summarize(20.0), pa.float64(), pa.float64(), [pa.float64()], "immutable")
314+
```
315+
316+
**Decorator example:**
317+
```
318+
@udaf(pa.float64(), pa.float64(), [pa.float64()], "immutable")
319+
def udf4() -> Summarize:
320+
return Summarize(10.0)
321+
```
281322
282323
Args:
283-
accum: The accumulator python function.
284-
input_types: The data types of the arguments to ``accum``.
324+
accum: The accumulator python function. **Only needed when calling as a function. Skip this argument when using `udaf` as a decorator.**
325+
input_types: The data types of the arguments to `accum.
285326
return_type: The data type of the return value.
286327
state_type: The data types of the intermediate accumulation.
287-
volatility: See :py:class:`Volatility` for allowed values.
328+
volatility: See :py:class:Volatility for allowed values.
288329
name: A descriptive name for the function.
289330
290331
Returns:
291332
A user-defined aggregate function, which can be used in either data
292333
aggregation or window function calls.
293-
""" # noqa W505
294-
if not callable(accum):
295-
raise TypeError("`func` must be callable.")
296-
if not isinstance(accum.__call__(), Accumulator):
297-
raise TypeError(
298-
"Accumulator must implement the abstract base class Accumulator"
299-
)
300-
if name is None:
301-
name = accum.__call__().__class__.__qualname__.lower()
302-
if isinstance(input_types, pyarrow.DataType):
303-
input_types = [input_types]
304-
return AggregateUDF(
305-
name=name,
306-
accumulator=accum,
307-
input_types=input_types,
308-
return_type=return_type,
309-
state_type=state_type,
310-
volatility=volatility,
311-
)
312-
313-
@staticmethod
314-
def udaf_decorator(
315-
input_types: pyarrow.DataType | list[pyarrow.DataType],
316-
return_type: pyarrow.DataType,
317-
state_type: list[pyarrow.DataType],
318-
volatility: Volatility | str,
319-
name: Optional[str] = None
320-
):
321-
def decorator(accum: Callable[[], Accumulator]):
322-
udaf_caller = AggregateUDF.udaf(
323-
accum,
324-
input_types,
325-
return_type,
326-
state_type,
327-
volatility,
328-
name
334+
"""
335+
336+
337+
338+
def __new__(cls, *args, **kwargs):
339+
if args and callable(args[0]):
340+
# Case 1: Used as a function, require the first parameter to be callable
341+
return cls._function(*args, **kwargs)
342+
else:
343+
# Case 2: Used as a decorator with parameters
344+
return cls._decorator(*args, **kwargs)
345+
346+
@staticmethod
347+
def _function(
348+
accum: Callable[[], Accumulator],
349+
input_types: pyarrow.DataType | list[pyarrow.DataType],
350+
return_type: pyarrow.DataType,
351+
state_type: list[pyarrow.DataType],
352+
volatility: Volatility | str,
353+
name: Optional[str] = None,
354+
) -> AggregateUDF:
355+
if not callable(accum):
356+
raise TypeError("`func` must be callable.")
357+
if not isinstance(accum.__call__(), Accumulator):
358+
raise TypeError(
359+
"Accumulator must implement the abstract base class Accumulator"
360+
)
361+
if name is None:
362+
name = accum.__call__().__class__.__qualname__.lower()
363+
if isinstance(input_types, pyarrow.DataType):
364+
input_types = [input_types]
365+
return AggregateUDF(
366+
name=name,
367+
accumulator=accum,
368+
input_types=input_types,
369+
return_type=return_type,
370+
state_type=state_type,
371+
volatility=volatility,
329372
)
330373

331-
@functools.wraps(accum)
332-
def wrapper(*args, **kwargs):
333-
return udaf_caller(*args, **kwargs)
334-
return wrapper
335-
return decorator
374+
@staticmethod
375+
def _decorator(
376+
input_types: pyarrow.DataType | list[pyarrow.DataType],
377+
return_type: pyarrow.DataType,
378+
state_type: list[pyarrow.DataType],
379+
volatility: Volatility | str,
380+
name: Optional[str] = None
381+
):
382+
def decorator(accum: Callable[[], Accumulator]):
383+
udaf_caller = AggregateUDF.udaf(
384+
accum,
385+
input_types,
386+
return_type,
387+
state_type,
388+
volatility,
389+
name
390+
)
391+
392+
@functools.wraps(accum)
393+
def wrapper(*args, **kwargs):
394+
return udaf_caller(*args, **kwargs)
395+
return wrapper
396+
return decorator
336397

337398

338399

python/tests/test_udaf.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import pyarrow as pa
2121
import pyarrow.compute as pc
2222
import pytest
23-
from datafusion import Accumulator, column, udaf, udaf_decorator
23+
from datafusion import Accumulator, column, udaf
2424

2525

2626
class Summarize(Accumulator):
@@ -118,10 +118,7 @@ def test_udaf_aggregate(df):
118118

119119
def test_udaf_decorator_aggregate(df):
120120

121-
@udaf_decorator(pa.float64(),
122-
pa.float64(),
123-
[pa.float64()],
124-
"immutable")
121+
@udaf(pa.float64(), pa.float64(), [pa.float64()], "immutable")
125122
def summarize():
126123
return Summarize()
127124

@@ -169,7 +166,7 @@ def test_udaf_aggregate_with_arguments(df):
169166
def test_udaf_decorator_aggregate_with_arguments(df):
170167
bias = 10.0
171168

172-
@udaf_decorator(pa.float64(),
169+
@udaf(pa.float64(),
173170
pa.float64(),
174171
[pa.float64()],
175172
"immutable")

0 commit comments

Comments
 (0)