Skip to content

Commit 250baea

Browse files
committed
Switching to use factory methods for udaf and udwf
1 parent b7468dc commit 250baea

File tree

5 files changed

+114
-77
lines changed

5 files changed

+114
-77
lines changed

python/datafusion/udf.py

Lines changed: 81 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121

2222
import datafusion._internal as df_internal
2323
from datafusion.expr import Expr
24-
from typing import Callable, TYPE_CHECKING, TypeVar, Type
24+
from typing import Callable, TYPE_CHECKING, TypeVar
2525
from abc import ABCMeta, abstractmethod
26-
from typing import List, Any, Optional
26+
from typing import List, Optional
2727
from enum import Enum
2828
import pyarrow
2929

@@ -84,7 +84,7 @@ class ScalarUDF:
8484

8585
def __init__(
8686
self,
87-
name: str | None,
87+
name: Optional[str],
8888
func: Callable[..., _R],
8989
input_types: pyarrow.DataType | list[pyarrow.DataType],
9090
return_type: _R,
@@ -115,7 +115,7 @@ def udf(
115115
input_types: list[pyarrow.DataType],
116116
return_type: _R,
117117
volatility: Volatility | str,
118-
name: str | None = None,
118+
name: Optional[str] = None,
119119
) -> ScalarUDF:
120120
"""Create a new User-Defined Function.
121121
@@ -181,13 +181,12 @@ class AggregateUDF:
181181

182182
def __init__(
183183
self,
184-
name: str | None,
185-
accumulator: Type[Accumulator],
184+
name: Optional[str],
185+
accumulator: Callable[[], Accumulator],
186186
input_types: list[pyarrow.DataType],
187187
return_type: pyarrow.DataType,
188188
state_type: list[pyarrow.DataType],
189189
volatility: Volatility | str,
190-
arguments: list[Any],
191190
) -> None:
192191
"""Instantiate a user-defined aggregate function (UDAF).
193192
@@ -201,7 +200,6 @@ def __init__(
201200
return_type,
202201
state_type,
203202
str(volatility),
204-
arguments,
205203
)
206204

207205
def __call__(self, *args: Expr) -> Expr:
@@ -215,48 +213,77 @@ def __call__(self, *args: Expr) -> Expr:
215213

216214
@staticmethod
217215
def udaf(
218-
accum: Type[Accumulator],
216+
accum: Callable[[], Accumulator],
219217
input_types: pyarrow.DataType | list[pyarrow.DataType],
220218
return_type: pyarrow.DataType,
221219
state_type: list[pyarrow.DataType],
222220
volatility: Volatility | str,
223-
arguments: Optional[list[Any]] = None,
224-
name: str | None = None,
221+
name: Optional[str] = None,
225222
) -> AggregateUDF:
226223
"""Create a new User-Defined Aggregate Function.
227224
228-
The accumulator function must be callable and implement :py:class:`Accumulator`.
225+
If your :py:class:`Accumulator` can be instantiated with no arguments, you
226+
can simply pass it's type as ``accum``. If you need to pass additional arguments
227+
to it's constructor, you can define a lambda or a factory method. During runtime
228+
the :py:class:`Accumulator` will be constructed for every instance in
229+
which this UDAF is used. The following examples are all valid.
230+
231+
.. code-block:: python
232+
import pyarrow as pa
233+
import pyarrow.compute as pc
234+
235+
class Summarize(Accumulator):
236+
def __init__(self, bias: float = 0.0):
237+
self._sum = pa.scalar(bias)
238+
239+
def state(self) -> List[pa.Scalar]:
240+
return [self._sum]
241+
242+
def update(self, values: pa.Array) -> None:
243+
self._sum = pa.scalar(self._sum.as_py() + pc.sum(values).as_py())
244+
245+
def merge(self, states: List[pa.Array]) -> None:
246+
self._sum = pa.scalar(self._sum.as_py() + pc.sum(states[0]).as_py())
247+
248+
def evaluate(self) -> pa.Scalar:
249+
return self._sum
250+
251+
def sum_bias_10() -> Summarize:
252+
return Summarize(10.0)
253+
254+
udaf1 = udaf(Summarize, pa.float64(), pa.float64(), [pa.float64()], "immutable")
255+
udaf2 = udaf(sum_bias_10, pa.float64(), pa.float64(), [pa.float64()], "immutable")
256+
udaf3 = udaf(lambda: Summarize(20.0), pa.float64(), pa.float64(), [pa.float64()], "immutable")
229257
230258
Args:
231259
accum: The accumulator python function.
232260
input_types: The data types of the arguments to ``accum``.
233261
return_type: The data type of the return value.
234262
state_type: The data types of the intermediate accumulation.
235263
volatility: See :py:class:`Volatility` for allowed values.
236-
arguments: A list of arguments to pass in to the __init__ method for accum.
237264
name: A descriptive name for the function.
238265
239266
Returns:
240267
A user-defined aggregate function, which can be used in either data
241268
aggregation or window function calls.
242-
"""
243-
if not issubclass(accum, Accumulator):
269+
""" # noqa W505
270+
if not callable(accum):
271+
raise TypeError("`func` must be callable.")
272+
if not isinstance(accum.__call__(), Accumulator):
244273
raise TypeError(
245-
"`accum` must implement the abstract base class Accumulator"
274+
"Accumulator must implement the abstract base class Accumulator"
246275
)
247276
if name is None:
248-
name = accum.__qualname__.lower()
277+
name = accum.__call__().__class__.__qualname__.lower()
249278
if isinstance(input_types, pyarrow.DataType):
250279
input_types = [input_types]
251-
arguments = [] if arguments is None else arguments
252280
return AggregateUDF(
253281
name=name,
254282
accumulator=accum,
255283
input_types=input_types,
256284
return_type=return_type,
257285
state_type=state_type,
258286
volatility=volatility,
259-
arguments=arguments,
260287
)
261288

262289

@@ -433,20 +460,19 @@ class WindowUDF:
433460

434461
def __init__(
435462
self,
436-
name: str | None,
437-
func: Type[WindowEvaluator],
463+
name: Optional[str],
464+
func: Callable[[], WindowEvaluator],
438465
input_types: list[pyarrow.DataType],
439466
return_type: pyarrow.DataType,
440467
volatility: Volatility | str,
441-
arguments: list[Any],
442468
) -> None:
443469
"""Instantiate a user-defined window function (UDWF).
444470
445471
See :py:func:`udwf` for a convenience function and argument
446472
descriptions.
447473
"""
448474
self._udwf = df_internal.WindowUDF(
449-
name, func, input_types, return_type, str(volatility), arguments
475+
name, func, input_types, return_type, str(volatility)
450476
)
451477

452478
def __call__(self, *args: Expr) -> Expr:
@@ -460,17 +486,40 @@ def __call__(self, *args: Expr) -> Expr:
460486

461487
@staticmethod
462488
def udwf(
463-
func: Type[WindowEvaluator],
489+
func: Callable[[], WindowEvaluator],
464490
input_types: pyarrow.DataType | list[pyarrow.DataType],
465491
return_type: pyarrow.DataType,
466492
volatility: Volatility | str,
467-
arguments: Optional[list[Any]] = None,
468-
name: str | None = None,
493+
name: Optional[str] = None,
469494
) -> WindowUDF:
470495
"""Create a new User-Defined Window Function.
471496
497+
If your :py:class:`WindowEvaluator` can be instantiated with no arguments, you
498+
can simply pass it's type as ``func``. If you need to pass additional arguments
499+
to it's constructor, you can define a lambda or a factory method. During runtime
500+
the :py:class:`WindowEvaluator` will be constructed for every instance in
501+
which this UDWF is used. The following examples are all valid.
502+
503+
.. code-block:: python
504+
505+
import pyarrow as pa
506+
507+
class BiasedNumbers(WindowEvaluator):
508+
def __init__(self, start: int = 0) -> None:
509+
self.start = start
510+
511+
def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array:
512+
return pa.array([self.start + i for i in range(num_rows)])
513+
514+
def bias_10() -> BiasedNumbers:
515+
return BiasedNumbers(10)
516+
517+
udwf1 = udwf(BiasedNumbers, pa.int64(), pa.int64(), "immutable")
518+
udwf2 = udwf(bias_10, pa.int64(), pa.int64(), "immutable")
519+
udwf3 = udwf(lambda: BiasedNumbers(20), pa.int64(), pa.int64(), "immutable")
520+
472521
Args:
473-
func: The python function.
522+
func: A callable to create the window function.
474523
input_types: The data types of the arguments to ``func``.
475524
return_type: The data type of the return value.
476525
volatility: See :py:class:`Volatility` for allowed values.
@@ -479,21 +528,21 @@ def udwf(
479528
480529
Returns:
481530
A user-defined window function.
482-
"""
483-
if not issubclass(func, WindowEvaluator):
531+
""" # noqa W505
532+
if not callable(func):
533+
raise TypeError("`func` must be callable.")
534+
if not isinstance(func.__call__(), WindowEvaluator):
484535
raise TypeError(
485536
"`func` must implement the abstract base class WindowEvaluator"
486537
)
487538
if name is None:
488-
name = func.__class__.__qualname__.lower()
539+
name = func.__call__().__class__.__qualname__.lower()
489540
if isinstance(input_types, pyarrow.DataType):
490541
input_types = [input_types]
491-
arguments = [] if arguments is None else arguments
492542
return WindowUDF(
493543
name=name,
494544
func=func,
495545
input_types=input_types,
496546
return_type=return_type,
497547
volatility=volatility,
498-
arguments=arguments,
499548
)

python/tests/test_udaf.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -79,22 +79,19 @@ def test_errors(df):
7979
volatility="immutable",
8080
)
8181

82-
accum = udaf(
83-
MissingMethods,
84-
pa.int64(),
85-
pa.int64(),
86-
[pa.int64()],
87-
volatility="immutable",
88-
)
89-
df = df.aggregate([], [accum(column("a"))])
90-
9182
msg = (
9283
"Can't instantiate abstract class MissingMethods (without an implementation "
9384
"for abstract methods 'evaluate', 'merge', 'update'|with abstract methods "
9485
"evaluate, merge, update)"
9586
)
9687
with pytest.raises(Exception, match=msg):
97-
df.collect()
88+
accum = udaf( # noqa F841
89+
MissingMethods,
90+
pa.int64(),
91+
pa.int64(),
92+
[pa.int64()],
93+
volatility="immutable",
94+
)
9895

9996

10097
def test_udaf_aggregate(df):
@@ -125,12 +122,11 @@ def test_udaf_aggregate_with_arguments(df):
125122
bias = 10.0
126123

127124
summarize = udaf(
128-
Summarize,
125+
lambda: Summarize(bias),
129126
pa.float64(),
130127
pa.float64(),
131128
[pa.float64()],
132129
volatility="immutable",
133-
arguments=[bias],
134130
)
135131

136132
df1 = df.aggregate([], [summarize(column("a"))])
@@ -140,6 +136,13 @@ def test_udaf_aggregate_with_arguments(df):
140136

141137
assert result.column(0) == pa.array([bias + 1.0 + 2.0 + 3.0])
142138

139+
df2 = df.aggregate([], [summarize(column("a"))])
140+
141+
# Run a second time to ensure the state is properly reset
142+
result = df2.collect()[0]
143+
144+
assert result.column(0) == pa.array([bias + 1.0 + 2.0 + 3.0])
145+
143146

144147
def test_group_by(df):
145148
summarize = udaf(

0 commit comments

Comments
 (0)