2323from datafusion .expr import Expr
2424from typing import Callable , TYPE_CHECKING , TypeVar
2525from abc import ABCMeta , abstractmethod
26- from typing import List
26+ from typing import List , Any , Optional
2727from enum import Enum
2828import pyarrow
2929
@@ -186,14 +186,21 @@ def __init__(
186186 return_type : _R ,
187187 state_type : list [pyarrow .DataType ],
188188 volatility : Volatility | str ,
189+ arguments : list [Any ],
189190 ) -> None :
190191 """Instantiate a user-defined aggregate function (UDAF).
191192
192193 See :py:func:`udaf` for a convenience function and argument
193194 descriptions.
194195 """
195196 self ._udaf = df_internal .AggregateUDF (
196- name , accumulator , input_types , return_type , state_type , str (volatility )
197+ name ,
198+ accumulator ,
199+ input_types ,
200+ return_type ,
201+ state_type ,
202+ str (volatility ),
203+ arguments ,
197204 )
198205
199206 def __call__ (self , * args : Expr ) -> Expr :
@@ -212,6 +219,7 @@ def udaf(
212219 return_type : _R ,
213220 state_type : list [pyarrow .DataType ],
214221 volatility : Volatility | str ,
222+ arguments : Optional [list [Any ]] = None ,
215223 name : str | None = None ,
216224 ) -> AggregateUDF :
217225 """Create a new User-Defined Aggregate Function.
@@ -224,6 +232,7 @@ def udaf(
224232 return_type: The data type of the return value.
225233 state_type: The data types of the intermediate accumulation.
226234 volatility: See :py:class:`Volatility` for allowed values.
235+ arguments: A list of arguments to pass in to the __init__ method for accum.
227236 name: A descriptive name for the function.
228237
229238 Returns:
@@ -238,13 +247,15 @@ def udaf(
238247 name = accum .__qualname__ .lower ()
239248 if isinstance (input_types , pyarrow .lib .DataType ):
240249 input_types = [input_types ]
250+ arguments = [] if arguments is None else arguments
241251 return AggregateUDF (
242252 name = name ,
243253 accumulator = accum ,
244254 input_types = input_types ,
245255 return_type = return_type ,
246256 state_type = state_type ,
247257 volatility = volatility ,
258+ arguments = arguments ,
248259 )
249260
250261
0 commit comments