@@ -422,3 +422,71 @@ def test_udwf_functions(complex_window_df, name, expr, expected):
422422 result = df .sort (column ("a" )).select (column (name )).collect ()[0 ]
423423
424424 assert result .column (0 ) == pa .array (expected )
425+
426+
427+ def test_udwf_overloads (count_window_df ):
428+ """Test different overload patterns for UDWF function."""
429+ # Single input type syntax
430+ single_input = udwf (
431+ SimpleWindowCount , pa .int64 (), pa .int64 (), volatility = "immutable"
432+ )
433+
434+ # List of input types syntax
435+ list_input = udwf (
436+ SimpleWindowCount , [pa .int64 ()], pa .int64 (), volatility = "immutable"
437+ )
438+
439+ # Decorator syntax with single input type
440+ @udwf (pa .int64 (), pa .int64 (), "immutable" )
441+ def window_count_single () -> WindowEvaluator :
442+ return SimpleWindowCount ()
443+
444+ # Decorator syntax with list of input types
445+ @udwf ([pa .int64 ()], pa .int64 (), "immutable" )
446+ def window_count_list () -> WindowEvaluator :
447+ return SimpleWindowCount ()
448+
449+ # Test all variants produce the same result
450+ df = count_window_df .select (
451+ single_input (column ("a" ))
452+ .window_frame (WindowFrame ("rows" , None , None ))
453+ .build ()
454+ .alias ("single" ),
455+ list_input (column ("a" ))
456+ .window_frame (WindowFrame ("rows" , None , None ))
457+ .build ()
458+ .alias ("list" ),
459+ window_count_single (column ("a" ))
460+ .window_frame (WindowFrame ("rows" , None , None ))
461+ .build ()
462+ .alias ("decorator_single" ),
463+ window_count_list (column ("a" ))
464+ .window_frame (WindowFrame ("rows" , None , None ))
465+ .build ()
466+ .alias ("decorator_list" ),
467+ )
468+
469+ result = df .collect ()[0 ]
470+ expected = pa .array ([0 , 1 , 2 ])
471+
472+ assert result .column (0 ) == expected
473+ assert result .column (1 ) == expected
474+ assert result .column (2 ) == expected
475+ assert result .column (3 ) == expected
476+
477+
478+ def test_udwf_named_function (ctx , count_window_df ):
479+ """Test UDWF with explicit name parameter."""
480+ window_count = udwf (
481+ SimpleWindowCount ,
482+ pa .int64 (),
483+ pa .int64 (),
484+ volatility = "immutable" ,
485+ name = "my_custom_counter" ,
486+ )
487+
488+ ctx .register_udwf (window_count )
489+ result = ctx .sql (
490+ "SELECT my_custom_counter(a) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) FROM test_table"
491+ ).collect ()[0 ]
492+ assert result .column (0 ) == pa .array ([0 , 1 , 2 ])
0 commit comments