@@ -218,7 +218,7 @@ def test_udwf_errors(complex_window_df):
218218def test_udwf_errors_with_message ():
219219 """Test error cases for UDWF creation."""
220220 with pytest .raises (
221- TypeError , match = "`func` must implement the WindowEvaluator protocol "
221+ TypeError , match = "`func` must implement the abstract base class WindowEvaluator "
222222 ):
223223 udwf (
224224 NotSubclassOfWindowEvaluator , pa .int64 (), pa .int64 (), volatility = "immutable"
@@ -466,51 +466,3 @@ def test_udwf_named_function(ctx, count_window_df):
466466 FOLLOWING) FROM test_table"""
467467 ).collect ()[0 ]
468468 assert result .column (0 ) == pa .array ([0 , 1 , 2 ])
469-
470-
471- def test_window_evaluator_protocol (count_window_df ):
472- """Test that WindowEvaluator works as a Protocol without explicit inheritance."""
473-
474- # Define a class that implements the Protocol interface without inheriting
475- class CounterWithoutInheritance :
476- def __init__ (self , base : int = 0 ) -> None :
477- self .base = base
478-
479- def evaluate_all (self , values : list [pa .Array ], num_rows : int ) -> pa .Array :
480- return pa .array ([self .base + i for i in range (num_rows )])
481-
482- # Protocol methods with default implementations don't need to be defined
483-
484- # Create a UDWF using the class that doesn't inherit from WindowEvaluator
485- protocol_counter = udwf (
486- CounterWithoutInheritance , pa .int64 (), pa .int64 (), volatility = "immutable"
487- )
488-
489- # Use the window function
490- df = count_window_df .select (
491- protocol_counter (column ("a" ))
492- .window_frame (WindowFrame ("rows" , None , None ))
493- .build ()
494- .alias ("count" )
495- )
496-
497- result = df .collect ()[0 ]
498- assert result .column (0 ) == pa .array ([0 , 1 , 2 ])
499-
500- # Also test with constructor args
501- protocol_counter_with_args = udwf (
502- lambda : CounterWithoutInheritance (10 ),
503- pa .int64 (),
504- pa .int64 (),
505- volatility = "immutable" ,
506- )
507-
508- df = count_window_df .select (
509- protocol_counter_with_args (column ("a" ))
510- .window_frame (WindowFrame ("rows" , None , None ))
511- .build ()
512- .alias ("count" )
513- )
514-
515- result = df .collect ()[0 ]
516- assert result .column (0 ) == pa .array ([10 , 11 , 12 ])
0 commit comments