diff --git a/python/datafusion/tests/test_udaf.py b/python/datafusion/tests/test_udaf.py index 76488e19b..6f2525b0f 100644 --- a/python/datafusion/tests/test_udaf.py +++ b/python/datafusion/tests/test_udaf.py @@ -21,7 +21,7 @@ import pyarrow.compute as pc import pytest -from datafusion import Accumulator, SessionContext, column, udaf +from datafusion import Accumulator, column, udaf, udf class Summarize(Accumulator): @@ -60,18 +60,15 @@ def state(self) -> List[pa.Scalar]: @pytest.fixture -def df(): - ctx = SessionContext() - +def df(ctx): # create a RecordBatch and a new DataFrame from it batch = pa.RecordBatch.from_arrays( [pa.array([1, 2, 3]), pa.array([4, 4, 6])], names=["a", "b"], ) - return ctx.create_dataframe([[batch]]) + return ctx.create_dataframe([[batch]], name="test_table") -@pytest.mark.skip(reason="df.collect() will hang, need more investigations") def test_errors(df): with pytest.raises(TypeError): udaf( @@ -92,8 +89,9 @@ def test_errors(df): df = df.aggregate([], [accum(column("a"))]) msg = ( - "Can't instantiate abstract class MissingMethods with abstract " - "methods evaluate, merge, update" + "Can't instantiate abstract class MissingMethods (without an implementation " + "for abstract methods 'evaluate', 'merge', 'update'|with abstract methods " + "evaluate, merge, update)" ) with pytest.raises(Exception, match=msg): df.collect() @@ -132,3 +130,36 @@ def test_group_by(df): arrays = [batch.column(1) for batch in batches] joined = pa.concat_arrays(arrays) assert joined == pa.array([1.0 + 2.0, 3.0]) + + +def test_register_udaf(ctx, df) -> None: + summarize = udaf( + Summarize, + pa.float64(), + pa.float64(), + [pa.float64()], + volatility="immutable", + ) + + ctx.register_udaf(summarize) + + df_result = ctx.sql("select summarize(b) from test_table") + + assert df_result.collect()[0][0][0].as_py() == 14.0 + + +def test_register_udf(ctx, df) -> None: + is_null = udf( + lambda x: x.is_null(), + [pa.float64()], + pa.bool_(), + volatility="immutable", + name="is_null", + ) + + ctx.register_udf(is_null) + + df_result = ctx.sql("select is_null(a) from test_table") + result = df_result.collect()[0].column(0) + + assert result == pa.array([False, False, False]) diff --git a/python/datafusion/udf.py b/python/datafusion/udf.py index a3b74bb11..f74d675e3 100644 --- a/python/datafusion/udf.py +++ b/python/datafusion/udf.py @@ -192,7 +192,7 @@ def __init__( See :py:func:`udaf` for a convenience function and argument descriptions. """ - self._udf = df_internal.AggregateUDF( + self._udaf = df_internal.AggregateUDF( name, accumulator, input_types, return_type, state_type, str(volatility) ) @@ -203,7 +203,7 @@ def __call__(self, *args: Expr) -> Expr: occur during the evaluation of the dataframe. """ args = [arg.expr for arg in args] - return Expr(self._udf.__call__(*args)) + return Expr(self._udaf.__call__(*args)) @staticmethod def udaf(