1919
2020from  __future__ import  annotations 
2121
22+ import  functools 
2223from  abc  import  ABCMeta , abstractmethod 
2324from  enum  import  Enum 
2425from  typing  import  TYPE_CHECKING , Callable , List , Optional , TypeVar 
@@ -110,43 +111,102 @@ def __call__(self, *args: Expr) -> Expr:
110111        args_raw  =  [arg .expr  for  arg  in  args ]
111112        return  Expr (self ._udf .__call__ (* args_raw ))
112113
113-     @staticmethod  
114-     def  udf (
115-         func : Callable [..., _R ],
116-         input_types : list [pyarrow .DataType ],
117-         return_type : _R ,
118-         volatility : Volatility  |  str ,
119-         name : Optional [str ] =  None ,
120-     ) ->  ScalarUDF :
121-         """Create a new User-Defined Function. 
114+     class  udf :
115+         """Create a new User-Defined Function (UDF). 
116+ 
117+         This class can be used both as a **function** and as a **decorator**. 
118+ 
119+         Usage: 
120+             - **As a function**: Call `udf(func, input_types, return_type, volatility, 
121+               name)`. 
122+             - **As a decorator**: Use `@udf(input_types, return_type, volatility, 
123+               name)`. In this case, do **not** pass `func` explicitly. 
122124
123125        Args: 
124-             func: A callable python function. 
125-             input_types: The data types of the arguments to ``func``. This list 
126-                 must be of the same length as the number of arguments. 
127-             return_type: The data type of the return value from the python 
128-                 function. 
129-             volatility: See ``Volatility`` for allowed values. 
130-             name: A descriptive name for the function. 
126+             func (Callable, optional): **Only needed when calling as a function.** 
127+                 Skip this argument when using `udf` as a decorator. 
128+             input_types (list[pyarrow.DataType]): The data types of the arguments 
129+                 to `func`. This list must be of the same length as the number of 
130+                 arguments. 
131+             return_type (_R): The data type of the return value from the function. 
132+             volatility (Volatility | str): See `Volatility` for allowed values. 
133+             name (Optional[str]): A descriptive name for the function. 
131134
132135        Returns: 
133-             A user-defined aggregate function, which can be used in either data 
134-                 aggregation or window function calls. 
136+             A user-defined function that can be used in SQL expressions, 
137+             data aggregation, or window function calls. 
138+ 
139+         Example: 
140+             **Using `udf` as a function:** 
141+             ``` 
142+             def double_func(x): 
143+                 return x * 2 
144+             double_udf = udf(double_func, [pyarrow.int32()], pyarrow.int32(), 
145+             "volatile", "double_it") 
146+             ``` 
147+ 
148+             **Using `udf` as a decorator:** 
149+             ``` 
150+             @udf([pyarrow.int32()], pyarrow.int32(), "volatile", "double_it") 
151+             def double_udf(x): 
152+                 return x * 2 
153+             ``` 
135154        """ 
136-         if  not  callable (func ):
137-             raise  TypeError ("`func` argument must be callable" )
138-         if  name  is  None :
139-             if  hasattr (func , "__qualname__" ):
140-                 name  =  func .__qualname__ .lower ()
155+ 
156+         def  __new__ (cls , * args , ** kwargs ):
157+             """Create a new UDF. 
158+ 
159+             Trigger UDF function or decorator depending on if the first args is callable 
160+             """ 
161+             if  args  and  callable (args [0 ]):
162+                 # Case 1: Used as a function, require the first parameter to be callable 
163+                 return  cls ._function (* args , ** kwargs )
141164            else :
142-                 name  =  func .__class__ .__name__ .lower ()
143-         return  ScalarUDF (
144-             name = name ,
145-             func = func ,
146-             input_types = input_types ,
147-             return_type = return_type ,
148-             volatility = volatility ,
149-         )
165+                 # Case 2: Used as a decorator with parameters 
166+                 return  cls ._decorator (* args , ** kwargs )
167+ 
168+         @staticmethod  
169+         def  _function (
170+             func : Callable [..., _R ],
171+             input_types : list [pyarrow .DataType ],
172+             return_type : _R ,
173+             volatility : Volatility  |  str ,
174+             name : Optional [str ] =  None ,
175+         ) ->  ScalarUDF :
176+             if  not  callable (func ):
177+                 raise  TypeError ("`func` argument must be callable" )
178+             if  name  is  None :
179+                 if  hasattr (func , "__qualname__" ):
180+                     name  =  func .__qualname__ .lower ()
181+                 else :
182+                     name  =  func .__class__ .__name__ .lower ()
183+             return  ScalarUDF (
184+                 name = name ,
185+                 func = func ,
186+                 input_types = input_types ,
187+                 return_type = return_type ,
188+                 volatility = volatility ,
189+             )
190+ 
191+         @staticmethod  
192+         def  _decorator (
193+             input_types : list [pyarrow .DataType ],
194+             return_type : _R ,
195+             volatility : Volatility  |  str ,
196+             name : Optional [str ] =  None ,
197+         ):
198+             def  decorator (func ):
199+                 udf_caller  =  ScalarUDF .udf (
200+                     func , input_types , return_type , volatility , name 
201+                 )
202+ 
203+                 @functools .wraps (func ) 
204+                 def  wrapper (* args , ** kwargs ):
205+                     return  udf_caller (* args , ** kwargs )
206+ 
207+                 return  wrapper 
208+ 
209+             return  decorator 
150210
151211
152212class  Accumulator (metaclass = ABCMeta ):
@@ -212,25 +272,27 @@ def __call__(self, *args: Expr) -> Expr:
212272        args_raw  =  [arg .expr  for  arg  in  args ]
213273        return  Expr (self ._udaf .__call__ (* args_raw ))
214274
215-     @staticmethod  
216-     def  udaf (
217-         accum : Callable [[], Accumulator ],
218-         input_types : pyarrow .DataType  |  list [pyarrow .DataType ],
219-         return_type : pyarrow .DataType ,
220-         state_type : list [pyarrow .DataType ],
221-         volatility : Volatility  |  str ,
222-         name : Optional [str ] =  None ,
223-     ) ->  AggregateUDF :
224-         """Create a new User-Defined Aggregate Function. 
275+     class  udaf :
276+         """Create a new User-Defined Aggregate Function (UDAF). 
225277
226-         If your :py:class:`Accumulator` can be instantiated with no arguments, you 
227-         can simply pass it's type as ``accum``. If you need to pass additional arguments 
228-         to it's constructor, you can define a lambda or a factory method. During runtime 
229-         the :py:class:`Accumulator` will be constructed for every instance in 
230-         which this UDAF is used. The following examples are all valid. 
278+         This class allows you to define an **aggregate function** that can be used in 
279+         data aggregation or window function calls. 
231280
232-         .. code-block:: python 
281+         Usage: 
282+             - **As a function**: Call `udaf(accum, input_types, return_type, state_type, 
283+                 volatility, name)`. 
284+             - **As a decorator**: Use `@udaf(input_types, return_type, state_type, 
285+                 volatility, name)`. 
286+             When using `udaf` as a decorator, **do not pass `accum` explicitly**. 
233287
288+         **Function example:** 
289+ 
290+             If your `:py:class:Accumulator` can be instantiated with no arguments, you 
291+             can simply pass it's type as `accum`. If you need to pass additional 
292+             arguments to it's constructor, you can define a lambda or a factory method. 
293+             During runtime the `:py:class:Accumulator` will be constructed for every 
294+             instance in which this UDAF is used. The following examples are all valid. 
295+             ``` 
234296            import pyarrow as pa 
235297            import pyarrow.compute as pc 
236298
@@ -253,12 +315,24 @@ def evaluate(self) -> pa.Scalar:
253315            def sum_bias_10() -> Summarize: 
254316                return Summarize(10.0) 
255317
256-             udaf1 = udaf(Summarize, pa.float64(), pa.float64(), [pa.float64()], "immutable") 
257-             udaf2 = udaf(sum_bias_10, pa.float64(), pa.float64(), [pa.float64()], "immutable") 
258-             udaf3 = udaf(lambda: Summarize(20.0), pa.float64(), pa.float64(), [pa.float64()], "immutable") 
318+             udaf1 = udaf(Summarize, pa.float64(), pa.float64(), [pa.float64()], 
319+                 "immutable") 
320+             udaf2 = udaf(sum_bias_10, pa.float64(), pa.float64(), [pa.float64()], 
321+                 "immutable") 
322+             udaf3 = udaf(lambda: Summarize(20.0), pa.float64(), pa.float64(), 
323+                 [pa.float64()], "immutable") 
324+             ``` 
325+ 
326+         **Decorator example:** 
327+             ``` 
328+             @udaf(pa.float64(), pa.float64(), [pa.float64()], "immutable") 
329+             def udf4() -> Summarize: 
330+                 return Summarize(10.0) 
331+             ``` 
259332
260333        Args: 
261-             accum: The accumulator python function. 
334+             accum: The accumulator python function. **Only needed when calling as a 
335+                 function. Skip this argument when using `udaf` as a decorator.** 
262336            input_types: The data types of the arguments to ``accum``. 
263337            return_type: The data type of the return value. 
264338            state_type: The data types of the intermediate accumulation. 
@@ -268,26 +342,69 @@ def sum_bias_10() -> Summarize:
268342        Returns: 
269343            A user-defined aggregate function, which can be used in either data 
270344            aggregation or window function calls. 
271-         """   # noqa W505 
272-         if  not  callable (accum ):
273-             raise  TypeError ("`func` must be callable." )
274-         if  not  isinstance (accum .__call__ (), Accumulator ):
275-             raise  TypeError (
276-                 "Accumulator must implement the abstract base class Accumulator" 
345+         """ 
346+ 
347+         def  __new__ (cls , * args , ** kwargs ):
348+             """Create a new UDAF. 
349+ 
350+             Trigger UDAF function or decorator depending on if the first args is 
351+             callable 
352+             """ 
353+             if  args  and  callable (args [0 ]):
354+                 # Case 1: Used as a function, require the first parameter to be callable 
355+                 return  cls ._function (* args , ** kwargs )
356+             else :
357+                 # Case 2: Used as a decorator with parameters 
358+                 return  cls ._decorator (* args , ** kwargs )
359+ 
360+         @staticmethod  
361+         def  _function (
362+             accum : Callable [[], Accumulator ],
363+             input_types : pyarrow .DataType  |  list [pyarrow .DataType ],
364+             return_type : pyarrow .DataType ,
365+             state_type : list [pyarrow .DataType ],
366+             volatility : Volatility  |  str ,
367+             name : Optional [str ] =  None ,
368+         ) ->  AggregateUDF :
369+             if  not  callable (accum ):
370+                 raise  TypeError ("`func` must be callable." )
371+             if  not  isinstance (accum .__call__ (), Accumulator ):
372+                 raise  TypeError (
373+                     "Accumulator must implement the abstract base class Accumulator" 
374+                 )
375+             if  name  is  None :
376+                 name  =  accum .__call__ ().__class__ .__qualname__ .lower ()
377+             if  isinstance (input_types , pyarrow .DataType ):
378+                 input_types  =  [input_types ]
379+             return  AggregateUDF (
380+                 name = name ,
381+                 accumulator = accum ,
382+                 input_types = input_types ,
383+                 return_type = return_type ,
384+                 state_type = state_type ,
385+                 volatility = volatility ,
277386            )
278-         if  name  is  None :
279-             name  =  accum .__call__ ().__class__ .__qualname__ .lower ()
280-         assert  name  is  not   None 
281-         if  isinstance (input_types , pyarrow .DataType ):
282-             input_types  =  [input_types ]
283-         return  AggregateUDF (
284-             name = name ,
285-             accumulator = accum ,
286-             input_types = input_types ,
287-             return_type = return_type ,
288-             state_type = state_type ,
289-             volatility = volatility ,
290-         )
387+ 
388+         @staticmethod  
389+         def  _decorator (
390+             input_types : pyarrow .DataType  |  list [pyarrow .DataType ],
391+             return_type : pyarrow .DataType ,
392+             state_type : list [pyarrow .DataType ],
393+             volatility : Volatility  |  str ,
394+             name : Optional [str ] =  None ,
395+         ):
396+             def  decorator (accum : Callable [[], Accumulator ]):
397+                 udaf_caller  =  AggregateUDF .udaf (
398+                     accum , input_types , return_type , state_type , volatility , name 
399+                 )
400+ 
401+                 @functools .wraps (accum ) 
402+                 def  wrapper (* args , ** kwargs ):
403+                     return  udaf_caller (* args , ** kwargs )
404+ 
405+                 return  wrapper 
406+ 
407+             return  decorator 
291408
292409
293410class  WindowEvaluator (metaclass = ABCMeta ):
0 commit comments