Skip to content

Commit e3b4acc

Browse files
committed
Improve type hinting for udaf and fix one pylance warning
1 parent 14fc166 commit e3b4acc

File tree

1 file changed

+5
-9
lines changed

1 file changed

+5
-9
lines changed

python/datafusion/udf.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
import datafusion._internal as df_internal
2323
from datafusion.expr import Expr
24-
from typing import Callable, TYPE_CHECKING, TypeVar
24+
from typing import Callable, TYPE_CHECKING, TypeVar, Type
2525
from abc import ABCMeta, abstractmethod
2626
from typing import List, Any, Optional
2727
from enum import Enum
@@ -167,10 +167,6 @@ def evaluate(self) -> pyarrow.Scalar:
167167
pass
168168

169169

170-
if TYPE_CHECKING:
171-
_A = TypeVar("_A", bound=(Callable[..., _R], Accumulator))
172-
173-
174170
class AggregateUDF:
175171
"""Class for performing scalar user-defined functions (UDF).
176172
@@ -181,9 +177,9 @@ class AggregateUDF:
181177
def __init__(
182178
self,
183179
name: str | None,
184-
accumulator: _A,
180+
accumulator: Type[Accumulator],
185181
input_types: list[pyarrow.DataType],
186-
return_type: _R,
182+
return_type: pyarrow.DataType,
187183
state_type: list[pyarrow.DataType],
188184
volatility: Volatility | str,
189185
arguments: list[Any],
@@ -214,9 +210,9 @@ def __call__(self, *args: Expr) -> Expr:
214210

215211
@staticmethod
216212
def udaf(
217-
accum: _A,
213+
accum: Type[Accumulator],
218214
input_types: pyarrow.DataType | list[pyarrow.DataType],
219-
return_type: _R,
215+
return_type: pyarrow.DataType,
220216
state_type: list[pyarrow.DataType],
221217
volatility: Volatility | str,
222218
arguments: Optional[list[Any]] = None,

0 commit comments

Comments
 (0)