2121
2222import  datafusion ._internal  as  df_internal 
2323from  datafusion .expr  import  Expr 
24- from  typing  import  Callable , TYPE_CHECKING , TypeVar ,  Type 
24+ from  typing  import  Callable , TYPE_CHECKING , TypeVar 
2525from  abc  import  ABCMeta , abstractmethod 
26- from  typing  import  List , Any ,  Optional 
26+ from  typing  import  List , Optional 
2727from  enum  import  Enum 
2828import  pyarrow 
2929
@@ -84,7 +84,7 @@ class ScalarUDF:
8484
8585    def  __init__ (
8686        self ,
87-         name : str   |   None ,
87+         name : Optional [ str ] ,
8888        func : Callable [..., _R ],
8989        input_types : pyarrow .DataType  |  list [pyarrow .DataType ],
9090        return_type : _R ,
@@ -115,7 +115,7 @@ def udf(
115115        input_types : list [pyarrow .DataType ],
116116        return_type : _R ,
117117        volatility : Volatility  |  str ,
118-         name : str   |   None  =  None ,
118+         name : Optional [ str ]  =  None ,
119119    ) ->  ScalarUDF :
120120        """Create a new User-Defined Function. 
121121
@@ -181,13 +181,12 @@ class AggregateUDF:
181181
182182    def  __init__ (
183183        self ,
184-         name : str   |   None ,
185-         accumulator : Type [ Accumulator ],
184+         name : Optional [ str ] ,
185+         accumulator : Callable [[],  Accumulator ],
186186        input_types : list [pyarrow .DataType ],
187187        return_type : pyarrow .DataType ,
188188        state_type : list [pyarrow .DataType ],
189189        volatility : Volatility  |  str ,
190-         arguments : list [Any ],
191190    ) ->  None :
192191        """Instantiate a user-defined aggregate function (UDAF). 
193192
@@ -201,7 +200,6 @@ def __init__(
201200            return_type ,
202201            state_type ,
203202            str (volatility ),
204-             arguments ,
205203        )
206204
207205    def  __call__ (self , * args : Expr ) ->  Expr :
@@ -215,48 +213,77 @@ def __call__(self, *args: Expr) -> Expr:
215213
216214    @staticmethod  
217215    def  udaf (
218-         accum : Type [ Accumulator ],
216+         accum : Callable [[],  Accumulator ],
219217        input_types : pyarrow .DataType  |  list [pyarrow .DataType ],
220218        return_type : pyarrow .DataType ,
221219        state_type : list [pyarrow .DataType ],
222220        volatility : Volatility  |  str ,
223-         arguments : Optional [list [Any ]] =  None ,
224-         name : str  |  None  =  None ,
221+         name : Optional [str ] =  None ,
225222    ) ->  AggregateUDF :
226223        """Create a new User-Defined Aggregate Function. 
227224
228-         The accumulator function must be callable and implement :py:class:`Accumulator`. 
225+         If your :py:class:`Accumulator` can be instantiated with no arguments, you 
226+         can simply pass it's type as ``accum``. If you need to pass additional arguments 
227+         to it's constructor, you can define a lambda or a factory method. During runtime 
228+         the :py:class:`Accumulator` will be constructed for every instance in 
229+         which this UDAF is used. The following examples are all valid. 
230+ 
231+         .. code-block:: python 
232+             import pyarrow as pa 
233+             import pyarrow.compute as pc 
234+ 
235+             class Summarize(Accumulator): 
236+                 def __init__(self, bias: float = 0.0): 
237+                     self._sum = pa.scalar(bias) 
238+ 
239+                 def state(self) -> List[pa.Scalar]: 
240+                     return [self._sum] 
241+ 
242+                 def update(self, values: pa.Array) -> None: 
243+                     self._sum = pa.scalar(self._sum.as_py() + pc.sum(values).as_py()) 
244+ 
245+                 def merge(self, states: List[pa.Array]) -> None: 
246+                     self._sum = pa.scalar(self._sum.as_py() + pc.sum(states[0]).as_py()) 
247+ 
248+                 def evaluate(self) -> pa.Scalar: 
249+                     return self._sum 
250+ 
251+             def sum_bias_10() -> Summarize: 
252+                 return Summarize(10.0) 
253+ 
254+             udaf1 = udaf(Summarize, pa.float64(), pa.float64(), [pa.float64()], "immutable") 
255+             udaf2 = udaf(sum_bias_10, pa.float64(), pa.float64(), [pa.float64()], "immutable") 
256+             udaf3 = udaf(lambda: Summarize(20.0), pa.float64(), pa.float64(), [pa.float64()], "immutable") 
229257
230258        Args: 
231259            accum: The accumulator python function. 
232260            input_types: The data types of the arguments to ``accum``. 
233261            return_type: The data type of the return value. 
234262            state_type: The data types of the intermediate accumulation. 
235263            volatility: See :py:class:`Volatility` for allowed values. 
236-             arguments: A list of arguments to pass in to the __init__ method for accum. 
237264            name: A descriptive name for the function. 
238265
239266        Returns: 
240267            A user-defined aggregate function, which can be used in either data 
241268            aggregation or window function calls. 
242-         """ 
243-         if  not  issubclass (accum , Accumulator ):
269+         """   # noqa W505 
270+         if  not  callable (accum ):
271+             raise  TypeError ("`func` must be callable." )
272+         if  not  isinstance (accum .__call__ (), Accumulator ):
244273            raise  TypeError (
245-                 "`accum`  must implement the abstract base class Accumulator" 
274+                 "Accumulator  must implement the abstract base class Accumulator" 
246275            )
247276        if  name  is  None :
248-             name  =  accum .__qualname__ .lower ()
277+             name  =  accum .__call__ (). __class__ . __qualname__ .lower ()
249278        if  isinstance (input_types , pyarrow .DataType ):
250279            input_types  =  [input_types ]
251-         arguments  =  [] if  arguments  is  None  else  arguments 
252280        return  AggregateUDF (
253281            name = name ,
254282            accumulator = accum ,
255283            input_types = input_types ,
256284            return_type = return_type ,
257285            state_type = state_type ,
258286            volatility = volatility ,
259-             arguments = arguments ,
260287        )
261288
262289
@@ -433,20 +460,19 @@ class WindowUDF:
433460
434461    def  __init__ (
435462        self ,
436-         name : str   |   None ,
437-         func : Type [ WindowEvaluator ],
463+         name : Optional [ str ] ,
464+         func : Callable [[],  WindowEvaluator ],
438465        input_types : list [pyarrow .DataType ],
439466        return_type : pyarrow .DataType ,
440467        volatility : Volatility  |  str ,
441-         arguments : list [Any ],
442468    ) ->  None :
443469        """Instantiate a user-defined window function (UDWF). 
444470
445471        See :py:func:`udwf` for a convenience function and argument 
446472        descriptions. 
447473        """ 
448474        self ._udwf  =  df_internal .WindowUDF (
449-             name , func , input_types , return_type , str (volatility ),  arguments 
475+             name , func , input_types , return_type , str (volatility )
450476        )
451477
452478    def  __call__ (self , * args : Expr ) ->  Expr :
@@ -460,17 +486,40 @@ def __call__(self, *args: Expr) -> Expr:
460486
461487    @staticmethod  
462488    def  udwf (
463-         func : Type [ WindowEvaluator ],
489+         func : Callable [[],  WindowEvaluator ],
464490        input_types : pyarrow .DataType  |  list [pyarrow .DataType ],
465491        return_type : pyarrow .DataType ,
466492        volatility : Volatility  |  str ,
467-         arguments : Optional [list [Any ]] =  None ,
468-         name : str  |  None  =  None ,
493+         name : Optional [str ] =  None ,
469494    ) ->  WindowUDF :
470495        """Create a new User-Defined Window Function. 
471496
497+         If your :py:class:`WindowEvaluator` can be instantiated with no arguments, you 
498+         can simply pass it's type as ``func``. If you need to pass additional arguments 
499+         to it's constructor, you can define a lambda or a factory method. During runtime 
500+         the :py:class:`WindowEvaluator` will be constructed for every instance in 
501+         which this UDWF is used. The following examples are all valid. 
502+ 
503+         .. code-block:: python 
504+ 
505+             import pyarrow as pa 
506+ 
507+             class BiasedNumbers(WindowEvaluator): 
508+                 def __init__(self, start: int = 0) -> None: 
509+                     self.start = start 
510+ 
511+                 def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array: 
512+                     return pa.array([self.start + i for i in range(num_rows)]) 
513+ 
514+             def bias_10() -> BiasedNumbers: 
515+                 return BiasedNumbers(10) 
516+ 
517+             udwf1 = udwf(BiasedNumbers, pa.int64(), pa.int64(), "immutable") 
518+             udwf2 = udwf(bias_10, pa.int64(), pa.int64(), "immutable") 
519+             udwf3 = udwf(lambda: BiasedNumbers(20), pa.int64(), pa.int64(), "immutable") 
520+ 
472521        Args: 
473-             func: The python  function. 
522+             func: A callable to create the window  function. 
474523            input_types: The data types of the arguments to ``func``. 
475524            return_type: The data type of the return value. 
476525            volatility: See :py:class:`Volatility` for allowed values. 
@@ -479,21 +528,21 @@ def udwf(
479528
480529        Returns: 
481530            A user-defined window function. 
482-         """ 
483-         if  not  issubclass (func , WindowEvaluator ):
531+         """   # noqa W505 
532+         if  not  callable (func ):
533+             raise  TypeError ("`func` must be callable." )
534+         if  not  isinstance (func .__call__ (), WindowEvaluator ):
484535            raise  TypeError (
485536                "`func` must implement the abstract base class WindowEvaluator" 
486537            )
487538        if  name  is  None :
488-             name  =  func .__class__ .__qualname__ .lower ()
539+             name  =  func .__call__ (). __class__ .__qualname__ .lower ()
489540        if  isinstance (input_types , pyarrow .DataType ):
490541            input_types  =  [input_types ]
491-         arguments  =  [] if  arguments  is  None  else  arguments 
492542        return  WindowUDF (
493543            name = name ,
494544            func = func ,
495545            input_types = input_types ,
496546            return_type = return_type ,
497547            volatility = volatility ,
498-             arguments = arguments ,
499548        )
0 commit comments