@@ -218,7 +218,7 @@ def test_udwf_errors(complex_window_df):
218218def  test_udwf_errors_with_message ():
219219    """Test error cases for UDWF creation.""" 
220220    with  pytest .raises (
221-         TypeError , match = "`func` must implement the abstract base class  WindowEvaluator" 
221+         TypeError , match = "`func` must implement the WindowEvaluator protocol " 
222222    ):
223223        udwf (
224224            NotSubclassOfWindowEvaluator , pa .int64 (), pa .int64 (), volatility = "immutable" 
@@ -466,3 +466,51 @@ def test_udwf_named_function(ctx, count_window_df):
466466        FOLLOWING) FROM test_table""" 
467467    ).collect ()[0 ]
468468    assert  result .column (0 ) ==  pa .array ([0 , 1 , 2 ])
469+ 
470+ 
471+ def  test_window_evaluator_protocol (count_window_df ):
472+     """Test that WindowEvaluator works as a Protocol without explicit inheritance.""" 
473+ 
474+     # Define a class that implements the Protocol interface without inheriting 
475+     class  CounterWithoutInheritance :
476+         def  __init__ (self , base : int  =  0 ) ->  None :
477+             self .base  =  base 
478+ 
479+         def  evaluate_all (self , values : list [pa .Array ], num_rows : int ) ->  pa .Array :
480+             return  pa .array ([self .base  +  i  for  i  in  range (num_rows )])
481+ 
482+         # Protocol methods with default implementations don't need to be defined 
483+ 
484+     # Create a UDWF using the class that doesn't inherit from WindowEvaluator 
485+     protocol_counter  =  udwf (
486+         CounterWithoutInheritance , pa .int64 (), pa .int64 (), volatility = "immutable" 
487+     )
488+ 
489+     # Use the window function 
490+     df  =  count_window_df .select (
491+         protocol_counter (column ("a" ))
492+         .window_frame (WindowFrame ("rows" , None , None ))
493+         .build ()
494+         .alias ("count" )
495+     )
496+ 
497+     result  =  df .collect ()[0 ]
498+     assert  result .column (0 ) ==  pa .array ([0 , 1 , 2 ])
499+ 
500+     # Also test with constructor args 
501+     protocol_counter_with_args  =  udwf (
502+         lambda : CounterWithoutInheritance (10 ),
503+         pa .int64 (),
504+         pa .int64 (),
505+         volatility = "immutable" ,
506+     )
507+ 
508+     df  =  count_window_df .select (
509+         protocol_counter_with_args (column ("a" ))
510+         .window_frame (WindowFrame ("rows" , None , None ))
511+         .build ()
512+         .alias ("count" )
513+     )
514+ 
515+     result  =  df .collect ()[0 ]
516+     assert  result .column (0 ) ==  pa .array ([10 , 11 , 12 ])
0 commit comments