Skip to content

Commit d29acf6

Browse files
committed
refactor: Remove state_type from udwf method signature and update return type handling
- Eliminated the state_type parameter from the udwf method to simplify the function signature. - Updated return type handling in the _function and _decorator methods to use a generic type _R for better type flexibility. - Enhanced the decorator to wrap the original function, allowing for improved argument handling and expression return.
1 parent 1164374 commit d29acf6

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed

python/datafusion/udf.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from __future__ import annotations
2121

22+
from ast import Call
2223
import functools
2324
from abc import ABCMeta, abstractmethod
2425
from enum import Enum
@@ -631,7 +632,6 @@ def __call__(self, *args: Expr) -> Expr:
631632
def udwf(
632633
input_type: pa.DataType | list[pa.DataType],
633634
return_type: pa.DataType,
634-
state_type: list[pa.DataType],
635635
volatility: str,
636636
name: Optional[str] = None,
637637
) -> Callable[..., WindowUDF]: ...
@@ -642,7 +642,6 @@ def udwf(
642642
func: Callable[[], WindowEvaluator],
643643
input_type: pa.DataType | list[pa.DataType],
644644
return_type: pa.DataType,
645-
state_type: list[pa.DataType],
646645
volatility: str,
647646
name: Optional[str] = None,
648647
) -> WindowUDF: ...
@@ -700,7 +699,7 @@ def biased_numbers() -> BiasedNumbers:
700699
def _function(
701700
func: Callable[[], WindowEvaluator],
702701
input_types: pa.DataType | list[pa.DataType],
703-
return_type: pa.DataType,
702+
return_type: _R,
704703
volatility: Volatility | str,
705704
name: Optional[str] = None,
706705
) -> WindowUDF:
@@ -727,12 +726,20 @@ def _function(
727726

728727
def _decorator(
729728
input_types: pa.DataType | list[pa.DataType],
730-
return_type: pa.DataType,
729+
return_type: _R,
731730
volatility: Volatility | str,
732731
name: Optional[str] = None,
733-
) -> Callable[[Callable[[], WindowEvaluator]], WindowUDF]:
734-
def decorator(func: Callable[[], WindowEvaluator]) -> WindowUDF:
735-
return _function(func, input_types, return_type, volatility, name)
732+
) -> Callable[..., Callable[..., Expr]]:
733+
def decorator(func: Callable[[], WindowEvaluator]) -> Callable[..., Expr]:
734+
udwf_caller = WindowUDF.udwf(
735+
func, input_types, return_type, volatility, name
736+
)
737+
738+
@functools.wraps(func)
739+
def wrapper(*args: Any, **kwargs: Any) -> Expr:
740+
return udwf_caller(*args, **kwargs)
741+
742+
return wrapper
736743

737744
return decorator
738745

0 commit comments

Comments
 (0)