@@ -646,7 +646,7 @@ def udwf(
646646 ) -> WindowUDF : ...
647647
648648 @staticmethod
649- def udwf (* args : Any , ** kwargs : Any ): # noqa: D417, C901
649+ def udwf (* args : Any , ** kwargs : Any ): # noqa: D417
650650 """Create a new User-Defined Window Function (UDWF).
651651
652652 This class can be used both as a **function** and as a **decorator**.
@@ -697,59 +697,78 @@ def biased_numbers() -> BiasedNumbers:
697697 Returns:
698698 A user-defined window function that can be used in window function calls.
699699 """
700+ if args and callable (args [0 ]):
701+ # Case 1: Used as a function, require the first parameter to be callable
702+ return WindowUDF ._create_window_udf (* args , ** kwargs )
703+ # Case 2: Used as a decorator with parameters
704+ return WindowUDF ._create_window_udf_decorator (* args , ** kwargs )
700705
701- def _function (
702- func : Callable [[], WindowEvaluator ],
703- input_types : pa .DataType | list [pa .DataType ],
704- return_type : _R ,
705- volatility : Volatility | str ,
706- name : Optional [str ] = None ,
707- ) -> WindowUDF :
708- if not callable (func ):
709- msg = "`func` argument must be callable"
710- raise TypeError (msg )
711- if not isinstance (func (), WindowEvaluator ):
712- msg = "`func` must implement the abstract base class WindowEvaluator"
713- raise TypeError (msg )
714- if name is None :
715- if hasattr (func , "__qualname__" ):
716- name = func .__qualname__ .lower ()
717- else :
718- name = func .__class__ .__name__ .lower ()
719- if isinstance (input_types , pa .DataType ):
720- input_types = [input_types ]
721- return WindowUDF (
722- name = name ,
723- func = func ,
724- input_types = input_types ,
725- return_type = return_type ,
726- volatility = volatility ,
727- )
706+ @staticmethod
707+ def _create_window_udf (
708+ func : Callable [[], WindowEvaluator ],
709+ input_types : pa .DataType | list [pa .DataType ],
710+ return_type : _R ,
711+ volatility : Volatility | str ,
712+ name : Optional [str ] = None ,
713+ ) -> WindowUDF :
714+ """Create a WindowUDF instance from function arguments."""
715+ if not callable (func ):
716+ msg = "`func` argument must be callable"
717+ raise TypeError (msg )
718+ if not isinstance (func (), WindowEvaluator ):
719+ msg = "`func` must implement the abstract base class WindowEvaluator"
720+ raise TypeError (msg )
721+
722+ if name is None :
723+ name = WindowUDF ._get_default_name (func )
724+
725+ input_types_list = WindowUDF ._normalize_input_types (input_types )
726+
727+ return WindowUDF (
728+ name = name ,
729+ func = func ,
730+ input_types = input_types_list ,
731+ return_type = return_type ,
732+ volatility = volatility ,
733+ )
728734
729- def _decorator (
730- input_types : pa .DataType | list [pa .DataType ],
731- return_type : _R ,
732- volatility : Volatility | str ,
733- name : Optional [str ] = None ,
734- ) -> Callable [..., Callable [..., Expr ]]:
735- def decorator (func : Callable [[], WindowEvaluator ]) -> Callable [..., Expr ]:
736- udwf_caller = WindowUDF .udwf (
737- func , input_types , return_type , volatility , name
738- )
735+ @staticmethod
736+ def _get_default_name (func : Callable ) -> str :
737+ """Get the default name for a function based on its attributes."""
738+ if hasattr (func , "__qualname__" ):
739+ return func .__qualname__ .lower ()
740+ return func .__class__ .__name__ .lower ()
739741
740- @functools .wraps (func )
741- def wrapper (* args : Any , ** kwargs : Any ) -> Expr :
742- return udwf_caller (* args , ** kwargs )
742+ @staticmethod
743+ def _normalize_input_types (
744+ input_types : pa .DataType | list [pa .DataType ],
745+ ) -> list [pa .DataType ]:
746+ """Convert a single DataType to a list if needed."""
747+ if isinstance (input_types , pa .DataType ):
748+ return [input_types ]
749+ return input_types
743750
744- return wrapper
751+ @staticmethod
752+ def _create_window_udf_decorator (
753+ input_types : pa .DataType | list [pa .DataType ],
754+ return_type : _R ,
755+ volatility : Volatility | str ,
756+ name : Optional [str ] = None ,
757+ ) -> Callable [..., Callable [..., Expr ]]:
758+ """Create a decorator for a WindowUDF."""
745759
746- return decorator
760+ def decorator (func : Callable [[], WindowEvaluator ]) -> Callable [..., Expr ]:
761+ udwf_caller = WindowUDF ._create_window_udf (
762+ func , input_types , return_type , volatility , name
763+ )
747764
748- if args and callable (args [0 ]):
749- # Case 1: Used as a function, require the first parameter to be callable
750- return _function (* args , ** kwargs )
751- # Case 2: Used as a decorator with parameters
752- return _decorator (* args , ** kwargs )
765+ @functools .wraps (func )
766+ def wrapper (* args : Any , ** kwargs : Any ) -> Expr :
767+ return udwf_caller (* args , ** kwargs )
768+
769+ return wrapper
770+
771+ return decorator
753772
754773
755774# Convenience exports so we can import instead of treating as
0 commit comments