2323from datafusion .expr import Expr
2424from typing import Callable , TYPE_CHECKING , TypeVar
2525from abc import ABCMeta , abstractmethod
26- from typing import List
26+ from typing import List , Optional
2727from enum import Enum
2828import pyarrow
2929
@@ -84,16 +84,18 @@ class ScalarUDF:
8484
8585 def __init__ (
8686 self ,
87- name : str | None ,
87+ name : Optional [ str ] ,
8888 func : Callable [..., _R ],
89- input_types : list [pyarrow .DataType ],
89+ input_types : pyarrow . DataType | list [pyarrow .DataType ],
9090 return_type : _R ,
9191 volatility : Volatility | str ,
9292 ) -> None :
9393 """Instantiate a scalar user-defined function (UDF).
9494
9595 See helper method :py:func:`udf` for argument details.
9696 """
97+ if isinstance (input_types , pyarrow .DataType ):
98+ input_types = [input_types ]
9799 self ._udf = df_internal .ScalarUDF (
98100 name , func , input_types , return_type , str (volatility )
99101 )
@@ -104,16 +106,16 @@ def __call__(self, *args: Expr) -> Expr:
104106 This function is not typically called by an end user. These calls will
105107 occur during the evaluation of the dataframe.
106108 """
107- args = [arg .expr for arg in args ]
108- return Expr (self ._udf .__call__ (* args ))
109+ args_raw = [arg .expr for arg in args ]
110+ return Expr (self ._udf .__call__ (* args_raw ))
109111
110112 @staticmethod
111113 def udf (
112114 func : Callable [..., _R ],
113115 input_types : list [pyarrow .DataType ],
114116 return_type : _R ,
115117 volatility : Volatility | str ,
116- name : str | None = None ,
118+ name : Optional [ str ] = None ,
117119 ) -> ScalarUDF :
118120 """Create a new User-Defined Function.
119121
@@ -133,7 +135,10 @@ def udf(
133135 if not callable (func ):
134136 raise TypeError ("`func` argument must be callable" )
135137 if name is None :
136- name = func .__qualname__ .lower ()
138+ if hasattr (func , "__qualname__" ):
139+ name = func .__qualname__ .lower ()
140+ else :
141+ name = func .__class__ .__name__ .lower ()
137142 return ScalarUDF (
138143 name = name ,
139144 func = func ,
@@ -167,10 +172,6 @@ def evaluate(self) -> pyarrow.Scalar:
167172 pass
168173
169174
170- if TYPE_CHECKING :
171- _A = TypeVar ("_A" , bound = (Callable [..., _R ], Accumulator ))
172-
173-
174175class AggregateUDF :
175176 """Class for performing scalar user-defined functions (UDF).
176177
@@ -180,10 +181,10 @@ class AggregateUDF:
180181
181182 def __init__ (
182183 self ,
183- name : str | None ,
184- accumulator : _A ,
184+ name : Optional [ str ] ,
185+ accumulator : Callable [[], Accumulator ] ,
185186 input_types : list [pyarrow .DataType ],
186- return_type : _R ,
187+ return_type : pyarrow . DataType ,
187188 state_type : list [pyarrow .DataType ],
188189 volatility : Volatility | str ,
189190 ) -> None :
@@ -193,7 +194,12 @@ def __init__(
193194 descriptions.
194195 """
195196 self ._udaf = df_internal .AggregateUDF (
196- name , accumulator , input_types , return_type , state_type , str (volatility )
197+ name ,
198+ accumulator ,
199+ input_types ,
200+ return_type ,
201+ state_type ,
202+ str (volatility ),
197203 )
198204
199205 def __call__ (self , * args : Expr ) -> Expr :
@@ -202,21 +208,52 @@ def __call__(self, *args: Expr) -> Expr:
202208 This function is not typically called by an end user. These calls will
203209 occur during the evaluation of the dataframe.
204210 """
205- args = [arg .expr for arg in args ]
206- return Expr (self ._udaf .__call__ (* args ))
211+ args_raw = [arg .expr for arg in args ]
212+ return Expr (self ._udaf .__call__ (* args_raw ))
207213
208214 @staticmethod
209215 def udaf (
210- accum : _A ,
211- input_types : list [pyarrow .DataType ],
212- return_type : _R ,
216+ accum : Callable [[], Accumulator ] ,
217+ input_types : pyarrow . DataType | list [pyarrow .DataType ],
218+ return_type : pyarrow . DataType ,
213219 state_type : list [pyarrow .DataType ],
214220 volatility : Volatility | str ,
215- name : str | None = None ,
221+ name : Optional [ str ] = None ,
216222 ) -> AggregateUDF :
217223 """Create a new User-Defined Aggregate Function.
218224
219- The accumulator function must be callable and implement :py:class:`Accumulator`.
225+ If your :py:class:`Accumulator` can be instantiated with no arguments, you
226+ can simply pass it's type as ``accum``. If you need to pass additional arguments
227+ to it's constructor, you can define a lambda or a factory method. During runtime
228+ the :py:class:`Accumulator` will be constructed for every instance in
229+ which this UDAF is used. The following examples are all valid.
230+
231+ .. code-block:: python
232+ import pyarrow as pa
233+ import pyarrow.compute as pc
234+
235+ class Summarize(Accumulator):
236+ def __init__(self, bias: float = 0.0):
237+ self._sum = pa.scalar(bias)
238+
239+ def state(self) -> List[pa.Scalar]:
240+ return [self._sum]
241+
242+ def update(self, values: pa.Array) -> None:
243+ self._sum = pa.scalar(self._sum.as_py() + pc.sum(values).as_py())
244+
245+ def merge(self, states: List[pa.Array]) -> None:
246+ self._sum = pa.scalar(self._sum.as_py() + pc.sum(states[0]).as_py())
247+
248+ def evaluate(self) -> pa.Scalar:
249+ return self._sum
250+
251+ def sum_bias_10() -> Summarize:
252+ return Summarize(10.0)
253+
254+ udaf1 = udaf(Summarize, pa.float64(), pa.float64(), [pa.float64()], "immutable")
255+ udaf2 = udaf(sum_bias_10, pa.float64(), pa.float64(), [pa.float64()], "immutable")
256+ udaf3 = udaf(lambda: Summarize(20.0), pa.float64(), pa.float64(), [pa.float64()], "immutable")
220257
221258 Args:
222259 accum: The accumulator python function.
@@ -229,14 +266,16 @@ def udaf(
229266 Returns:
230267 A user-defined aggregate function, which can be used in either data
231268 aggregation or window function calls.
232- """
233- if not issubclass (accum , Accumulator ):
269+ """ # noqa W505
270+ if not callable (accum ):
271+ raise TypeError ("`func` must be callable." )
272+ if not isinstance (accum .__call__ (), Accumulator ):
234273 raise TypeError (
235- "`accum` must implement the abstract base class Accumulator"
274+ "Accumulator must implement the abstract base class Accumulator"
236275 )
237276 if name is None :
238- name = accum .__qualname__ .lower ()
239- if isinstance (input_types , pyarrow .lib . DataType ):
277+ name = accum .__call__ (). __class__ . __qualname__ .lower ()
278+ if isinstance (input_types , pyarrow .DataType ):
240279 input_types = [input_types ]
241280 return AggregateUDF (
242281 name = name ,
@@ -421,8 +460,8 @@ class WindowUDF:
421460
422461 def __init__ (
423462 self ,
424- name : str | None ,
425- func : WindowEvaluator ,
463+ name : Optional [ str ] ,
464+ func : Callable [[], WindowEvaluator ] ,
426465 input_types : list [pyarrow .DataType ],
427466 return_type : pyarrow .DataType ,
428467 volatility : Volatility | str ,
@@ -447,30 +486,57 @@ def __call__(self, *args: Expr) -> Expr:
447486
448487 @staticmethod
449488 def udwf (
450- func : WindowEvaluator ,
489+ func : Callable [[], WindowEvaluator ] ,
451490 input_types : pyarrow .DataType | list [pyarrow .DataType ],
452491 return_type : pyarrow .DataType ,
453492 volatility : Volatility | str ,
454- name : str | None = None ,
493+ name : Optional [ str ] = None ,
455494 ) -> WindowUDF :
456495 """Create a new User-Defined Window Function.
457496
497+ If your :py:class:`WindowEvaluator` can be instantiated with no arguments, you
498+ can simply pass it's type as ``func``. If you need to pass additional arguments
499+ to it's constructor, you can define a lambda or a factory method. During runtime
500+ the :py:class:`WindowEvaluator` will be constructed for every instance in
501+ which this UDWF is used. The following examples are all valid.
502+
503+ .. code-block:: python
504+
505+ import pyarrow as pa
506+
507+ class BiasedNumbers(WindowEvaluator):
508+ def __init__(self, start: int = 0) -> None:
509+ self.start = start
510+
511+ def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array:
512+ return pa.array([self.start + i for i in range(num_rows)])
513+
514+ def bias_10() -> BiasedNumbers:
515+ return BiasedNumbers(10)
516+
517+ udwf1 = udwf(BiasedNumbers, pa.int64(), pa.int64(), "immutable")
518+ udwf2 = udwf(bias_10, pa.int64(), pa.int64(), "immutable")
519+ udwf3 = udwf(lambda: BiasedNumbers(20), pa.int64(), pa.int64(), "immutable")
520+
458521 Args:
459- func: The python function.
522+ func: A callable to create the window function.
460523 input_types: The data types of the arguments to ``func``.
461524 return_type: The data type of the return value.
462525 volatility: See :py:class:`Volatility` for allowed values.
526+ arguments: A list of arguments to pass in to the __init__ method for accum.
463527 name: A descriptive name for the function.
464528
465529 Returns:
466530 A user-defined window function.
467- """
468- if not isinstance (func , WindowEvaluator ):
531+ """ # noqa W505
532+ if not callable (func ):
533+ raise TypeError ("`func` must be callable." )
534+ if not isinstance (func .__call__ (), WindowEvaluator ):
469535 raise TypeError (
470536 "`func` must implement the abstract base class WindowEvaluator"
471537 )
472538 if name is None :
473- name = func .__class__ .__qualname__ .lower ()
539+ name = func .__call__ (). __class__ .__qualname__ .lower ()
474540 if isinstance (input_types , pyarrow .DataType ):
475541 input_types = [input_types ]
476542 return WindowUDF (
0 commit comments