@@ -111,65 +111,96 @@ def __call__(self, *args: Expr) -> Expr:
111111 args_raw = [arg .expr for arg in args ]
112112 return Expr (self ._udf .__call__ (* args_raw ))
113113
114- @ staticmethod
115- def udf (
116- func : Callable [..., _R ],
117- input_types : list [ pyarrow . DataType ],
118- return_type : _R ,
119- volatility : Volatility | str ,
120- name : Optional [ str ] = None ,
121- ) -> ScalarUDF :
122- """Create a new User-Defined Function .
114+ class udf :
115+ """Create a new User-Defined Function (UDF).
116+
117+ This class can be used both as a **function** and as a **decorator**.
118+
119+ Usage:
120+ - **As a function**: Call `udf(func, input_types, return_type, volatility, name)`.
121+ - **As a decorator**: Use `@udf(input_types, return_type, volatility, name)`.
122+ In this case, do **not** pass `func` explicitly .
123123
124124 Args:
125- func: A callable python function.
126- input_types: The data types of the arguments to ``func``. This list
127- must be of the same length as the number of arguments.
128- return_type: The data type of the return value from the python
129- function.
130- volatility: See `` Volatility` ` for allowed values.
131- name: A descriptive name for the function.
125+ func (Callable, optional): **Only needed when calling as a function.**
126+ Skip this argument when using `udf` as a decorator.
127+ input_types (list[pyarrow.DataType]): The data types of the arguments
128+ to `func`. This list must be of the same length as the number of arguments.
129+ return_type (_R): The data type of the return value from the function.
130+ volatility (Volatility | str) : See `Volatility` for allowed values.
131+ name (Optional[str]) : A descriptive name for the function.
132132
133133 Returns:
134- A user-defined aggregate function, which can be used in either data
135- aggregation or window function calls.
134+ A user-defined function that can be used in SQL expressions,
135+ data aggregation, or window function calls.
136+
137+ Example:
138+ **Using `udf` as a function:**
139+ ```python
140+ def double_func(x):
141+ return x * 2
142+ double_udf = udf(double_func, [pyarrow.int32()], pyarrow.int32(), "volatile", "double_it")
143+ ```
144+
145+ **Using `udf` as a decorator:**
146+ ```python
147+ @udf([pyarrow.int32()], pyarrow.int32(), "volatile", "double_it")
148+ def double_udf(x):
149+ return x * 2
150+ ```
136151 """
137- if not callable (func ):
138- raise TypeError ("`func` argument must be callable" )
139- if name is None :
140- if hasattr (func , "__qualname__" ):
141- name = func .__qualname__ .lower ()
152+ def __new__ (cls , * args , ** kwargs ):
153+ if args and callable (args [0 ]):
154+ # Case 1: Used as a function, require the first parameter to be callable
155+ return cls ._function (* args , ** kwargs )
142156 else :
143- name = func .__class__ .__name__ .lower ()
144- return ScalarUDF (
145- name = name ,
146- func = func ,
147- input_types = input_types ,
148- return_type = return_type ,
149- volatility = volatility ,
150- )
151-
152- @staticmethod
153- def udf_decorator (
154- input_types : list [pyarrow .DataType ],
155- return_type : _R ,
156- volatility : Volatility | str ,
157- name : Optional [str ] = None
158- ):
159- def decorator (func ):
160- udf_caller = ScalarUDF .udf (
161- func ,
162- input_types ,
163- return_type ,
164- volatility ,
165- name
157+ # Case 2: Used as a decorator with parameters
158+ return cls ._decorator (* args , ** kwargs )
159+
160+ @staticmethod
161+ def _function (
162+ func : Callable [..., _R ],
163+ input_types : list [pyarrow .DataType ],
164+ return_type : _R ,
165+ volatility : Volatility | str ,
166+ name : Optional [str ] = None ,
167+ ) -> ScalarUDF :
168+ if not callable (func ):
169+ raise TypeError ("`func` argument must be callable" )
170+ if name is None :
171+ if hasattr (func , "__qualname__" ):
172+ name = func .__qualname__ .lower ()
173+ else :
174+ name = func .__class__ .__name__ .lower ()
175+ return ScalarUDF (
176+ name = name ,
177+ func = func ,
178+ input_types = input_types ,
179+ return_type = return_type ,
180+ volatility = volatility ,
166181 )
167-
168- @functools .wraps (func )
169- def wrapper (* args , ** kwargs ):
170- return udf_caller (* args , ** kwargs )
171- return wrapper
172- return decorator
182+
183+ @staticmethod
184+ def _decorator (
185+ input_types : list [pyarrow .DataType ],
186+ return_type : _R ,
187+ volatility : Volatility | str ,
188+ name : Optional [str ] = None
189+ ):
190+ def decorator (func ):
191+ udf_caller = ScalarUDF .udf (
192+ func ,
193+ input_types ,
194+ return_type ,
195+ volatility ,
196+ name
197+ )
198+
199+ @functools .wraps (func )
200+ def wrapper (* args , ** kwargs ):
201+ return udf_caller (* args , ** kwargs )
202+ return wrapper
203+ return decorator
173204
174205class Accumulator (metaclass = ABCMeta ):
175206 """Defines how an :py:class:`AggregateUDF` accumulates values."""
@@ -234,25 +265,27 @@ def __call__(self, *args: Expr) -> Expr:
234265 args_raw = [arg .expr for arg in args ]
235266 return Expr (self ._udaf .__call__ (* args_raw ))
236267
237- @staticmethod
238- def udaf (
239- accum : Callable [[], Accumulator ],
240- input_types : pyarrow .DataType | list [pyarrow .DataType ],
241- return_type : pyarrow .DataType ,
242- state_type : list [pyarrow .DataType ],
243- volatility : Volatility | str ,
244- name : Optional [str ] = None ,
245- ) -> AggregateUDF :
246- """Create a new User-Defined Aggregate Function.
268+ class udaf :
269+ """Create a new User-Defined Aggregate Function (UDAF).
247270
248- If your :py:class:`Accumulator` can be instantiated with no arguments, you
249- can simply pass it's type as ``accum``. If you need to pass additional arguments
250- to it's constructor, you can define a lambda or a factory method. During runtime
251- the :py:class:`Accumulator` will be constructed for every instance in
252- which this UDAF is used. The following examples are all valid.
271+ This class allows you to define an **aggregate function** that can be used in data
272+ aggregation or window function calls.
253273
254- .. code-block:: python
274+ Usage:
275+ - **As a function**: Call `udaf(accum, input_types, return_type, state_type,
276+ volatility, name)`.
277+ - **As a decorator**: Use `@udaf(input_types, return_type, state_type,
278+ volatility, name)`.
279+ When using `udaf` as a decorator, **do not pass `accum` explicitly**.
255280
281+ **Function example:**
282+
283+ If your `:py:class:Accumulator` can be instantiated with no arguments, you can
284+ simply pass it's type as `accum`. If you need to pass additional arguments to
285+ it's constructor, you can define a lambda or a factory method. During runtime the
286+ `:py:class:Accumulator` will be constructed for every instance in which this UDAF is
287+ used. The following examples are all valid.
288+ ```
256289 import pyarrow as pa
257290 import pyarrow.compute as pc
258291
@@ -278,61 +311,89 @@ def sum_bias_10() -> Summarize:
278311 udaf1 = udaf(Summarize, pa.float64(), pa.float64(), [pa.float64()], "immutable")
279312 udaf2 = udaf(sum_bias_10, pa.float64(), pa.float64(), [pa.float64()], "immutable")
280313 udaf3 = udaf(lambda: Summarize(20.0), pa.float64(), pa.float64(), [pa.float64()], "immutable")
314+ ```
315+
316+ **Decorator example:**
317+ ```
318+ @udaf(pa.float64(), pa.float64(), [pa.float64()], "immutable")
319+ def udf4() -> Summarize:
320+ return Summarize(10.0)
321+ ```
281322
282323 Args:
283- accum: The accumulator python function.
284- input_types: The data types of the arguments to `` accum`` .
324+ accum: The accumulator python function. **Only needed when calling as a function. Skip this argument when using `udaf` as a decorator.**
325+ input_types: The data types of the arguments to `accum.
285326 return_type: The data type of the return value.
286327 state_type: The data types of the intermediate accumulation.
287- volatility: See :py:class:` Volatility` for allowed values.
328+ volatility: See :py:class:Volatility for allowed values.
288329 name: A descriptive name for the function.
289330
290331 Returns:
291332 A user-defined aggregate function, which can be used in either data
292333 aggregation or window function calls.
293- """ # noqa W505
294- if not callable (accum ):
295- raise TypeError ("`func` must be callable." )
296- if not isinstance (accum .__call__ (), Accumulator ):
297- raise TypeError (
298- "Accumulator must implement the abstract base class Accumulator"
299- )
300- if name is None :
301- name = accum .__call__ ().__class__ .__qualname__ .lower ()
302- if isinstance (input_types , pyarrow .DataType ):
303- input_types = [input_types ]
304- return AggregateUDF (
305- name = name ,
306- accumulator = accum ,
307- input_types = input_types ,
308- return_type = return_type ,
309- state_type = state_type ,
310- volatility = volatility ,
311- )
312-
313- @staticmethod
314- def udaf_decorator (
315- input_types : pyarrow .DataType | list [pyarrow .DataType ],
316- return_type : pyarrow .DataType ,
317- state_type : list [pyarrow .DataType ],
318- volatility : Volatility | str ,
319- name : Optional [str ] = None
320- ):
321- def decorator (accum : Callable [[], Accumulator ]):
322- udaf_caller = AggregateUDF .udaf (
323- accum ,
324- input_types ,
325- return_type ,
326- state_type ,
327- volatility ,
328- name
334+ """
335+
336+
337+
338+ def __new__ (cls , * args , ** kwargs ):
339+ if args and callable (args [0 ]):
340+ # Case 1: Used as a function, require the first parameter to be callable
341+ return cls ._function (* args , ** kwargs )
342+ else :
343+ # Case 2: Used as a decorator with parameters
344+ return cls ._decorator (* args , ** kwargs )
345+
346+ @staticmethod
347+ def _function (
348+ accum : Callable [[], Accumulator ],
349+ input_types : pyarrow .DataType | list [pyarrow .DataType ],
350+ return_type : pyarrow .DataType ,
351+ state_type : list [pyarrow .DataType ],
352+ volatility : Volatility | str ,
353+ name : Optional [str ] = None ,
354+ ) -> AggregateUDF :
355+ if not callable (accum ):
356+ raise TypeError ("`func` must be callable." )
357+ if not isinstance (accum .__call__ (), Accumulator ):
358+ raise TypeError (
359+ "Accumulator must implement the abstract base class Accumulator"
360+ )
361+ if name is None :
362+ name = accum .__call__ ().__class__ .__qualname__ .lower ()
363+ if isinstance (input_types , pyarrow .DataType ):
364+ input_types = [input_types ]
365+ return AggregateUDF (
366+ name = name ,
367+ accumulator = accum ,
368+ input_types = input_types ,
369+ return_type = return_type ,
370+ state_type = state_type ,
371+ volatility = volatility ,
329372 )
330373
331- @functools .wraps (accum )
332- def wrapper (* args , ** kwargs ):
333- return udaf_caller (* args , ** kwargs )
334- return wrapper
335- return decorator
374+ @staticmethod
375+ def _decorator (
376+ input_types : pyarrow .DataType | list [pyarrow .DataType ],
377+ return_type : pyarrow .DataType ,
378+ state_type : list [pyarrow .DataType ],
379+ volatility : Volatility | str ,
380+ name : Optional [str ] = None
381+ ):
382+ def decorator (accum : Callable [[], Accumulator ]):
383+ udaf_caller = AggregateUDF .udaf (
384+ accum ,
385+ input_types ,
386+ return_type ,
387+ state_type ,
388+ volatility ,
389+ name
390+ )
391+
392+ @functools .wraps (accum )
393+ def wrapper (* args , ** kwargs ):
394+ return udaf_caller (* args , ** kwargs )
395+ return wrapper
396+ return decorator
336397
337398
338399
0 commit comments