2121
2222import datafusion ._internal as df_internal
2323from datafusion .expr import Expr
24- from typing import Callable , TYPE_CHECKING , TypeVar
24+ from typing import Callable , TYPE_CHECKING , TypeVar , Type
2525from abc import ABCMeta , abstractmethod
2626from typing import List , Any , Optional
2727from enum import Enum
@@ -167,10 +167,6 @@ def evaluate(self) -> pyarrow.Scalar:
167167 pass
168168
169169
170- if TYPE_CHECKING :
171- _A = TypeVar ("_A" , bound = (Callable [..., _R ], Accumulator ))
172-
173-
174170class AggregateUDF :
175171 """Class for performing scalar user-defined functions (UDF).
176172
@@ -181,9 +177,9 @@ class AggregateUDF:
181177 def __init__ (
182178 self ,
183179 name : str | None ,
184- accumulator : _A ,
180+ accumulator : Type [ Accumulator ] ,
185181 input_types : list [pyarrow .DataType ],
186- return_type : _R ,
182+ return_type : pyarrow . DataType ,
187183 state_type : list [pyarrow .DataType ],
188184 volatility : Volatility | str ,
189185 arguments : list [Any ],
@@ -214,9 +210,9 @@ def __call__(self, *args: Expr) -> Expr:
214210
215211 @staticmethod
216212 def udaf (
217- accum : _A ,
213+ accum : Type [ Accumulator ] ,
218214 input_types : pyarrow .DataType | list [pyarrow .DataType ],
219- return_type : _R ,
215+ return_type : pyarrow . DataType ,
220216 state_type : list [pyarrow .DataType ],
221217 volatility : Volatility | str ,
222218 arguments : Optional [list [Any ]] = None ,
0 commit comments