@@ -104,8 +104,8 @@ def __call__(self, *args: Expr) -> Expr:
104104 This function is not typically called by an end user. These calls will
105105 occur during the evaluation of the dataframe.
106106 """
107- args = [arg .expr for arg in args ]
108- return Expr (self ._udf .__call__ (* args ))
107+ args_raw = [arg .expr for arg in args ]
108+ return Expr (self ._udf .__call__ (* args_raw ))
109109
110110 @staticmethod
111111 def udf (
@@ -209,13 +209,13 @@ def __call__(self, *args: Expr) -> Expr:
209209 This function is not typically called by an end user. These calls will
210210 occur during the evaluation of the dataframe.
211211 """
212- args = [arg .expr for arg in args ]
213- return Expr (self ._udaf .__call__ (* args ))
212+ args_raw = [arg .expr for arg in args ]
213+ return Expr (self ._udaf .__call__ (* args_raw ))
214214
215215 @staticmethod
216216 def udaf (
217217 accum : _A ,
218- input_types : list [pyarrow .DataType ],
218+ input_types : pyarrow . DataType | list [pyarrow .DataType ],
219219 return_type : _R ,
220220 state_type : list [pyarrow .DataType ],
221221 volatility : Volatility | str ,
@@ -245,7 +245,7 @@ def udaf(
245245 )
246246 if name is None :
247247 name = accum .__qualname__ .lower ()
248- if isinstance (input_types , pyarrow .lib . DataType ):
248+ if isinstance (input_types , pyarrow .DataType ):
249249 input_types = [input_types ]
250250 arguments = [] if arguments is None else arguments
251251 return AggregateUDF (
0 commit comments