@@ -621,31 +621,48 @@ def __call__(self, *args: Expr) -> Expr:
621621 args_raw = [arg .expr for arg in args ]
622622 return Expr (self ._udwf .__call__ (* args_raw ))
623623
624+ @overload
625+ @staticmethod
626+ def udwf (
627+ input_types : pa .DataType | list [pa .DataType ],
628+ return_type : pa .DataType ,
629+ volatility : Volatility | str ,
630+ name : Optional [str ] = None ,
631+ ) -> Callable [..., WindowUDF ]: ...
632+
633+ @overload
624634 @staticmethod
625635 def udwf (
626636 func : Callable [[], WindowEvaluator ],
627637 input_types : pa .DataType | list [pa .DataType ],
628638 return_type : pa .DataType ,
629639 volatility : Volatility | str ,
630640 name : Optional [str ] = None ,
631- ) -> WindowUDF :
632- """Create a new User-Defined Window Function.
641+ ) -> WindowUDF : ...
633642
634- If your :py:class:`WindowEvaluator` can be instantiated with no arguments, you
635- can simply pass it's type as ``func``. If you need to pass additional arguments
636- to it's constructor, you can define a lambda or a factory method. During runtime
637- the :py:class:`WindowEvaluator` will be constructed for every instance in
638- which this UDWF is used. The following examples are all valid.
643+ @staticmethod
644+ def udwf (* args : Any , ** kwargs : Any ): # noqa: D417
645+ """Create a new User-Defined Window Function (UDWF).
639646
640- .. code-block:: python
647+ This class can be used both as a **function** and as a **decorator**.
648+
649+ Usage:
650+ - **As a function**: Call `udwf(func, input_types, return_type, volatility,
651+ name)`.
652+ - **As a decorator**: Use `@udwf(input_types, return_type, volatility,
653+ name)`. When using `udwf` as a decorator, **do not pass `func`
654+ explicitly**.
641655
656+ **Function example:**
657+ ```
642658 import pyarrow as pa
643659
644660 class BiasedNumbers(WindowEvaluator):
645661 def __init__(self, start: int = 0) -> None:
646662 self.start = start
647663
648- def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array:
664+ def evaluate_all(self, values: list[pa.Array],
665+ num_rows: int) -> pa.Array:
649666 return pa.array([self.start + i for i in range(num_rows)])
650667
651668 def bias_10() -> BiasedNumbers:
@@ -655,35 +672,93 @@ def bias_10() -> BiasedNumbers:
655672 udwf2 = udwf(bias_10, pa.int64(), pa.int64(), "immutable")
656673 udwf3 = udwf(lambda: BiasedNumbers(20), pa.int64(), pa.int64(), "immutable")
657674
675+ ```
676+
677+ **Decorator example:**
678+ ```
679+ @udwf(pa.int64(), pa.int64(), "immutable")
680+ def biased_numbers() -> BiasedNumbers:
681+ return BiasedNumbers(10)
682+ ```
683+
658684 Args:
659- func: A callable to create the window function.
660- input_types: The data types of the arguments to ``func``.
685+ func: **Only needed when calling as a function. Skip this argument when
686+ using `udwf` as a decorator.**
687+ input_types: The data types of the arguments.
661688 return_type: The data type of the return value.
662689 volatility: See :py:class:`Volatility` for allowed values.
663- arguments: A list of arguments to pass in to the __init__ method for accum.
664690 name: A descriptive name for the function.
665691
666692 Returns:
667- A user-defined window function.
668- """ # noqa: W505, E501
693+ A user-defined window function that can be used in window function calls.
694+ """
695+ if args and callable (args [0 ]):
696+ # Case 1: Used as a function, require the first parameter to be callable
697+ return WindowUDF ._create_window_udf (* args , ** kwargs )
698+ # Case 2: Used as a decorator with parameters
699+ return WindowUDF ._create_window_udf_decorator (* args , ** kwargs )
700+
701+ @staticmethod
702+ def _create_window_udf (
703+ func : Callable [[], WindowEvaluator ],
704+ input_types : pa .DataType | list [pa .DataType ],
705+ return_type : pa .DataType ,
706+ volatility : Volatility | str ,
707+ name : Optional [str ] = None ,
708+ ) -> WindowUDF :
709+ """Create a WindowUDF instance from function arguments."""
669710 if not callable (func ):
670711 msg = "`func` must be callable."
671712 raise TypeError (msg )
672713 if not isinstance (func (), WindowEvaluator ):
673714 msg = "`func` must implement the abstract base class WindowEvaluator"
674715 raise TypeError (msg )
675- if name is None :
676- name = func ().__class__ .__qualname__ .lower ()
677- if isinstance (input_types , pa .DataType ):
678- input_types = [input_types ]
679- return WindowUDF (
680- name = name ,
681- func = func ,
682- input_types = input_types ,
683- return_type = return_type ,
684- volatility = volatility ,
716+
717+ name = name or func .__qualname__ .lower ()
718+ input_types = (
719+ [input_types ] if isinstance (input_types , pa .DataType ) else input_types
685720 )
686721
722+ return WindowUDF (name , func , input_types , return_type , volatility )
723+
724+ @staticmethod
725+ def _get_default_name (func : Callable ) -> str :
726+ """Get the default name for a function based on its attributes."""
727+ if hasattr (func , "__qualname__" ):
728+ return func .__qualname__ .lower ()
729+ return func .__class__ .__name__ .lower ()
730+
731+ @staticmethod
732+ def _normalize_input_types (
733+ input_types : pa .DataType | list [pa .DataType ],
734+ ) -> list [pa .DataType ]:
735+ """Convert a single DataType to a list if needed."""
736+ if isinstance (input_types , pa .DataType ):
737+ return [input_types ]
738+ return input_types
739+
740+ @staticmethod
741+ def _create_window_udf_decorator (
742+ input_types : pa .DataType | list [pa .DataType ],
743+ return_type : pa .DataType ,
744+ volatility : Volatility | str ,
745+ name : Optional [str ] = None ,
746+ ) -> Callable [[Callable [[], WindowEvaluator ]], Callable [..., Expr ]]:
747+ """Create a decorator for a WindowUDF."""
748+
749+ def decorator (func : Callable [[], WindowEvaluator ]) -> Callable [..., Expr ]:
750+ udwf_caller = WindowUDF ._create_window_udf (
751+ func , input_types , return_type , volatility , name
752+ )
753+
754+ @functools .wraps (func )
755+ def wrapper (* args : Any , ** kwargs : Any ) -> Expr :
756+ return udwf_caller (* args , ** kwargs )
757+
758+ return wrapper
759+
760+ return decorator
761+
687762
688763# Convenience exports so we can import instead of treating as
689764# variables at the package root
0 commit comments