diff --git a/docs/source/conf.py b/docs/source/conf.py index 0be03d81d..0ca124fd1 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -71,6 +71,7 @@ autoapi_member_order = "groupwise" suppress_warnings = ["autoapi.python_import_resolution"] autoapi_python_class_content = "both" +autoapi_keep_files = False # set to True for debugging generated files def autoapi_skip_member_fn(app, what, name, obj, skip, options) -> bool: # noqa: ARG001 diff --git a/docs/source/user-guide/common-operations/udf-and-udfa.rst b/docs/source/user-guide/common-operations/udf-and-udfa.rst index ffd7a05cb..e22338305 100644 --- a/docs/source/user-guide/common-operations/udf-and-udfa.rst +++ b/docs/source/user-guide/common-operations/udf-and-udfa.rst @@ -26,7 +26,7 @@ Scalar Functions When writing a user-defined function that can operate on a row by row basis, these are called Scalar Functions. You can define your own scalar function by calling -:py:func:`~datafusion.udf.ScalarUDF.udf` . +:py:func:`~datafusion.user_defined.ScalarUDF.udf` . The basic definition of a scalar UDF is a python function that takes one or more `pyarrow `_ arrays and returns a single array as @@ -93,9 +93,9 @@ converting to Python objects to do the evaluation. Aggregate Functions ------------------- -The :py:func:`~datafusion.udf.AggregateUDF.udaf` function allows you to define User-Defined +The :py:func:`~datafusion.user_defined.AggregateUDF.udaf` function allows you to define User-Defined Aggregate Functions (UDAFs). To use this you must implement an -:py:class:`~datafusion.udf.Accumulator` that determines how the aggregation is performed. +:py:class:`~datafusion.user_defined.Accumulator` that determines how the aggregation is performed. When defining a UDAF there are four methods you need to implement. The ``update`` function takes the array(s) of input and updates the internal state of the accumulator. You should define this function @@ -153,8 +153,8 @@ Window Functions ---------------- To implement a User-Defined Window Function (UDWF) you must call the -:py:func:`~datafusion.udf.WindowUDF.udwf` function using a class that implements the abstract -class :py:class:`~datafusion.udf.WindowEvaluator`. +:py:func:`~datafusion.user_defined.WindowUDF.udwf` function using a class that implements the abstract +class :py:class:`~datafusion.user_defined.WindowEvaluator`. There are three methods of evaluation of UDWFs. @@ -207,7 +207,7 @@ determine which evaluate functions are called. import pyarrow as pa from datafusion import udwf, col, SessionContext - from datafusion.udf import WindowEvaluator + from datafusion.user_defined import WindowEvaluator class ExponentialSmooth(WindowEvaluator): def __init__(self, alpha: float) -> None: diff --git a/examples/python-udwf.py b/examples/python-udwf.py index 98d118bf2..645ded188 100644 --- a/examples/python-udwf.py +++ b/examples/python-udwf.py @@ -22,7 +22,7 @@ from datafusion import col, lit, udwf from datafusion import functions as f from datafusion.expr import WindowFrame -from datafusion.udf import WindowEvaluator +from datafusion.user_defined import WindowEvaluator # This example creates five different examples of user defined window functions in order # to demonstrate the variety of ways a user may need to implement. diff --git a/pyproject.toml b/pyproject.toml index d86b657ec..728cedb2d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,6 +63,10 @@ exclude = [".github/**", "ci/**", ".asf.yaml"] locked = true features = ["substrait"] +[tool.pytest.ini_options] +asyncio_mode = "auto" +asyncio_default_fixture_loop_scope = "function" + # Enable docstring linting using the google style guide [tool.ruff.lint] select = ["ALL" ] diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py index 60d0d61b4..9ae36fece 100644 --- a/python/datafusion/__init__.py +++ b/python/datafusion/__init__.py @@ -49,7 +49,15 @@ from .io import read_avro, read_csv, read_json, read_parquet from .plan import ExecutionPlan, LogicalPlan from .record_batch import RecordBatch, RecordBatchStream -from .udf import Accumulator, AggregateUDF, ScalarUDF, WindowUDF, udaf, udf, udwf +from .user_defined import ( + Accumulator, + AggregateUDF, + ScalarUDF, + WindowUDF, + udaf, + udf, + udwf, +) __version__ = importlib_metadata.version(__name__) diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 1429a4975..940f597cc 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -30,7 +30,7 @@ from datafusion.dataframe import DataFrame from datafusion.expr import Expr, SortExpr, sort_list_to_raw_sort_list from datafusion.record_batch import RecordBatchStream -from datafusion.udf import AggregateUDF, ScalarUDF, WindowUDF +from datafusion.user_defined import AggregateUDF, ScalarUDF, WindowUDF from ._internal import RuntimeEnvBuilder as RuntimeEnvBuilderInternal from ._internal import SessionConfig as SessionConfigInternal diff --git a/python/datafusion/udf.py b/python/datafusion/udf.py index e93a34ca5..c7265fa09 100644 --- a/python/datafusion/udf.py +++ b/python/datafusion/udf.py @@ -15,753 +15,15 @@ # specific language governing permissions and limitations # under the License. -"""Provides the user-defined functions for evaluation of dataframes.""" +"""Deprecated module for user defined functions.""" -from __future__ import annotations +import warnings -import functools -from abc import ABCMeta, abstractmethod -from enum import Enum -from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, overload +from datafusion.user_defined import * # noqa: F403 -import pyarrow as pa - -import datafusion._internal as df_internal -from datafusion.expr import Expr - -if TYPE_CHECKING: - _R = TypeVar("_R", bound=pa.DataType) - - -class Volatility(Enum): - """Defines how stable or volatile a function is. - - When setting the volatility of a function, you can either pass this - enumeration or a ``str``. The ``str`` equivalent is the lower case value of the - name (`"immutable"`, `"stable"`, or `"volatile"`). - """ - - Immutable = 1 - """An immutable function will always return the same output when given the - same input. - - DataFusion will attempt to inline immutable functions during planning. - """ - - Stable = 2 - """ - Returns the same value for a given input within a single queries. - - A stable function may return different values given the same input across - different queries but must return the same value for a given input within a - query. An example of this is the ``Now`` function. DataFusion will attempt to - inline ``Stable`` functions during planning, when possible. For query - ``select col1, now() from t1``, it might take a while to execute but ``now()`` - column will be the same for each output row, which is evaluated during - planning. - """ - - Volatile = 3 - """A volatile function may change the return value from evaluation to - evaluation. - - Multiple invocations of a volatile function may return different results - when used in the same query. An example of this is the random() function. - DataFusion can not evaluate such functions during planning. In the query - ``select col1, random() from t1``, ``random()`` function will be evaluated - for each output row, resulting in a unique random value for each row. - """ - - def __str__(self) -> str: - """Returns the string equivalent.""" - return self.name.lower() - - -class ScalarUDF: - """Class for performing scalar user-defined functions (UDF). - - Scalar UDFs operate on a row by row basis. See also :py:class:`AggregateUDF` for - operating on a group of rows. - """ - - def __init__( - self, - name: str, - func: Callable[..., _R], - input_types: pa.DataType | list[pa.DataType], - return_type: _R, - volatility: Volatility | str, - ) -> None: - """Instantiate a scalar user-defined function (UDF). - - See helper method :py:func:`udf` for argument details. - """ - if isinstance(input_types, pa.DataType): - input_types = [input_types] - self._udf = df_internal.ScalarUDF( - name, func, input_types, return_type, str(volatility) - ) - - def __call__(self, *args: Expr) -> Expr: - """Execute the UDF. - - This function is not typically called by an end user. These calls will - occur during the evaluation of the dataframe. - """ - args_raw = [arg.expr for arg in args] - return Expr(self._udf.__call__(*args_raw)) - - @overload - @staticmethod - def udf( - input_types: list[pa.DataType], - return_type: _R, - volatility: Volatility | str, - name: Optional[str] = None, - ) -> Callable[..., ScalarUDF]: ... - - @overload - @staticmethod - def udf( - func: Callable[..., _R], - input_types: list[pa.DataType], - return_type: _R, - volatility: Volatility | str, - name: Optional[str] = None, - ) -> ScalarUDF: ... - - @staticmethod - def udf(*args: Any, **kwargs: Any): # noqa: D417 - """Create a new User-Defined Function (UDF). - - This class can be used both as a **function** and as a **decorator**. - - Usage: - - **As a function**: Call `udf(func, input_types, return_type, volatility, - name)`. - - **As a decorator**: Use `@udf(input_types, return_type, volatility, - name)`. In this case, do **not** pass `func` explicitly. - - Args: - func (Callable, optional): **Only needed when calling as a function.** - Skip this argument when using `udf` as a decorator. - input_types (list[pa.DataType]): The data types of the arguments - to `func`. This list must be of the same length as the number of - arguments. - return_type (_R): The data type of the return value from the function. - volatility (Volatility | str): See `Volatility` for allowed values. - name (Optional[str]): A descriptive name for the function. - - Returns: - A user-defined function that can be used in SQL expressions, - data aggregation, or window function calls. - - Example: - **Using `udf` as a function:** - ``` - def double_func(x): - return x * 2 - double_udf = udf(double_func, [pa.int32()], pa.int32(), - "volatile", "double_it") - ``` - - **Using `udf` as a decorator:** - ``` - @udf([pa.int32()], pa.int32(), "volatile", "double_it") - def double_udf(x): - return x * 2 - ``` - """ - - def _function( - func: Callable[..., _R], - input_types: list[pa.DataType], - return_type: _R, - volatility: Volatility | str, - name: Optional[str] = None, - ) -> ScalarUDF: - if not callable(func): - msg = "`func` argument must be callable" - raise TypeError(msg) - if name is None: - if hasattr(func, "__qualname__"): - name = func.__qualname__.lower() - else: - name = func.__class__.__name__.lower() - return ScalarUDF( - name=name, - func=func, - input_types=input_types, - return_type=return_type, - volatility=volatility, - ) - - def _decorator( - input_types: list[pa.DataType], - return_type: _R, - volatility: Volatility | str, - name: Optional[str] = None, - ) -> Callable: - def decorator(func: Callable): - udf_caller = ScalarUDF.udf( - func, input_types, return_type, volatility, name - ) - - @functools.wraps(func) - def wrapper(*args: Any, **kwargs: Any): - return udf_caller(*args, **kwargs) - - return wrapper - - return decorator - - if args and callable(args[0]): - # Case 1: Used as a function, require the first parameter to be callable - return _function(*args, **kwargs) - # Case 2: Used as a decorator with parameters - return _decorator(*args, **kwargs) - - -class Accumulator(metaclass=ABCMeta): - """Defines how an :py:class:`AggregateUDF` accumulates values.""" - - @abstractmethod - def state(self) -> list[pa.Scalar]: - """Return the current state.""" - - @abstractmethod - def update(self, *values: pa.Array) -> None: - """Evaluate an array of values and update state.""" - - @abstractmethod - def merge(self, states: list[pa.Array]) -> None: - """Merge a set of states.""" - - @abstractmethod - def evaluate(self) -> pa.Scalar: - """Return the resultant value.""" - - -class AggregateUDF: - """Class for performing scalar user-defined functions (UDF). - - Aggregate UDFs operate on a group of rows and return a single value. See - also :py:class:`ScalarUDF` for operating on a row by row basis. - """ - - def __init__( - self, - name: str, - accumulator: Callable[[], Accumulator], - input_types: list[pa.DataType], - return_type: pa.DataType, - state_type: list[pa.DataType], - volatility: Volatility | str, - ) -> None: - """Instantiate a user-defined aggregate function (UDAF). - - See :py:func:`udaf` for a convenience function and argument - descriptions. - """ - self._udaf = df_internal.AggregateUDF( - name, - accumulator, - input_types, - return_type, - state_type, - str(volatility), - ) - - def __call__(self, *args: Expr) -> Expr: - """Execute the UDAF. - - This function is not typically called by an end user. These calls will - occur during the evaluation of the dataframe. - """ - args_raw = [arg.expr for arg in args] - return Expr(self._udaf.__call__(*args_raw)) - - @overload - @staticmethod - def udaf( - input_types: pa.DataType | list[pa.DataType], - return_type: pa.DataType, - state_type: list[pa.DataType], - volatility: Volatility | str, - name: Optional[str] = None, - ) -> Callable[..., AggregateUDF]: ... - - @overload - @staticmethod - def udaf( - accum: Callable[[], Accumulator], - input_types: pa.DataType | list[pa.DataType], - return_type: pa.DataType, - state_type: list[pa.DataType], - volatility: Volatility | str, - name: Optional[str] = None, - ) -> AggregateUDF: ... - - @staticmethod - def udaf(*args: Any, **kwargs: Any): # noqa: D417 - """Create a new User-Defined Aggregate Function (UDAF). - - This class allows you to define an **aggregate function** that can be used in - data aggregation or window function calls. - - Usage: - - **As a function**: Call `udaf(accum, input_types, return_type, state_type, - volatility, name)`. - - **As a decorator**: Use `@udaf(input_types, return_type, state_type, - volatility, name)`. - When using `udaf` as a decorator, **do not pass `accum` explicitly**. - - **Function example:** - - If your `:py:class:Accumulator` can be instantiated with no arguments, you - can simply pass it's type as `accum`. If you need to pass additional - arguments to it's constructor, you can define a lambda or a factory method. - During runtime the `:py:class:Accumulator` will be constructed for every - instance in which this UDAF is used. The following examples are all valid. - ``` - import pyarrow as pa - import pyarrow.compute as pc - - class Summarize(Accumulator): - def __init__(self, bias: float = 0.0): - self._sum = pa.scalar(bias) - - def state(self) -> list[pa.Scalar]: - return [self._sum] - - def update(self, values: pa.Array) -> None: - self._sum = pa.scalar(self._sum.as_py() + pc.sum(values).as_py()) - - def merge(self, states: list[pa.Array]) -> None: - self._sum = pa.scalar(self._sum.as_py() + pc.sum(states[0]).as_py()) - - def evaluate(self) -> pa.Scalar: - return self._sum - - def sum_bias_10() -> Summarize: - return Summarize(10.0) - - udaf1 = udaf(Summarize, pa.float64(), pa.float64(), [pa.float64()], - "immutable") - udaf2 = udaf(sum_bias_10, pa.float64(), pa.float64(), [pa.float64()], - "immutable") - udaf3 = udaf(lambda: Summarize(20.0), pa.float64(), pa.float64(), - [pa.float64()], "immutable") - ``` - - **Decorator example:** - ``` - @udaf(pa.float64(), pa.float64(), [pa.float64()], "immutable") - def udf4() -> Summarize: - return Summarize(10.0) - ``` - - Args: - accum: The accumulator python function. **Only needed when calling as a - function. Skip this argument when using `udaf` as a decorator.** - input_types: The data types of the arguments to ``accum``. - return_type: The data type of the return value. - state_type: The data types of the intermediate accumulation. - volatility: See :py:class:`Volatility` for allowed values. - name: A descriptive name for the function. - - Returns: - A user-defined aggregate function, which can be used in either data - aggregation or window function calls. - """ - - def _function( - accum: Callable[[], Accumulator], - input_types: pa.DataType | list[pa.DataType], - return_type: pa.DataType, - state_type: list[pa.DataType], - volatility: Volatility | str, - name: Optional[str] = None, - ) -> AggregateUDF: - if not callable(accum): - msg = "`func` must be callable." - raise TypeError(msg) - if not isinstance(accum(), Accumulator): - msg = "Accumulator must implement the abstract base class Accumulator" - raise TypeError(msg) - if name is None: - name = accum().__class__.__qualname__.lower() - if isinstance(input_types, pa.DataType): - input_types = [input_types] - return AggregateUDF( - name=name, - accumulator=accum, - input_types=input_types, - return_type=return_type, - state_type=state_type, - volatility=volatility, - ) - - def _decorator( - input_types: pa.DataType | list[pa.DataType], - return_type: pa.DataType, - state_type: list[pa.DataType], - volatility: Volatility | str, - name: Optional[str] = None, - ) -> Callable[..., Callable[..., Expr]]: - def decorator(accum: Callable[[], Accumulator]) -> Callable[..., Expr]: - udaf_caller = AggregateUDF.udaf( - accum, input_types, return_type, state_type, volatility, name - ) - - @functools.wraps(accum) - def wrapper(*args: Any, **kwargs: Any) -> Expr: - return udaf_caller(*args, **kwargs) - - return wrapper - - return decorator - - if args and callable(args[0]): - # Case 1: Used as a function, require the first parameter to be callable - return _function(*args, **kwargs) - # Case 2: Used as a decorator with parameters - return _decorator(*args, **kwargs) - - -class WindowEvaluator: - """Evaluator class for user-defined window functions (UDWF). - - It is up to the user to decide which evaluate function is appropriate. - - +------------------------+--------------------------------+------------------+---------------------------+ - | ``uses_window_frame`` | ``supports_bounded_execution`` | ``include_rank`` | function_to_implement | - +========================+================================+==================+===========================+ - | False (default) | False (default) | False (default) | ``evaluate_all`` | - +------------------------+--------------------------------+------------------+---------------------------+ - | False | True | False | ``evaluate`` | - +------------------------+--------------------------------+------------------+---------------------------+ - | False | True/False | True | ``evaluate_all_with_rank``| - +------------------------+--------------------------------+------------------+---------------------------+ - | True | True/False | True/False | ``evaluate`` | - +------------------------+--------------------------------+------------------+---------------------------+ - """ # noqa: W505, E501 - - def memoize(self) -> None: - """Perform a memoize operation to improve performance. - - When the window frame has a fixed beginning (e.g UNBOUNDED - PRECEDING), some functions such as FIRST_VALUE and - NTH_VALUE do not need the (unbounded) input once they have - seen a certain amount of input. - - `memoize` is called after each input batch is processed, and - such functions can save whatever they need - """ - - def get_range(self, idx: int, num_rows: int) -> tuple[int, int]: # noqa: ARG002 - """Return the range for the window fuction. - - If `uses_window_frame` flag is `false`. This method is used to - calculate required range for the window function during - stateful execution. - - Generally there is no required range, hence by default this - returns smallest range(current row). e.g seeing current row is - enough to calculate window result (such as row_number, rank, - etc) - - Args: - idx:: Current index - num_rows: Number of rows. - """ - return (idx, idx + 1) - - def is_causal(self) -> bool: - """Get whether evaluator needs future data for its result.""" - return False - - def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array: - """Evaluate a window function on an entire input partition. - - This function is called once per input *partition* for window functions that - *do not use* values from the window frame, such as - :py:func:`~datafusion.functions.row_number`, - :py:func:`~datafusion.functions.rank`, - :py:func:`~datafusion.functions.dense_rank`, - :py:func:`~datafusion.functions.percent_rank`, - :py:func:`~datafusion.functions.cume_dist`, - :py:func:`~datafusion.functions.lead`, - and :py:func:`~datafusion.functions.lag`. - - It produces the result of all rows in a single pass. It - expects to receive the entire partition as the ``value`` and - must produce an output column with one output row for every - input row. - - ``num_rows`` is required to correctly compute the output in case - ``len(values) == 0`` - - Implementing this function is an optimization. Certain window - functions are not affected by the window frame definition or - the query doesn't have a frame, and ``evaluate`` skips the - (costly) window frame boundary calculation and the overhead of - calling ``evaluate`` for each output row. - - For example, the `LAG` built in window function does not use - the values of its window frame (it can be computed in one shot - on the entire partition with ``Self::evaluate_all`` regardless of the - window defined in the ``OVER`` clause) - - .. code-block:: text - - lag(x, 1) OVER (ORDER BY z ROWS BETWEEN 2 PRECEDING AND 3 FOLLOWING) - - However, ``avg()`` computes the average in the window and thus - does use its window frame. - - .. code-block:: text - - avg(x) OVER (PARTITION BY y ORDER BY z ROWS BETWEEN 2 PRECEDING AND 3 FOLLOWING) - """ # noqa: W505, E501 - - def evaluate( - self, values: list[pa.Array], eval_range: tuple[int, int] - ) -> pa.Scalar: - """Evaluate window function on a range of rows in an input partition. - - This is the simplest and most general function to implement - but also the least performant as it creates output one row at - a time. It is typically much faster to implement stateful - evaluation using one of the other specialized methods on this - trait. - - Returns a [`ScalarValue`] that is the value of the window - function within `range` for the entire partition. Argument - `values` contains the evaluation result of function arguments - and evaluation results of ORDER BY expressions. If function has a - single argument, `values[1..]` will contain ORDER BY expression results. - """ - - def evaluate_all_with_rank( - self, num_rows: int, ranks_in_partition: list[tuple[int, int]] - ) -> pa.Array: - """Called for window functions that only need the rank of a row. - - Evaluate the partition evaluator against the partition using - the row ranks. For example, ``rank(col("a"))`` produces - - .. code-block:: text - - a | rank - - + ---- - A | 1 - A | 1 - C | 3 - D | 4 - D | 4 - - For this case, `num_rows` would be `5` and the - `ranks_in_partition` would be called with - - .. code-block:: text - - [ - (0,1), - (2,2), - (3,4), - ] - - The user must implement this method if ``include_rank`` returns True. - """ - - def supports_bounded_execution(self) -> bool: - """Can the window function be incrementally computed using bounded memory?""" - return False - - def uses_window_frame(self) -> bool: - """Does the window function use the values from the window frame?""" - return False - - def include_rank(self) -> bool: - """Can this function be evaluated with (only) rank?""" - return False - - -class WindowUDF: - """Class for performing window user-defined functions (UDF). - - Window UDFs operate on a partition of rows. See - also :py:class:`ScalarUDF` for operating on a row by row basis. - """ - - def __init__( - self, - name: str, - func: Callable[[], WindowEvaluator], - input_types: list[pa.DataType], - return_type: pa.DataType, - volatility: Volatility | str, - ) -> None: - """Instantiate a user-defined window function (UDWF). - - See :py:func:`udwf` for a convenience function and argument - descriptions. - """ - self._udwf = df_internal.WindowUDF( - name, func, input_types, return_type, str(volatility) - ) - - def __call__(self, *args: Expr) -> Expr: - """Execute the UDWF. - - This function is not typically called by an end user. These calls will - occur during the evaluation of the dataframe. - """ - args_raw = [arg.expr for arg in args] - return Expr(self._udwf.__call__(*args_raw)) - - @overload - @staticmethod - def udwf( - input_types: pa.DataType | list[pa.DataType], - return_type: pa.DataType, - volatility: Volatility | str, - name: Optional[str] = None, - ) -> Callable[..., WindowUDF]: ... - - @overload - @staticmethod - def udwf( - func: Callable[[], WindowEvaluator], - input_types: pa.DataType | list[pa.DataType], - return_type: pa.DataType, - volatility: Volatility | str, - name: Optional[str] = None, - ) -> WindowUDF: ... - - @staticmethod - def udwf(*args: Any, **kwargs: Any): # noqa: D417 - """Create a new User-Defined Window Function (UDWF). - - This class can be used both as a **function** and as a **decorator**. - - Usage: - - **As a function**: Call `udwf(func, input_types, return_type, volatility, - name)`. - - **As a decorator**: Use `@udwf(input_types, return_type, volatility, - name)`. When using `udwf` as a decorator, **do not pass `func` - explicitly**. - - **Function example:** - ``` - import pyarrow as pa - - class BiasedNumbers(WindowEvaluator): - def __init__(self, start: int = 0) -> None: - self.start = start - - def evaluate_all(self, values: list[pa.Array], - num_rows: int) -> pa.Array: - return pa.array([self.start + i for i in range(num_rows)]) - - def bias_10() -> BiasedNumbers: - return BiasedNumbers(10) - - udwf1 = udwf(BiasedNumbers, pa.int64(), pa.int64(), "immutable") - udwf2 = udwf(bias_10, pa.int64(), pa.int64(), "immutable") - udwf3 = udwf(lambda: BiasedNumbers(20), pa.int64(), pa.int64(), "immutable") - - ``` - - **Decorator example:** - ``` - @udwf(pa.int64(), pa.int64(), "immutable") - def biased_numbers() -> BiasedNumbers: - return BiasedNumbers(10) - ``` - - Args: - func: **Only needed when calling as a function. Skip this argument when - using `udwf` as a decorator.** - input_types: The data types of the arguments. - return_type: The data type of the return value. - volatility: See :py:class:`Volatility` for allowed values. - name: A descriptive name for the function. - - Returns: - A user-defined window function that can be used in window function calls. - """ - if args and callable(args[0]): - # Case 1: Used as a function, require the first parameter to be callable - return WindowUDF._create_window_udf(*args, **kwargs) - # Case 2: Used as a decorator with parameters - return WindowUDF._create_window_udf_decorator(*args, **kwargs) - - @staticmethod - def _create_window_udf( - func: Callable[[], WindowEvaluator], - input_types: pa.DataType | list[pa.DataType], - return_type: pa.DataType, - volatility: Volatility | str, - name: Optional[str] = None, - ) -> WindowUDF: - """Create a WindowUDF instance from function arguments.""" - if not callable(func): - msg = "`func` must be callable." - raise TypeError(msg) - if not isinstance(func(), WindowEvaluator): - msg = "`func` must implement the abstract base class WindowEvaluator" - raise TypeError(msg) - - name = name or func.__qualname__.lower() - input_types = ( - [input_types] if isinstance(input_types, pa.DataType) else input_types - ) - - return WindowUDF(name, func, input_types, return_type, volatility) - - @staticmethod - def _get_default_name(func: Callable) -> str: - """Get the default name for a function based on its attributes.""" - if hasattr(func, "__qualname__"): - return func.__qualname__.lower() - return func.__class__.__name__.lower() - - @staticmethod - def _normalize_input_types( - input_types: pa.DataType | list[pa.DataType], - ) -> list[pa.DataType]: - """Convert a single DataType to a list if needed.""" - if isinstance(input_types, pa.DataType): - return [input_types] - return input_types - - @staticmethod - def _create_window_udf_decorator( - input_types: pa.DataType | list[pa.DataType], - return_type: pa.DataType, - volatility: Volatility | str, - name: Optional[str] = None, - ) -> Callable[[Callable[[], WindowEvaluator]], Callable[..., Expr]]: - """Create a decorator for a WindowUDF.""" - - def decorator(func: Callable[[], WindowEvaluator]) -> Callable[..., Expr]: - udwf_caller = WindowUDF._create_window_udf( - func, input_types, return_type, volatility, name - ) - - @functools.wraps(func) - def wrapper(*args: Any, **kwargs: Any) -> Expr: - return udwf_caller(*args, **kwargs) - - return wrapper - - return decorator - - -# Convenience exports so we can import instead of treating as -# variables at the package root -udf = ScalarUDF.udf -udaf = AggregateUDF.udaf -udwf = WindowUDF.udwf +warnings.warn( + "The module 'udf' is deprecated and will be removed in the next release. " + "Please use 'user_defined' instead.", + DeprecationWarning, + stacklevel=2, +) diff --git a/python/datafusion/user_defined.py b/python/datafusion/user_defined.py new file mode 100644 index 000000000..f7302b01a --- /dev/null +++ b/python/datafusion/user_defined.py @@ -0,0 +1,755 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Provides the user-defined functions for evaluation of dataframes.""" + +from __future__ import annotations + +import functools +from abc import ABCMeta, abstractmethod +from enum import Enum +from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, overload + +import pyarrow as pa + +import datafusion._internal as df_internal +from datafusion.expr import Expr + +if TYPE_CHECKING: + _R = TypeVar("_R", bound=pa.DataType) + + +class Volatility(Enum): + """Defines how stable or volatile a function is. + + When setting the volatility of a function, you can either pass this + enumeration or a ``str``. The ``str`` equivalent is the lower case value of the + name (`"immutable"`, `"stable"`, or `"volatile"`). + """ + + Immutable = 1 + """An immutable function will always return the same output when given the + same input. + + DataFusion will attempt to inline immutable functions during planning. + """ + + Stable = 2 + """ + Returns the same value for a given input within a single queries. + + A stable function may return different values given the same input across + different queries but must return the same value for a given input within a + query. An example of this is the ``Now`` function. DataFusion will attempt to + inline ``Stable`` functions during planning, when possible. For query + ``select col1, now() from t1``, it might take a while to execute but ``now()`` + column will be the same for each output row, which is evaluated during + planning. + """ + + Volatile = 3 + """A volatile function may change the return value from evaluation to + evaluation. + + Multiple invocations of a volatile function may return different results + when used in the same query. An example of this is the random() function. + DataFusion can not evaluate such functions during planning. In the query + ``select col1, random() from t1``, ``random()`` function will be evaluated + for each output row, resulting in a unique random value for each row. + """ + + def __str__(self) -> str: + """Returns the string equivalent.""" + return self.name.lower() + + +class ScalarUDF: + """Class for performing scalar user-defined functions (UDF). + + Scalar UDFs operate on a row by row basis. See also :py:class:`AggregateUDF` for + operating on a group of rows. + """ + + def __init__( + self, + name: str, + func: Callable[..., _R], + input_types: pa.DataType | list[pa.DataType], + return_type: _R, + volatility: Volatility | str, + ) -> None: + """Instantiate a scalar user-defined function (UDF). + + See helper method :py:func:`udf` for argument details. + """ + if isinstance(input_types, pa.DataType): + input_types = [input_types] + self._udf = df_internal.ScalarUDF( + name, func, input_types, return_type, str(volatility) + ) + + def __call__(self, *args: Expr) -> Expr: + """Execute the UDF. + + This function is not typically called by an end user. These calls will + occur during the evaluation of the dataframe. + """ + args_raw = [arg.expr for arg in args] + return Expr(self._udf.__call__(*args_raw)) + + @overload + @staticmethod + def udf( + input_types: list[pa.DataType], + return_type: _R, + volatility: Volatility | str, + name: Optional[str] = None, + ) -> Callable[..., ScalarUDF]: ... + + @overload + @staticmethod + def udf( + func: Callable[..., _R], + input_types: list[pa.DataType], + return_type: _R, + volatility: Volatility | str, + name: Optional[str] = None, + ) -> ScalarUDF: ... + + @staticmethod + def udf(*args: Any, **kwargs: Any): # noqa: D417 + """Create a new User-Defined Function (UDF). + + This class can be used both as either a function or a decorator. + + Usage: + - As a function: ``udf(func, input_types, return_type, volatility, name)``. + - As a decorator: ``@udf(input_types, return_type, volatility, name)``. + When used a decorator, do **not** pass ``func`` explicitly. + + Args: + func (Callable, optional): Only needed when calling as a function. + Skip this argument when using ``udf`` as a decorator. + input_types (list[pa.DataType]): The data types of the arguments + to ``func``. This list must be of the same length as the number of + arguments. + return_type (_R): The data type of the return value from the function. + volatility (Volatility | str): See `Volatility` for allowed values. + name (Optional[str]): A descriptive name for the function. + + Returns: + A user-defined function that can be used in SQL expressions, + data aggregation, or window function calls. + + Example: Using ``udf`` as a function:: + + def double_func(x): + return x * 2 + double_udf = udf(double_func, [pa.int32()], pa.int32(), + "volatile", "double_it") + + Example: Using ``udf`` as a decorator:: + + @udf([pa.int32()], pa.int32(), "volatile", "double_it") + def double_udf(x): + return x * 2 + """ + + def _function( + func: Callable[..., _R], + input_types: list[pa.DataType], + return_type: _R, + volatility: Volatility | str, + name: Optional[str] = None, + ) -> ScalarUDF: + if not callable(func): + msg = "`func` argument must be callable" + raise TypeError(msg) + if name is None: + if hasattr(func, "__qualname__"): + name = func.__qualname__.lower() + else: + name = func.__class__.__name__.lower() + return ScalarUDF( + name=name, + func=func, + input_types=input_types, + return_type=return_type, + volatility=volatility, + ) + + def _decorator( + input_types: list[pa.DataType], + return_type: _R, + volatility: Volatility | str, + name: Optional[str] = None, + ) -> Callable: + def decorator(func: Callable): + udf_caller = ScalarUDF.udf( + func, input_types, return_type, volatility, name + ) + + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any): + return udf_caller(*args, **kwargs) + + return wrapper + + return decorator + + if args and callable(args[0]): + # Case 1: Used as a function, require the first parameter to be callable + return _function(*args, **kwargs) + # Case 2: Used as a decorator with parameters + return _decorator(*args, **kwargs) + + +class Accumulator(metaclass=ABCMeta): + """Defines how an :py:class:`AggregateUDF` accumulates values.""" + + @abstractmethod + def state(self) -> list[pa.Scalar]: + """Return the current state.""" + + @abstractmethod + def update(self, *values: pa.Array) -> None: + """Evaluate an array of values and update state.""" + + @abstractmethod + def merge(self, states: list[pa.Array]) -> None: + """Merge a set of states.""" + + @abstractmethod + def evaluate(self) -> pa.Scalar: + """Return the resultant value.""" + + +class AggregateUDF: + """Class for performing scalar user-defined functions (UDF). + + Aggregate UDFs operate on a group of rows and return a single value. See + also :py:class:`ScalarUDF` for operating on a row by row basis. + """ + + def __init__( + self, + name: str, + accumulator: Callable[[], Accumulator], + input_types: list[pa.DataType], + return_type: pa.DataType, + state_type: list[pa.DataType], + volatility: Volatility | str, + ) -> None: + """Instantiate a user-defined aggregate function (UDAF). + + See :py:func:`udaf` for a convenience function and argument + descriptions. + """ + self._udaf = df_internal.AggregateUDF( + name, + accumulator, + input_types, + return_type, + state_type, + str(volatility), + ) + + def __call__(self, *args: Expr) -> Expr: + """Execute the UDAF. + + This function is not typically called by an end user. These calls will + occur during the evaluation of the dataframe. + """ + args_raw = [arg.expr for arg in args] + return Expr(self._udaf.__call__(*args_raw)) + + @overload + @staticmethod + def udaf( + input_types: pa.DataType | list[pa.DataType], + return_type: pa.DataType, + state_type: list[pa.DataType], + volatility: Volatility | str, + name: Optional[str] = None, + ) -> Callable[..., AggregateUDF]: ... + + @overload + @staticmethod + def udaf( + accum: Callable[[], Accumulator], + input_types: pa.DataType | list[pa.DataType], + return_type: pa.DataType, + state_type: list[pa.DataType], + volatility: Volatility | str, + name: Optional[str] = None, + ) -> AggregateUDF: ... + + @staticmethod + def udaf(*args: Any, **kwargs: Any): # noqa: D417 + """Create a new User-Defined Aggregate Function (UDAF). + + This class allows you to define an aggregate function that can be used in + data aggregation or window function calls. + + Usage: + - As a function: ``udaf(accum, input_types, return_type, state_type, volatility, name)``. + - As a decorator: ``@udaf(input_types, return_type, state_type, volatility, name)``. + When using ``udaf`` as a decorator, do not pass ``accum`` explicitly. + + Function example: + + If your :py:class:`Accumulator` can be instantiated with no arguments, you + can simply pass it's type as `accum`. If you need to pass additional + arguments to it's constructor, you can define a lambda or a factory method. + During runtime the :py:class:`Accumulator` will be constructed for every + instance in which this UDAF is used. The following examples are all valid:: + + import pyarrow as pa + import pyarrow.compute as pc + + class Summarize(Accumulator): + def __init__(self, bias: float = 0.0): + self._sum = pa.scalar(bias) + + def state(self) -> list[pa.Scalar]: + return [self._sum] + + def update(self, values: pa.Array) -> None: + self._sum = pa.scalar(self._sum.as_py() + pc.sum(values).as_py()) + + def merge(self, states: list[pa.Array]) -> None: + self._sum = pa.scalar(self._sum.as_py() + pc.sum(states[0]).as_py()) + + def evaluate(self) -> pa.Scalar: + return self._sum + + def sum_bias_10() -> Summarize: + return Summarize(10.0) + + udaf1 = udaf(Summarize, pa.float64(), pa.float64(), [pa.float64()], + "immutable") + udaf2 = udaf(sum_bias_10, pa.float64(), pa.float64(), [pa.float64()], + "immutable") + udaf3 = udaf(lambda: Summarize(20.0), pa.float64(), pa.float64(), + [pa.float64()], "immutable") + + Decorator example::: + + @udaf(pa.float64(), pa.float64(), [pa.float64()], "immutable") + def udf4() -> Summarize: + return Summarize(10.0) + + Args: + accum: The accumulator python function. Only needed when calling as a + function. Skip this argument when using ``udaf`` as a decorator. + input_types: The data types of the arguments to ``accum``. + return_type: The data type of the return value. + state_type: The data types of the intermediate accumulation. + volatility: See :py:class:`Volatility` for allowed values. + name: A descriptive name for the function. + + Returns: + A user-defined aggregate function, which can be used in either data + aggregation or window function calls. + """ # noqa: E501 W505 + + def _function( + accum: Callable[[], Accumulator], + input_types: pa.DataType | list[pa.DataType], + return_type: pa.DataType, + state_type: list[pa.DataType], + volatility: Volatility | str, + name: Optional[str] = None, + ) -> AggregateUDF: + if not callable(accum): + msg = "`func` must be callable." + raise TypeError(msg) + if not isinstance(accum(), Accumulator): + msg = "Accumulator must implement the abstract base class Accumulator" + raise TypeError(msg) + if name is None: + name = accum().__class__.__qualname__.lower() + if isinstance(input_types, pa.DataType): + input_types = [input_types] + return AggregateUDF( + name=name, + accumulator=accum, + input_types=input_types, + return_type=return_type, + state_type=state_type, + volatility=volatility, + ) + + def _decorator( + input_types: pa.DataType | list[pa.DataType], + return_type: pa.DataType, + state_type: list[pa.DataType], + volatility: Volatility | str, + name: Optional[str] = None, + ) -> Callable[..., Callable[..., Expr]]: + def decorator(accum: Callable[[], Accumulator]) -> Callable[..., Expr]: + udaf_caller = AggregateUDF.udaf( + accum, input_types, return_type, state_type, volatility, name + ) + + @functools.wraps(accum) + def wrapper(*args: Any, **kwargs: Any) -> Expr: + return udaf_caller(*args, **kwargs) + + return wrapper + + return decorator + + if args and callable(args[0]): + # Case 1: Used as a function, require the first parameter to be callable + return _function(*args, **kwargs) + # Case 2: Used as a decorator with parameters + return _decorator(*args, **kwargs) + + +class WindowEvaluator: + """Evaluator class for user-defined window functions (UDWF). + + It is up to the user to decide which evaluate function is appropriate. + + +------------------------+--------------------------------+------------------+---------------------------+ + | ``uses_window_frame`` | ``supports_bounded_execution`` | ``include_rank`` | function_to_implement | + +========================+================================+==================+===========================+ + | False (default) | False (default) | False (default) | ``evaluate_all`` | + +------------------------+--------------------------------+------------------+---------------------------+ + | False | True | False | ``evaluate`` | + +------------------------+--------------------------------+------------------+---------------------------+ + | False | True/False | True | ``evaluate_all_with_rank``| + +------------------------+--------------------------------+------------------+---------------------------+ + | True | True/False | True/False | ``evaluate`` | + +------------------------+--------------------------------+------------------+---------------------------+ + """ # noqa: W505, E501 + + def memoize(self) -> None: + """Perform a memoize operation to improve performance. + + When the window frame has a fixed beginning (e.g UNBOUNDED + PRECEDING), some functions such as FIRST_VALUE and + NTH_VALUE do not need the (unbounded) input once they have + seen a certain amount of input. + + `memoize` is called after each input batch is processed, and + such functions can save whatever they need + """ + + def get_range(self, idx: int, num_rows: int) -> tuple[int, int]: # noqa: ARG002 + """Return the range for the window fuction. + + If `uses_window_frame` flag is `false`. This method is used to + calculate required range for the window function during + stateful execution. + + Generally there is no required range, hence by default this + returns smallest range(current row). e.g seeing current row is + enough to calculate window result (such as row_number, rank, + etc) + + Args: + idx:: Current index + num_rows: Number of rows. + """ + return (idx, idx + 1) + + def is_causal(self) -> bool: + """Get whether evaluator needs future data for its result.""" + return False + + def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array: + """Evaluate a window function on an entire input partition. + + This function is called once per input *partition* for window functions that + *do not use* values from the window frame, such as + :py:func:`~datafusion.functions.row_number`, + :py:func:`~datafusion.functions.rank`, + :py:func:`~datafusion.functions.dense_rank`, + :py:func:`~datafusion.functions.percent_rank`, + :py:func:`~datafusion.functions.cume_dist`, + :py:func:`~datafusion.functions.lead`, + and :py:func:`~datafusion.functions.lag`. + + It produces the result of all rows in a single pass. It + expects to receive the entire partition as the ``value`` and + must produce an output column with one output row for every + input row. + + ``num_rows`` is required to correctly compute the output in case + ``len(values) == 0`` + + Implementing this function is an optimization. Certain window + functions are not affected by the window frame definition or + the query doesn't have a frame, and ``evaluate`` skips the + (costly) window frame boundary calculation and the overhead of + calling ``evaluate`` for each output row. + + For example, the `LAG` built in window function does not use + the values of its window frame (it can be computed in one shot + on the entire partition with ``Self::evaluate_all`` regardless of the + window defined in the ``OVER`` clause) + + .. code-block:: text + + lag(x, 1) OVER (ORDER BY z ROWS BETWEEN 2 PRECEDING AND 3 FOLLOWING) + + However, ``avg()`` computes the average in the window and thus + does use its window frame. + + .. code-block:: text + + avg(x) OVER (PARTITION BY y ORDER BY z ROWS BETWEEN 2 PRECEDING AND 3 FOLLOWING) + """ # noqa: W505, E501 + + def evaluate( + self, values: list[pa.Array], eval_range: tuple[int, int] + ) -> pa.Scalar: + """Evaluate window function on a range of rows in an input partition. + + This is the simplest and most general function to implement + but also the least performant as it creates output one row at + a time. It is typically much faster to implement stateful + evaluation using one of the other specialized methods on this + trait. + + Returns a [`ScalarValue`] that is the value of the window + function within `range` for the entire partition. Argument + `values` contains the evaluation result of function arguments + and evaluation results of ORDER BY expressions. If function has a + single argument, `values[1..]` will contain ORDER BY expression results. + """ + + def evaluate_all_with_rank( + self, num_rows: int, ranks_in_partition: list[tuple[int, int]] + ) -> pa.Array: + """Called for window functions that only need the rank of a row. + + Evaluate the partition evaluator against the partition using + the row ranks. For example, ``rank(col("a"))`` produces + + .. code-block:: text + + a | rank + - + ---- + A | 1 + A | 1 + C | 3 + D | 4 + D | 4 + + For this case, `num_rows` would be `5` and the + `ranks_in_partition` would be called with + + .. code-block:: text + + [ + (0,1), + (2,2), + (3,4), + ] + + The user must implement this method if ``include_rank`` returns True. + """ + + def supports_bounded_execution(self) -> bool: + """Can the window function be incrementally computed using bounded memory?""" + return False + + def uses_window_frame(self) -> bool: + """Does the window function use the values from the window frame?""" + return False + + def include_rank(self) -> bool: + """Can this function be evaluated with (only) rank?""" + return False + + +class WindowUDF: + """Class for performing window user-defined functions (UDF). + + Window UDFs operate on a partition of rows. See + also :py:class:`ScalarUDF` for operating on a row by row basis. + """ + + def __init__( + self, + name: str, + func: Callable[[], WindowEvaluator], + input_types: list[pa.DataType], + return_type: pa.DataType, + volatility: Volatility | str, + ) -> None: + """Instantiate a user-defined window function (UDWF). + + See :py:func:`udwf` for a convenience function and argument + descriptions. + """ + self._udwf = df_internal.WindowUDF( + name, func, input_types, return_type, str(volatility) + ) + + def __call__(self, *args: Expr) -> Expr: + """Execute the UDWF. + + This function is not typically called by an end user. These calls will + occur during the evaluation of the dataframe. + """ + args_raw = [arg.expr for arg in args] + return Expr(self._udwf.__call__(*args_raw)) + + @overload + @staticmethod + def udwf( + input_types: pa.DataType | list[pa.DataType], + return_type: pa.DataType, + volatility: Volatility | str, + name: Optional[str] = None, + ) -> Callable[..., WindowUDF]: ... + + @overload + @staticmethod + def udwf( + func: Callable[[], WindowEvaluator], + input_types: pa.DataType | list[pa.DataType], + return_type: pa.DataType, + volatility: Volatility | str, + name: Optional[str] = None, + ) -> WindowUDF: ... + + @staticmethod + def udwf(*args: Any, **kwargs: Any): # noqa: D417 + """Create a new User-Defined Window Function (UDWF). + + This class can be used both as either a function or a decorator. + + Usage: + - As a function: ``udwf(func, input_types, return_type, volatility, name)``. + - As a decorator: ``@udwf(input_types, return_type, volatility, name)``. + When using ``udwf`` as a decorator, do not pass ``func`` explicitly. + + Function example:: + + import pyarrow as pa + + class BiasedNumbers(WindowEvaluator): + def __init__(self, start: int = 0) -> None: + self.start = start + + def evaluate_all(self, values: list[pa.Array], + num_rows: int) -> pa.Array: + return pa.array([self.start + i for i in range(num_rows)]) + + def bias_10() -> BiasedNumbers: + return BiasedNumbers(10) + + udwf1 = udwf(BiasedNumbers, pa.int64(), pa.int64(), "immutable") + udwf2 = udwf(bias_10, pa.int64(), pa.int64(), "immutable") + udwf3 = udwf(lambda: BiasedNumbers(20), pa.int64(), pa.int64(), "immutable") + + + Decorator example:: + + @udwf(pa.int64(), pa.int64(), "immutable") + def biased_numbers() -> BiasedNumbers: + return BiasedNumbers(10) + + Args: + func: Only needed when calling as a function. Skip this argument when + using ``udwf`` as a decorator. + input_types: The data types of the arguments. + return_type: The data type of the return value. + volatility: See :py:class:`Volatility` for allowed values. + name: A descriptive name for the function. + + Returns: + A user-defined window function that can be used in window function calls. + """ + if args and callable(args[0]): + # Case 1: Used as a function, require the first parameter to be callable + return WindowUDF._create_window_udf(*args, **kwargs) + # Case 2: Used as a decorator with parameters + return WindowUDF._create_window_udf_decorator(*args, **kwargs) + + @staticmethod + def _create_window_udf( + func: Callable[[], WindowEvaluator], + input_types: pa.DataType | list[pa.DataType], + return_type: pa.DataType, + volatility: Volatility | str, + name: Optional[str] = None, + ) -> WindowUDF: + """Create a WindowUDF instance from function arguments.""" + if not callable(func): + msg = "`func` must be callable." + raise TypeError(msg) + if not isinstance(func(), WindowEvaluator): + msg = "`func` must implement the abstract base class WindowEvaluator" + raise TypeError(msg) + + name = name or func.__qualname__.lower() + input_types = ( + [input_types] if isinstance(input_types, pa.DataType) else input_types + ) + + return WindowUDF(name, func, input_types, return_type, volatility) + + @staticmethod + def _get_default_name(func: Callable) -> str: + """Get the default name for a function based on its attributes.""" + if hasattr(func, "__qualname__"): + return func.__qualname__.lower() + return func.__class__.__name__.lower() + + @staticmethod + def _normalize_input_types( + input_types: pa.DataType | list[pa.DataType], + ) -> list[pa.DataType]: + """Convert a single DataType to a list if needed.""" + if isinstance(input_types, pa.DataType): + return [input_types] + return input_types + + @staticmethod + def _create_window_udf_decorator( + input_types: pa.DataType | list[pa.DataType], + return_type: pa.DataType, + volatility: Volatility | str, + name: Optional[str] = None, + ) -> Callable[[Callable[[], WindowEvaluator]], Callable[..., Expr]]: + """Create a decorator for a WindowUDF.""" + + def decorator(func: Callable[[], WindowEvaluator]) -> Callable[..., Expr]: + udwf_caller = WindowUDF._create_window_udf( + func, input_types, return_type, volatility, name + ) + + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Expr: + return udwf_caller(*args, **kwargs) + + return wrapper + + return decorator + + +# Convenience exports so we can import instead of treating as +# variables at the package root +udf = ScalarUDF.udf +udaf = AggregateUDF.udaf +udwf = WindowUDF.udwf diff --git a/python/tests/test_imports.py b/python/tests/test_imports.py index 9ef7ed89a..fca94b35a 100644 --- a/python/tests/test_imports.py +++ b/python/tests/test_imports.py @@ -107,7 +107,7 @@ def test_class_module_is_datafusion(): AggregateUDF, ScalarUDF, ]: - assert klass.__module__ == "datafusion.udf" + assert klass.__module__ == "datafusion.user_defined" # expressions for klass in [Expr, Column, Literal, BinaryExpr, AggregateFunction]: diff --git a/python/tests/test_udwf.py b/python/tests/test_udwf.py index 4190e7d64..5aaf00664 100644 --- a/python/tests/test_udwf.py +++ b/python/tests/test_udwf.py @@ -22,7 +22,7 @@ from datafusion import SessionContext, column, lit, udwf from datafusion import functions as f from datafusion.expr import WindowFrame -from datafusion.udf import WindowEvaluator +from datafusion.user_defined import WindowEvaluator class ExponentialSmoothDefault(WindowEvaluator):