2222import functools
2323from abc import ABCMeta , abstractmethod
2424from enum import Enum
25- from typing import TYPE_CHECKING , Any , Callable , Optional , TypeVar , overload
25+ from typing import TYPE_CHECKING , Any , Callable , Optional , Protocol , TypeVar , overload
2626
2727import pyarrow as pa
2828
@@ -77,6 +77,15 @@ def __str__(self) -> str:
7777 return self .name .lower ()
7878
7979
80+ class ScalarUDFExportable (Protocol ):
81+ """Type hint for object that has __datafusion_table_provider__ PyCapsule.
82+
83+ https://datafusion.apache.org/python/user-guide/io/table_provider.html
84+ """
85+
86+ def __datafusion_scalar_udf__ (self ) -> object : ... # noqa: D105
87+
88+
8089class ScalarUDF :
8190 """Class for performing scalar user-defined functions (UDF).
8291
@@ -133,6 +142,10 @@ def udf(
133142 name : Optional [str ] = None ,
134143 ) -> ScalarUDF : ...
135144
145+ @overload
146+ @staticmethod
147+ def udf (func : ScalarUDFExportable ) -> ScalarUDF : ...
148+
136149 @staticmethod
137150 def udf (* args : Any , ** kwargs : Any ): # noqa: D417
138151 """Create a new User-Defined Function (UDF).
@@ -147,7 +160,10 @@ def udf(*args: Any, **kwargs: Any): # noqa: D417
147160
148161 Args:
149162 func (Callable, optional): **Only needed when calling as a function.**
150- Skip this argument when using `udf` as a decorator.
163+ Skip this argument when using `udf` as a decorator. If you have a Rust
164+ backed ScalarUDF within a PyCapsule, you can pass this parameter
165+ and ignore the rest. They will be determined directly from the
166+ underlying function. See the online documentation for more information.
151167 input_types (list[pa.DataType]): The data types of the arguments
152168 to `func`. This list must be of the same length as the number of
153169 arguments.
@@ -219,21 +235,30 @@ def wrapper(*args: Any, **kwargs: Any):
219235 return decorator
220236
221237 if hasattr (args [0 ], "__datafusion_scalar_udf__" ):
222- name = str (args [0 ].__class__ )
223- return ScalarUDF (
224- name = name ,
225- func = args [0 ],
226- input_types = None ,
227- return_type = None ,
228- volatility = None ,
229- )
238+ return ScalarUDF .from_pycapsule (args [0 ])
230239
231240 if args and callable (args [0 ]):
232241 # Case 1: Used as a function, require the first parameter to be callable
233242 return _function (* args , ** kwargs )
234243 # Case 2: Used as a decorator with parameters
235244 return _decorator (* args , ** kwargs )
236245
246+ @staticmethod
247+ def from_pycapsule (func : ScalarUDFExportable ) -> ScalarUDF :
248+ """Create a Scalar UDF from ScalarUDF PyCapsule object.
249+
250+ This function will instantiate a Scalar UDF that uses a DataFusion
251+ ScalarUDF that is exported via the FFI bindings.
252+ """
253+ name = str (udf .__class__ )
254+ return ScalarUDF (
255+ name = name ,
256+ func = func ,
257+ input_types = None ,
258+ return_type = None ,
259+ volatility = None ,
260+ )
261+
237262
238263class Accumulator (metaclass = ABCMeta ):
239264 """Defines how an :py:class:`AggregateUDF` accumulates values."""
0 commit comments