Skip to content
Open
79 changes: 74 additions & 5 deletions python/datafusion/user_defined.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,31 @@
import functools
from abc import ABCMeta, abstractmethod
from enum import Enum
from typing import TYPE_CHECKING, Any, Callable, Optional, Protocol, TypeVar, overload
from typing import (
TYPE_CHECKING,
Any,
Callable,
Optional,
Protocol,
TypeVar,
cast,
overload,
)

import pyarrow as pa
from typing_extensions import TypeGuard

import datafusion._internal as df_internal
from datafusion.expr import Expr

if TYPE_CHECKING:
from _typeshed import CapsuleType as _PyCapsule

_R = TypeVar("_R", bound=pa.DataType)
else:

class _PyCapsule:
"""Lightweight typing proxy for CPython ``PyCapsule`` objects."""


class Volatility(Enum):
Expand Down Expand Up @@ -83,6 +99,11 @@ class ScalarUDFExportable(Protocol):
def __datafusion_scalar_udf__(self) -> object: ... # noqa: D105


def _is_pycapsule(value: object) -> TypeGuard[_PyCapsule]:
"""Return ``True`` when ``value`` is a CPython ``PyCapsule``."""
return value.__class__.__name__ == "PyCapsule"


class ScalarUDF:
"""Class for performing scalar user-defined functions (UDF).

Expand Down Expand Up @@ -290,6 +311,7 @@ class AggregateUDF:
also :py:class:`ScalarUDF` for operating on a row by row basis.
"""

@overload
def __init__(
self,
name: str,
Expand All @@ -298,6 +320,27 @@ def __init__(
return_type: pa.DataType,
state_type: list[pa.DataType],
volatility: Volatility | str,
) -> None: ...

@overload
def __init__(
self,
name: str,
accumulator: AggregateUDFExportable,
input_types: None = ...,
return_type: None = ...,
state_type: None = ...,
volatility: None = ...,
) -> None: ...

def __init__(
self,
name: str,
accumulator: Callable[[], Accumulator] | AggregateUDFExportable,
input_types: list[pa.DataType] | None,
return_type: pa.DataType | None,
state_type: list[pa.DataType] | None,
volatility: Volatility | str | None,
) -> None:
"""Instantiate a user-defined aggregate function (UDAF).

Expand All @@ -307,6 +350,18 @@ def __init__(
if hasattr(accumulator, "__datafusion_aggregate_udf__"):
self._udaf = df_internal.AggregateUDF.from_pycapsule(accumulator)
return
if (
input_types is None
or return_type is None
or state_type is None
or volatility is None
):
msg = (
"`input_types`, `return_type`, `state_type`, and `volatility` "
"must be provided when `accumulator` is callable."
)
raise TypeError(msg)

self._udaf = df_internal.AggregateUDF(
name,
accumulator,
Expand Down Expand Up @@ -350,6 +405,14 @@ def udaf(
name: Optional[str] = None,
) -> AggregateUDF: ...

@overload
@staticmethod
def udaf(accum: AggregateUDFExportable) -> AggregateUDF: ...

@overload
@staticmethod
def udaf(accum: _PyCapsule) -> AggregateUDF: ...

@staticmethod
def udaf(*args: Any, **kwargs: Any): # noqa: D417, C901
"""Create a new User-Defined Aggregate Function (UDAF).
Expand Down Expand Up @@ -470,7 +533,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Expr:

return decorator

if hasattr(args[0], "__datafusion_aggregate_udf__"):
if hasattr(args[0], "__datafusion_aggregate_udf__") or _is_pycapsule(args[0]):
return AggregateUDF.from_pycapsule(args[0])

if args and callable(args[0]):
Expand All @@ -480,16 +543,22 @@ def wrapper(*args: Any, **kwargs: Any) -> Expr:
return _decorator(*args, **kwargs)

@staticmethod
def from_pycapsule(func: AggregateUDFExportable) -> AggregateUDF:
def from_pycapsule(func: AggregateUDFExportable | _PyCapsule) -> AggregateUDF:
"""Create an Aggregate UDF from AggregateUDF PyCapsule object.

This function will instantiate a Aggregate UDF that uses a DataFusion
AggregateUDF that is exported via the FFI bindings.
"""
name = str(func.__class__)
if _is_pycapsule(func):
aggregate = cast(AggregateUDF, object.__new__(AggregateUDF))
aggregate._udaf = df_internal.AggregateUDF.from_pycapsule(func)
return aggregate

capsule = cast(AggregateUDFExportable, func)
name = str(capsule.__class__)
return AggregateUDF(
name=name,
accumulator=func,
accumulator=capsule,
input_types=None,
return_type=None,
state_type=None,
Expand Down
3 changes: 1 addition & 2 deletions python/tests/test_pyclass_frozen.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@
r"(?P<key>[A-Za-z_][A-Za-z0-9_]*)\s*=\s*\"(?P<value>[^\"]+)\"",
)
STRUCT_NAME_RE = re.compile(
r"\b(?:pub\s+)?(?:struct|enum)\s+"
r"(?P<name>[A-Za-z_][A-Za-z0-9_]*)",
Comment on lines -35 to -36
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not related to this PR but this came up as a Ruff error.

r"\b(?:pub\s+)?(?:struct|enum)\s+" r"(?P<name>[A-Za-z_][A-Za-z0-9_]*)",
)


Expand Down
33 changes: 21 additions & 12 deletions src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,15 @@ pub fn to_rust_accumulator(accum: PyObject) -> AccumulatorFactoryFunction {
})
}

fn aggregate_udf_from_capsule(capsule: &Bound<'_, PyCapsule>) -> PyDataFusionResult<AggregateUDF> {
validate_pycapsule(capsule, "datafusion_aggregate_udf")?;

let udaf = unsafe { capsule.reference::<FFI_AggregateUDF>() };
let udaf: ForeignAggregateUDF = udaf.try_into()?;

Ok(udaf.into())
}

/// Represents an AggregateUDF
#[pyclass(frozen, name = "AggregateUDF", module = "datafusion", subclass)]
#[derive(Debug, Clone)]
Expand Down Expand Up @@ -186,22 +195,22 @@ impl PyAggregateUDF {

#[staticmethod]
pub fn from_pycapsule(func: Bound<'_, PyAny>) -> PyDataFusionResult<Self> {
if func.is_instance_of::<PyCapsule>() {
let capsule = func.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
let function = aggregate_udf_from_capsule(&capsule)?;
return Ok(Self { function });
}

if func.hasattr("__datafusion_aggregate_udf__")? {
let capsule = func.getattr("__datafusion_aggregate_udf__")?.call0()?;
let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
validate_pycapsule(capsule, "datafusion_aggregate_udf")?;

let udaf = unsafe { capsule.reference::<FFI_AggregateUDF>() };
let udaf: ForeignAggregateUDF = udaf.try_into()?;

Ok(Self {
function: udaf.into(),
})
} else {
Err(crate::errors::PyDataFusionError::Common(
"__datafusion_aggregate_udf__ does not exist on AggregateUDF object.".to_string(),
))
let function = aggregate_udf_from_capsule(&capsule)?;
return Ok(Self { function });
}

Err(crate::errors::PyDataFusionError::Common(
"__datafusion_aggregate_udf__ does not exist on AggregateUDF object.".to_string(),
))
}

/// creates a new PyExpr with the call of the udf
Expand Down
Loading