Skip to content

Commit 9b2992d

Browse files
committed
Change udwf() to take an instance rather than a class so we can parameterize it
1 parent 8663e77 commit 9b2992d

File tree

3 files changed

+12
-16
lines changed

3 files changed

+12
-16
lines changed

python/datafusion/tests/test_udwf.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
class ExponentialSmooth(WindowEvaluator):
2626
"""Interface of a user-defined accumulation."""
2727

28-
def __init__(self) -> None:
29-
self.alpha = 0.9
28+
def __init__(self, alpha: float) -> None:
29+
self.alpha = alpha
3030

3131
def evaluate_all(self, values: pa.Array, num_rows: int) -> pa.Array:
3232
results = []
@@ -66,15 +66,15 @@ def df():
6666
def test_udwf_errors(df):
6767
with pytest.raises(TypeError):
6868
udwf(
69-
NotSubclassOfWindowEvaluator,
69+
NotSubclassOfWindowEvaluator(),
7070
pa.float64(),
7171
pa.float64(),
7272
volatility="immutable",
7373
)
7474

7575

7676
smooth = udwf(
77-
ExponentialSmooth,
77+
ExponentialSmooth(0.9),
7878
pa.float64(),
7979
pa.float64(),
8080
volatility="immutable",

python/datafusion/udf.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
import datafusion._internal as df_internal
2323
from datafusion.expr import Expr
24-
from typing import Callable, TYPE_CHECKING, TypeVar, Type
24+
from typing import Callable, TYPE_CHECKING, TypeVar
2525
from abc import ABCMeta, abstractmethod
2626
from typing import List
2727
from enum import Enum
@@ -412,7 +412,7 @@ class WindowUDF:
412412
def __init__(
413413
self,
414414
name: str | None,
415-
func: Type[WindowEvaluator],
415+
func: WindowEvaluator,
416416
input_type: pyarrow.DataType,
417417
return_type: pyarrow.DataType,
418418
volatility: Volatility | str,
@@ -437,7 +437,7 @@ def __call__(self, *args: Expr) -> Expr:
437437

438438
@staticmethod
439439
def udwf(
440-
func: Type[WindowEvaluator],
440+
func: WindowEvaluator,
441441
input_type: pyarrow.DataType,
442442
return_type: pyarrow.DataType,
443443
volatility: Volatility | str,
@@ -455,12 +455,12 @@ def udwf(
455455
Returns:
456456
A user defined window function.
457457
"""
458-
if not issubclass(func, WindowEvaluator):
458+
if not isinstance(func, WindowEvaluator):
459459
raise TypeError(
460460
"`func` must implement the abstract base class WindowEvaluator"
461461
)
462462
if name is None:
463-
name = func.__qualname__.lower()
463+
name = func.__class__.__qualname__.lower()
464464
return WindowUDF(
465465
name=name,
466466
func=func,

src/udwf.rs

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -186,14 +186,10 @@ impl PartitionEvaluator for RustPartitionEvaluator {
186186
}
187187
}
188188

189-
pub fn to_rust_partition_evaluator(evalutor: PyObject) -> PartitionEvaluatorFactory {
189+
pub fn to_rust_partition_evaluator(evaluator: PyObject) -> PartitionEvaluatorFactory {
190190
Arc::new(move || -> Result<Box<dyn PartitionEvaluator>> {
191-
let evalutor = Python::with_gil(|py| {
192-
evalutor
193-
.call0(py)
194-
.map_err(|e| DataFusionError::Execution(format!("{e}")))
195-
})?;
196-
Ok(Box::new(RustPartitionEvaluator::new(evalutor)))
191+
let evaluator = Python::with_gil(|py| evaluator.clone_ref(py));
192+
Ok(Box::new(RustPartitionEvaluator::new(evaluator)))
197193
})
198194
}
199195

0 commit comments

Comments
 (0)