Skip to content

Commit 4c397cf

Browse files
committed
feat: Add overloads for udwf function to support multiple input types and decorator syntax
1 parent ae62383 commit 4c397cf

File tree

2 files changed

+89
-0
lines changed

2 files changed

+89
-0
lines changed

python/datafusion/udf.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,27 @@ def __call__(self, *args: Expr) -> Expr:
626626
args_raw = [arg.expr for arg in args]
627627
return Expr(self._udwf.__call__(*args_raw))
628628

629+
@overload
630+
@staticmethod
631+
def udwf(
632+
input_type: pa.DataType | list[pa.DataType],
633+
return_type: pa.DataType,
634+
state_type: list[pa.DataType],
635+
volatility: str,
636+
name: Optional[str] = None,
637+
) -> Callable[..., WindowUDF]: ...
638+
639+
@overload
640+
@staticmethod
641+
def udwf(
642+
windown: Callable[[], WindowEvaluator],
643+
input_type: pa.DataType | list[pa.DataType],
644+
return_type: pa.DataType,
645+
state_type: list[pa.DataType],
646+
volatility: str,
647+
name: Optional[str] = None,
648+
) -> WindowUDF: ...
649+
629650
@staticmethod
630651
def udwf(
631652
*args: Any, **kwargs: Any

python/tests/test_udwf.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,3 +422,71 @@ def test_udwf_functions(complex_window_df, name, expr, expected):
422422
result = df.sort(column("a")).select(column(name)).collect()[0]
423423

424424
assert result.column(0) == pa.array(expected)
425+
426+
427+
def test_udwf_overloads(count_window_df):
428+
"""Test different overload patterns for UDWF function."""
429+
# Single input type syntax
430+
single_input = udwf(
431+
SimpleWindowCount, pa.int64(), pa.int64(), volatility="immutable"
432+
)
433+
434+
# List of input types syntax
435+
list_input = udwf(
436+
SimpleWindowCount, [pa.int64()], pa.int64(), volatility="immutable"
437+
)
438+
439+
# Decorator syntax with single input type
440+
@udwf(pa.int64(), pa.int64(), "immutable")
441+
def window_count_single() -> WindowEvaluator:
442+
return SimpleWindowCount()
443+
444+
# Decorator syntax with list of input types
445+
@udwf([pa.int64()], pa.int64(), "immutable")
446+
def window_count_list() -> WindowEvaluator:
447+
return SimpleWindowCount()
448+
449+
# Test all variants produce the same result
450+
df = count_window_df.select(
451+
single_input(column("a"))
452+
.window_frame(WindowFrame("rows", None, None))
453+
.build()
454+
.alias("single"),
455+
list_input(column("a"))
456+
.window_frame(WindowFrame("rows", None, None))
457+
.build()
458+
.alias("list"),
459+
window_count_single(column("a"))
460+
.window_frame(WindowFrame("rows", None, None))
461+
.build()
462+
.alias("decorator_single"),
463+
window_count_list(column("a"))
464+
.window_frame(WindowFrame("rows", None, None))
465+
.build()
466+
.alias("decorator_list"),
467+
)
468+
469+
result = df.collect()[0]
470+
expected = pa.array([0, 1, 2])
471+
472+
assert result.column(0) == expected
473+
assert result.column(1) == expected
474+
assert result.column(2) == expected
475+
assert result.column(3) == expected
476+
477+
478+
def test_udwf_named_function(ctx, count_window_df):
479+
"""Test UDWF with explicit name parameter."""
480+
window_count = udwf(
481+
SimpleWindowCount,
482+
pa.int64(),
483+
pa.int64(),
484+
volatility="immutable",
485+
name="my_custom_counter",
486+
)
487+
488+
ctx.register_udwf(window_count)
489+
result = ctx.sql(
490+
"SELECT my_custom_counter(a) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) FROM test_table"
491+
).collect()[0]
492+
assert result.column(0) == pa.array([0, 1, 2])

0 commit comments

Comments
 (0)