diff --git a/python/datafusion/user_defined.py b/python/datafusion/user_defined.py index 67568e313..767878978 100644 --- a/python/datafusion/user_defined.py +++ b/python/datafusion/user_defined.py @@ -22,15 +22,31 @@ import functools from abc import ABCMeta, abstractmethod from enum import Enum -from typing import TYPE_CHECKING, Any, Callable, Optional, Protocol, TypeVar, overload +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Optional, + Protocol, + TypeVar, + cast, + overload, +) import pyarrow as pa +from typing_extensions import TypeGuard import datafusion._internal as df_internal from datafusion.expr import Expr if TYPE_CHECKING: + from _typeshed import CapsuleType as _PyCapsule + _R = TypeVar("_R", bound=pa.DataType) +else: + + class _PyCapsule: + """Lightweight typing proxy for CPython ``PyCapsule`` objects.""" class Volatility(Enum): @@ -83,6 +99,11 @@ class ScalarUDFExportable(Protocol): def __datafusion_scalar_udf__(self) -> object: ... # noqa: D105 +def _is_pycapsule(value: object) -> TypeGuard[_PyCapsule]: + """Return ``True`` when ``value`` is a CPython ``PyCapsule``.""" + return value.__class__.__name__ == "PyCapsule" + + class ScalarUDF: """Class for performing scalar user-defined functions (UDF). @@ -290,6 +311,7 @@ class AggregateUDF: also :py:class:`ScalarUDF` for operating on a row by row basis. """ + @overload def __init__( self, name: str, @@ -298,6 +320,27 @@ def __init__( return_type: pa.DataType, state_type: list[pa.DataType], volatility: Volatility | str, + ) -> None: ... + + @overload + def __init__( + self, + name: str, + accumulator: AggregateUDFExportable, + input_types: None = ..., + return_type: None = ..., + state_type: None = ..., + volatility: None = ..., + ) -> None: ... + + def __init__( + self, + name: str, + accumulator: Callable[[], Accumulator] | AggregateUDFExportable, + input_types: list[pa.DataType] | None, + return_type: pa.DataType | None, + state_type: list[pa.DataType] | None, + volatility: Volatility | str | None, ) -> None: """Instantiate a user-defined aggregate function (UDAF). @@ -307,6 +350,18 @@ def __init__( if hasattr(accumulator, "__datafusion_aggregate_udf__"): self._udaf = df_internal.AggregateUDF.from_pycapsule(accumulator) return + if ( + input_types is None + or return_type is None + or state_type is None + or volatility is None + ): + msg = ( + "`input_types`, `return_type`, `state_type`, and `volatility` " + "must be provided when `accumulator` is callable." + ) + raise TypeError(msg) + self._udaf = df_internal.AggregateUDF( name, accumulator, @@ -350,6 +405,14 @@ def udaf( name: Optional[str] = None, ) -> AggregateUDF: ... + @overload + @staticmethod + def udaf(accum: AggregateUDFExportable) -> AggregateUDF: ... + + @overload + @staticmethod + def udaf(accum: _PyCapsule) -> AggregateUDF: ... + @staticmethod def udaf(*args: Any, **kwargs: Any): # noqa: D417, C901 """Create a new User-Defined Aggregate Function (UDAF). @@ -470,7 +533,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Expr: return decorator - if hasattr(args[0], "__datafusion_aggregate_udf__"): + if hasattr(args[0], "__datafusion_aggregate_udf__") or _is_pycapsule(args[0]): return AggregateUDF.from_pycapsule(args[0]) if args and callable(args[0]): @@ -480,16 +543,22 @@ def wrapper(*args: Any, **kwargs: Any) -> Expr: return _decorator(*args, **kwargs) @staticmethod - def from_pycapsule(func: AggregateUDFExportable) -> AggregateUDF: + def from_pycapsule(func: AggregateUDFExportable | _PyCapsule) -> AggregateUDF: """Create an Aggregate UDF from AggregateUDF PyCapsule object. This function will instantiate a Aggregate UDF that uses a DataFusion AggregateUDF that is exported via the FFI bindings. """ - name = str(func.__class__) + if _is_pycapsule(func): + aggregate = cast(AggregateUDF, object.__new__(AggregateUDF)) + aggregate._udaf = df_internal.AggregateUDF.from_pycapsule(func) + return aggregate + + capsule = cast(AggregateUDFExportable, func) + name = str(capsule.__class__) return AggregateUDF( name=name, - accumulator=func, + accumulator=capsule, input_types=None, return_type=None, state_type=None, diff --git a/python/tests/test_pyclass_frozen.py b/python/tests/test_pyclass_frozen.py index 189ea8dec..428e5e98b 100644 --- a/python/tests/test_pyclass_frozen.py +++ b/python/tests/test_pyclass_frozen.py @@ -32,8 +32,7 @@ r"(?P[A-Za-z_][A-Za-z0-9_]*)\s*=\s*\"(?P[^\"]+)\"", ) STRUCT_NAME_RE = re.compile( - r"\b(?:pub\s+)?(?:struct|enum)\s+" - r"(?P[A-Za-z_][A-Za-z0-9_]*)", + r"\b(?:pub\s+)?(?:struct|enum)\s+" r"(?P[A-Za-z_][A-Za-z0-9_]*)", ) diff --git a/src/udaf.rs b/src/udaf.rs index eab4581df..e48e35f8d 100644 --- a/src/udaf.rs +++ b/src/udaf.rs @@ -154,6 +154,15 @@ pub fn to_rust_accumulator(accum: PyObject) -> AccumulatorFactoryFunction { }) } +fn aggregate_udf_from_capsule(capsule: &Bound<'_, PyCapsule>) -> PyDataFusionResult { + validate_pycapsule(capsule, "datafusion_aggregate_udf")?; + + let udaf = unsafe { capsule.reference::() }; + let udaf: ForeignAggregateUDF = udaf.try_into()?; + + Ok(udaf.into()) +} + /// Represents an AggregateUDF #[pyclass(frozen, name = "AggregateUDF", module = "datafusion", subclass)] #[derive(Debug, Clone)] @@ -186,22 +195,22 @@ impl PyAggregateUDF { #[staticmethod] pub fn from_pycapsule(func: Bound<'_, PyAny>) -> PyDataFusionResult { + if func.is_instance_of::() { + let capsule = func.downcast::().map_err(py_datafusion_err)?; + let function = aggregate_udf_from_capsule(capsule)?; + return Ok(Self { function }); + } + if func.hasattr("__datafusion_aggregate_udf__")? { let capsule = func.getattr("__datafusion_aggregate_udf__")?.call0()?; let capsule = capsule.downcast::().map_err(py_datafusion_err)?; - validate_pycapsule(capsule, "datafusion_aggregate_udf")?; - - let udaf = unsafe { capsule.reference::() }; - let udaf: ForeignAggregateUDF = udaf.try_into()?; - - Ok(Self { - function: udaf.into(), - }) - } else { - Err(crate::errors::PyDataFusionError::Common( - "__datafusion_aggregate_udf__ does not exist on AggregateUDF object.".to_string(), - )) + let function = aggregate_udf_from_capsule(capsule)?; + return Ok(Self { function }); } + + Err(crate::errors::PyDataFusionError::Common( + "__datafusion_aggregate_udf__ does not exist on AggregateUDF object.".to_string(), + )) } /// creates a new PyExpr with the call of the udf