Skip to content

Commit c119b5d

Browse files
committed
Add option for passing in constructor arguments to the udaf
1 parent d181a30 commit c119b5d

File tree

3 files changed

+54
-11
lines changed

3 files changed

+54
-11
lines changed

python/datafusion/udf.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from datafusion.expr import Expr
2424
from typing import Callable, TYPE_CHECKING, TypeVar
2525
from abc import ABCMeta, abstractmethod
26-
from typing import List
26+
from typing import List, Any, Optional
2727
from enum import Enum
2828
import pyarrow
2929

@@ -186,14 +186,21 @@ def __init__(
186186
return_type: _R,
187187
state_type: list[pyarrow.DataType],
188188
volatility: Volatility | str,
189+
arguments: list[Any],
189190
) -> None:
190191
"""Instantiate a user-defined aggregate function (UDAF).
191192
192193
See :py:func:`udaf` for a convenience function and argument
193194
descriptions.
194195
"""
195196
self._udaf = df_internal.AggregateUDF(
196-
name, accumulator, input_types, return_type, state_type, str(volatility)
197+
name,
198+
accumulator,
199+
input_types,
200+
return_type,
201+
state_type,
202+
str(volatility),
203+
arguments,
197204
)
198205

199206
def __call__(self, *args: Expr) -> Expr:
@@ -212,6 +219,7 @@ def udaf(
212219
return_type: _R,
213220
state_type: list[pyarrow.DataType],
214221
volatility: Volatility | str,
222+
arguments: Optional[list[Any]] = None,
215223
name: str | None = None,
216224
) -> AggregateUDF:
217225
"""Create a new User-Defined Aggregate Function.
@@ -224,6 +232,7 @@ def udaf(
224232
return_type: The data type of the return value.
225233
state_type: The data types of the intermediate accumulation.
226234
volatility: See :py:class:`Volatility` for allowed values.
235+
arguments: A list of arguments to pass in to the __init__ method for accum.
227236
name: A descriptive name for the function.
228237
229238
Returns:
@@ -238,13 +247,15 @@ def udaf(
238247
name = accum.__qualname__.lower()
239248
if isinstance(input_types, pyarrow.lib.DataType):
240249
input_types = [input_types]
250+
arguments = [] if arguments is None else arguments
241251
return AggregateUDF(
242252
name=name,
243253
accumulator=accum,
244254
input_types=input_types,
245255
return_type=return_type,
246256
state_type=state_type,
247257
volatility=volatility,
258+
arguments=arguments,
248259
)
249260

250261

python/tests/test_udaf.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
class Summarize(Accumulator):
2828
"""Interface of a user-defined accumulation."""
2929

30-
def __init__(self):
31-
self._sum = pa.scalar(0.0)
30+
def __init__(self, initial_value: float = 0.0):
31+
self._sum = pa.scalar(initial_value)
3232

3333
def state(self) -> List[pa.Scalar]:
3434
return [self._sum]
@@ -97,7 +97,7 @@ def test_errors(df):
9797
df.collect()
9898

9999

100-
def test_aggregate(df):
100+
def test_udaf_aggregate(df):
101101
summarize = udaf(
102102
Summarize,
103103
pa.float64(),
@@ -106,14 +106,41 @@ def test_aggregate(df):
106106
volatility="immutable",
107107
)
108108

109-
df = df.aggregate([], [summarize(column("a"))])
109+
df1 = df.aggregate([], [summarize(column("a"))])
110110

111111
# execute and collect the first (and only) batch
112-
result = df.collect()[0]
112+
result = df1.collect()[0]
113+
114+
assert result.column(0) == pa.array([1.0 + 2.0 + 3.0])
115+
116+
df2 = df.aggregate([], [summarize(column("a"))])
117+
118+
# Run a second time to ensure the state is properly reset
119+
result = df2.collect()[0]
113120

114121
assert result.column(0) == pa.array([1.0 + 2.0 + 3.0])
115122

116123

124+
def test_udaf_aggregate_with_arguments(df):
125+
bias = 10.0
126+
127+
summarize = udaf(
128+
Summarize,
129+
pa.float64(),
130+
pa.float64(),
131+
[pa.float64()],
132+
volatility="immutable",
133+
arguments=[bias],
134+
)
135+
136+
df1 = df.aggregate([], [summarize(column("a"))])
137+
138+
# execute and collect the first (and only) batch
139+
result = df1.collect()[0]
140+
141+
assert result.column(0) == pa.array([bias + 1.0 + 2.0 + 3.0])
142+
143+
117144
def test_group_by(df):
118145
summarize = udaf(
119146
Summarize,

src/udaf.rs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,11 +128,15 @@ impl Accumulator for RustAccumulator {
128128
}
129129
}
130130

131-
pub fn to_rust_accumulator(accum: PyObject) -> AccumulatorFactoryFunction {
131+
pub fn to_rust_accumulator(
132+
accum: PyObject,
133+
arguments: Vec<PyObject>,
134+
) -> AccumulatorFactoryFunction {
132135
Arc::new(move |_| -> Result<Box<dyn Accumulator>> {
133136
let accum = Python::with_gil(|py| {
137+
let py_args = PyTuple::new_bound(py, arguments.iter());
134138
accum
135-
.call0(py)
139+
.call1(py, py_args)
136140
.map_err(|e| DataFusionError::Execution(format!("{e}")))
137141
})?;
138142
Ok(Box::new(RustAccumulator::new(accum)))
@@ -149,21 +153,22 @@ pub struct PyAggregateUDF {
149153
#[pymethods]
150154
impl PyAggregateUDF {
151155
#[new]
152-
#[pyo3(signature=(name, accumulator, input_type, return_type, state_type, volatility))]
156+
#[pyo3(signature=(name, accumulator, input_type, return_type, state_type, volatility, arguments))]
153157
fn new(
154158
name: &str,
155159
accumulator: PyObject,
156160
input_type: PyArrowType<Vec<DataType>>,
157161
return_type: PyArrowType<DataType>,
158162
state_type: PyArrowType<Vec<DataType>>,
159163
volatility: &str,
164+
arguments: Vec<PyObject>,
160165
) -> PyResult<Self> {
161166
let function = create_udaf(
162167
name,
163168
input_type.0,
164169
Arc::new(return_type.0),
165170
parse_volatility(volatility)?,
166-
to_rust_accumulator(accumulator),
171+
to_rust_accumulator(accumulator, arguments),
167172
Arc::new(state_type.0),
168173
);
169174
Ok(Self { function })

0 commit comments

Comments
 (0)