1919
2020from __future__ import annotations
2121
22+ import functools
2223from abc import ABCMeta , abstractmethod
2324from enum import Enum
2425from typing import TYPE_CHECKING , Callable , List , Optional , TypeVar
@@ -110,43 +111,102 @@ def __call__(self, *args: Expr) -> Expr:
110111 args_raw = [arg .expr for arg in args ]
111112 return Expr (self ._udf .__call__ (* args_raw ))
112113
113- @staticmethod
114- def udf (
115- func : Callable [..., _R ],
116- input_types : list [pyarrow .DataType ],
117- return_type : _R ,
118- volatility : Volatility | str ,
119- name : Optional [str ] = None ,
120- ) -> ScalarUDF :
121- """Create a new User-Defined Function.
114+ class udf :
115+ """Create a new User-Defined Function (UDF).
116+
117+ This class can be used both as a **function** and as a **decorator**.
118+
119+ Usage:
120+ - **As a function**: Call `udf(func, input_types, return_type, volatility,
121+ name)`.
122+ - **As a decorator**: Use `@udf(input_types, return_type, volatility,
123+ name)`. In this case, do **not** pass `func` explicitly.
122124
123125 Args:
124- func: A callable python function.
125- input_types: The data types of the arguments to ``func``. This list
126- must be of the same length as the number of arguments.
127- return_type: The data type of the return value from the python
128- function.
129- volatility: See ``Volatility`` for allowed values.
130- name: A descriptive name for the function.
126+ func (Callable, optional): **Only needed when calling as a function.**
127+ Skip this argument when using `udf` as a decorator.
128+ input_types (list[pyarrow.DataType]): The data types of the arguments
129+ to `func`. This list must be of the same length as the number of
130+ arguments.
131+ return_type (_R): The data type of the return value from the function.
132+ volatility (Volatility | str): See `Volatility` for allowed values.
133+ name (Optional[str]): A descriptive name for the function.
131134
132135 Returns:
133- A user-defined aggregate function, which can be used in either data
134- aggregation or window function calls.
136+ A user-defined function that can be used in SQL expressions,
137+ data aggregation, or window function calls.
138+
139+ Example:
140+ **Using `udf` as a function:**
141+ ```
142+ def double_func(x):
143+ return x * 2
144+ double_udf = udf(double_func, [pyarrow.int32()], pyarrow.int32(),
145+ "volatile", "double_it")
146+ ```
147+
148+ **Using `udf` as a decorator:**
149+ ```
150+ @udf([pyarrow.int32()], pyarrow.int32(), "volatile", "double_it")
151+ def double_udf(x):
152+ return x * 2
153+ ```
135154 """
136- if not callable (func ):
137- raise TypeError ("`func` argument must be callable" )
138- if name is None :
139- if hasattr (func , "__qualname__" ):
140- name = func .__qualname__ .lower ()
155+
156+ def __new__ (cls , * args , ** kwargs ):
157+ """Create a new UDF.
158+
159+ Trigger UDF function or decorator depending on if the first args is callable
160+ """
161+ if args and callable (args [0 ]):
162+ # Case 1: Used as a function, require the first parameter to be callable
163+ return cls ._function (* args , ** kwargs )
141164 else :
142- name = func .__class__ .__name__ .lower ()
143- return ScalarUDF (
144- name = name ,
145- func = func ,
146- input_types = input_types ,
147- return_type = return_type ,
148- volatility = volatility ,
149- )
165+ # Case 2: Used as a decorator with parameters
166+ return cls ._decorator (* args , ** kwargs )
167+
168+ @staticmethod
169+ def _function (
170+ func : Callable [..., _R ],
171+ input_types : list [pyarrow .DataType ],
172+ return_type : _R ,
173+ volatility : Volatility | str ,
174+ name : Optional [str ] = None ,
175+ ) -> ScalarUDF :
176+ if not callable (func ):
177+ raise TypeError ("`func` argument must be callable" )
178+ if name is None :
179+ if hasattr (func , "__qualname__" ):
180+ name = func .__qualname__ .lower ()
181+ else :
182+ name = func .__class__ .__name__ .lower ()
183+ return ScalarUDF (
184+ name = name ,
185+ func = func ,
186+ input_types = input_types ,
187+ return_type = return_type ,
188+ volatility = volatility ,
189+ )
190+
191+ @staticmethod
192+ def _decorator (
193+ input_types : list [pyarrow .DataType ],
194+ return_type : _R ,
195+ volatility : Volatility | str ,
196+ name : Optional [str ] = None ,
197+ ):
198+ def decorator (func ):
199+ udf_caller = ScalarUDF .udf (
200+ func , input_types , return_type , volatility , name
201+ )
202+
203+ @functools .wraps (func )
204+ def wrapper (* args , ** kwargs ):
205+ return udf_caller (* args , ** kwargs )
206+
207+ return wrapper
208+
209+ return decorator
150210
151211
152212class Accumulator (metaclass = ABCMeta ):
@@ -212,25 +272,27 @@ def __call__(self, *args: Expr) -> Expr:
212272 args_raw = [arg .expr for arg in args ]
213273 return Expr (self ._udaf .__call__ (* args_raw ))
214274
215- @staticmethod
216- def udaf (
217- accum : Callable [[], Accumulator ],
218- input_types : pyarrow .DataType | list [pyarrow .DataType ],
219- return_type : pyarrow .DataType ,
220- state_type : list [pyarrow .DataType ],
221- volatility : Volatility | str ,
222- name : Optional [str ] = None ,
223- ) -> AggregateUDF :
224- """Create a new User-Defined Aggregate Function.
275+ class udaf :
276+ """Create a new User-Defined Aggregate Function (UDAF).
225277
226- If your :py:class:`Accumulator` can be instantiated with no arguments, you
227- can simply pass it's type as ``accum``. If you need to pass additional arguments
228- to it's constructor, you can define a lambda or a factory method. During runtime
229- the :py:class:`Accumulator` will be constructed for every instance in
230- which this UDAF is used. The following examples are all valid.
278+ This class allows you to define an **aggregate function** that can be used in
279+ data aggregation or window function calls.
231280
232- .. code-block:: python
281+ Usage:
282+ - **As a function**: Call `udaf(accum, input_types, return_type, state_type,
283+ volatility, name)`.
284+ - **As a decorator**: Use `@udaf(input_types, return_type, state_type,
285+ volatility, name)`.
286+ When using `udaf` as a decorator, **do not pass `accum` explicitly**.
233287
288+ **Function example:**
289+
290+ If your `:py:class:Accumulator` can be instantiated with no arguments, you
291+ can simply pass it's type as `accum`. If you need to pass additional
292+ arguments to it's constructor, you can define a lambda or a factory method.
293+ During runtime the `:py:class:Accumulator` will be constructed for every
294+ instance in which this UDAF is used. The following examples are all valid.
295+ ```
234296 import pyarrow as pa
235297 import pyarrow.compute as pc
236298
@@ -253,12 +315,24 @@ def evaluate(self) -> pa.Scalar:
253315 def sum_bias_10() -> Summarize:
254316 return Summarize(10.0)
255317
256- udaf1 = udaf(Summarize, pa.float64(), pa.float64(), [pa.float64()], "immutable")
257- udaf2 = udaf(sum_bias_10, pa.float64(), pa.float64(), [pa.float64()], "immutable")
258- udaf3 = udaf(lambda: Summarize(20.0), pa.float64(), pa.float64(), [pa.float64()], "immutable")
318+ udaf1 = udaf(Summarize, pa.float64(), pa.float64(), [pa.float64()],
319+ "immutable")
320+ udaf2 = udaf(sum_bias_10, pa.float64(), pa.float64(), [pa.float64()],
321+ "immutable")
322+ udaf3 = udaf(lambda: Summarize(20.0), pa.float64(), pa.float64(),
323+ [pa.float64()], "immutable")
324+ ```
325+
326+ **Decorator example:**
327+ ```
328+ @udaf(pa.float64(), pa.float64(), [pa.float64()], "immutable")
329+ def udf4() -> Summarize:
330+ return Summarize(10.0)
331+ ```
259332
260333 Args:
261- accum: The accumulator python function.
334+ accum: The accumulator python function. **Only needed when calling as a
335+ function. Skip this argument when using `udaf` as a decorator.**
262336 input_types: The data types of the arguments to ``accum``.
263337 return_type: The data type of the return value.
264338 state_type: The data types of the intermediate accumulation.
@@ -268,26 +342,69 @@ def sum_bias_10() -> Summarize:
268342 Returns:
269343 A user-defined aggregate function, which can be used in either data
270344 aggregation or window function calls.
271- """ # noqa W505
272- if not callable (accum ):
273- raise TypeError ("`func` must be callable." )
274- if not isinstance (accum .__call__ (), Accumulator ):
275- raise TypeError (
276- "Accumulator must implement the abstract base class Accumulator"
345+ """
346+
347+ def __new__ (cls , * args , ** kwargs ):
348+ """Create a new UDAF.
349+
350+ Trigger UDAF function or decorator depending on if the first args is
351+ callable
352+ """
353+ if args and callable (args [0 ]):
354+ # Case 1: Used as a function, require the first parameter to be callable
355+ return cls ._function (* args , ** kwargs )
356+ else :
357+ # Case 2: Used as a decorator with parameters
358+ return cls ._decorator (* args , ** kwargs )
359+
360+ @staticmethod
361+ def _function (
362+ accum : Callable [[], Accumulator ],
363+ input_types : pyarrow .DataType | list [pyarrow .DataType ],
364+ return_type : pyarrow .DataType ,
365+ state_type : list [pyarrow .DataType ],
366+ volatility : Volatility | str ,
367+ name : Optional [str ] = None ,
368+ ) -> AggregateUDF :
369+ if not callable (accum ):
370+ raise TypeError ("`func` must be callable." )
371+ if not isinstance (accum .__call__ (), Accumulator ):
372+ raise TypeError (
373+ "Accumulator must implement the abstract base class Accumulator"
374+ )
375+ if name is None :
376+ name = accum .__call__ ().__class__ .__qualname__ .lower ()
377+ if isinstance (input_types , pyarrow .DataType ):
378+ input_types = [input_types ]
379+ return AggregateUDF (
380+ name = name ,
381+ accumulator = accum ,
382+ input_types = input_types ,
383+ return_type = return_type ,
384+ state_type = state_type ,
385+ volatility = volatility ,
277386 )
278- if name is None :
279- name = accum .__call__ ().__class__ .__qualname__ .lower ()
280- assert name is not None
281- if isinstance (input_types , pyarrow .DataType ):
282- input_types = [input_types ]
283- return AggregateUDF (
284- name = name ,
285- accumulator = accum ,
286- input_types = input_types ,
287- return_type = return_type ,
288- state_type = state_type ,
289- volatility = volatility ,
290- )
387+
388+ @staticmethod
389+ def _decorator (
390+ input_types : pyarrow .DataType | list [pyarrow .DataType ],
391+ return_type : pyarrow .DataType ,
392+ state_type : list [pyarrow .DataType ],
393+ volatility : Volatility | str ,
394+ name : Optional [str ] = None ,
395+ ):
396+ def decorator (accum : Callable [[], Accumulator ]):
397+ udaf_caller = AggregateUDF .udaf (
398+ accum , input_types , return_type , state_type , volatility , name
399+ )
400+
401+ @functools .wraps (accum )
402+ def wrapper (* args , ** kwargs ):
403+ return udaf_caller (* args , ** kwargs )
404+
405+ return wrapper
406+
407+ return decorator
291408
292409
293410class WindowEvaluator (metaclass = ABCMeta ):
0 commit comments