@@ -246,3 +246,218 @@ def udaf(
246246 state_type = state_type ,
247247 volatility = volatility ,
248248 )
249+
250+
251+ class WindowEvaluator (metaclass = ABCMeta ):
252+ """Evaluator class for user defined window functions (UDWF).
253+
254+ Users should inherit from this class and implement ``evaluate``, ``evaluate_all``,
255+ and/or ``evaluate_all_with_rank``. If using `evaluate` only you will need to
256+ override ``supports_bounded_execution``.
257+ """
258+
259+ def memoize (self ) -> None :
260+ """Perform a memoize operation to improve performance.
261+
262+ When the window frame has a fixed beginning (e.g UNBOUNDED
263+ PRECEDING), some functions such as FIRST_VALUE, LAST_VALUE and
264+ NTH_VALUE do not need the (unbounded) input once they have
265+ seen a certain amount of input.
266+
267+ `memoize` is called after each input batch is processed, and
268+ such functions can save whatever they need
269+ """
270+ pass
271+
272+ def get_range (self , idx : int , n_rows : int ) -> tuple [int , int ]:
273+ """Return the range for the window fuction.
274+
275+ If `uses_window_frame` flag is `false`. This method is used to
276+ calculate required range for the window function during
277+ stateful execution.
278+
279+ Generally there is no required range, hence by default this
280+ returns smallest range(current row). e.g seeing current row is
281+ enough to calculate window result (such as row_number, rank,
282+ etc)
283+
284+ Args:
285+ idx:: Current index
286+ n_rows: Number of rows.
287+ """
288+ return (idx , idx + 1 )
289+
290+ def is_causal (self ) -> bool :
291+ """Get whether evaluator needs future data for its result."""
292+ return False
293+
294+ def evaluate_all (self , values : pyarrow .Array , num_rows : int ) -> pyarrow .Array :
295+ """Evaluate a window function on an entire input partition.
296+
297+ This function is called once per input *partition* for window
298+ functions that *do not use* values from the window frame,
299+ such as `ROW_NUMBER`, `RANK`, `DENSE_RANK`, `PERCENT_RANK`,
300+ `CUME_DIST`, `LEAD`, `LAG`).
301+
302+ It produces the result of all rows in a single pass. It
303+ expects to receive the entire partition as the `value` and
304+ must produce an output column with one output row for every
305+ input row.
306+
307+ `num_rows` is required to correctly compute the output in case
308+ `values.len() == 0`
309+
310+ Implementing this function is an optimization: certain window
311+ functions are not affected by the window frame definition or
312+ the query doesn't have a frame, and `evaluate` skips the
313+ (costly) window frame boundary calculation and the overhead of
314+ calling `evaluate` for each output row.
315+
316+ For example, the `LAG` built in window function does not use
317+ the values of its window frame (it can be computed in one shot
318+ on the entire partition with `Self::evaluate_all` regardless of the
319+ window defined in the `OVER` clause)
320+
321+ ```sql
322+ lag(x, 1) OVER (ORDER BY z ROWS BETWEEN 2 PRECEDING AND 3 FOLLOWING)
323+ ```
324+
325+ However, `avg()` computes the average in the window and thus
326+ does use its window frame
327+
328+ ```sql
329+ avg(x) OVER (PARTITION BY y ORDER BY z ROWS BETWEEN 2 PRECEDING AND 3 FOLLOWING)
330+ ```
331+ """
332+ if self .supports_bounded_execution () and not self .uses_window_frame ():
333+ res = []
334+ for idx in range (0 , num_rows ):
335+ res .append (self .evaluate (values , self .get_range (idx , num_rows )))
336+ return pyarrow .array (res )
337+ else :
338+ raise
339+
340+ @abstractmethod
341+ def evaluate (self , values : pyarrow .Array , range : tuple [int , int ]) -> pyarrow .Scalar :
342+ """Evaluate window function on a range of rows in an input partition.
343+
344+ This is the simplest and most general function to implement
345+ but also the least performant as it creates output one row at
346+ a time. It is typically much faster to implement stateful
347+ evaluation using one of the other specialized methods on this
348+ trait.
349+
350+ Returns a [`ScalarValue`] that is the value of the window
351+ function within `range` for the entire partition. Argument
352+ `values` contains the evaluation result of function arguments
353+ and evaluation results of ORDER BY expressions. If function has a
354+ single argument, `values[1..]` will contain ORDER BY expression results.
355+ """
356+ pass
357+
358+ @abstractmethod
359+ def evaluate_all_with_rank (
360+ self , num_rows : int , ranks_in_partition : list [tuple [int , int ]]
361+ ) -> pyarrow .Array :
362+ """Called for window functions that only need the rank of a row.
363+
364+ Evaluate the partition evaluator against the partition using
365+ the row ranks. For example, `RANK(col)` produces
366+
367+ ```text
368+ col | rank
369+ --- + ----
370+ A | 1
371+ A | 1
372+ C | 3
373+ D | 4
374+ D | 5
375+ ```
376+
377+ For this case, `num_rows` would be `5` and the
378+ `ranks_in_partition` would be called with
379+
380+ ```text
381+ [
382+ (0,1),
383+ (2,2),
384+ (3,4),
385+ ]
386+ """
387+ pass
388+
389+ def supports_bounded_execution (self ) -> bool :
390+ """Can the window function be incrementally computed using bounded memory?"""
391+ return False
392+
393+ def uses_window_frame (self ) -> bool :
394+ """Does the window function use the values from the window frame?"""
395+ return False
396+
397+ def include_rank (self ) -> bool :
398+ """Can this function be evaluated with (only) rank?"""
399+ return False
400+
401+
402+ class WindowUDF :
403+ """Class for performing window user defined functions (UDF).
404+
405+ Window UDFs operate on a partition of rows. See
406+ also :py:class:`ScalarUDF` for operating on a row by row basis.
407+ """
408+
409+ def __init__ (
410+ self ,
411+ name : str | None ,
412+ func : WindowEvaluator ,
413+ input_type : pyarrow .DataType ,
414+ return_type : _R ,
415+ volatility : Volatility | str ,
416+ ) -> None :
417+ """Instantiate a user defined window function (UDWF).
418+
419+ See :py:func:`udwf` for a convenience function and argument
420+ descriptions.
421+ """
422+ self ._udwf = df_internal .WindowUDF (
423+ name , func , input_type , return_type , str (volatility )
424+ )
425+
426+ def __call__ (self , * args : Expr ) -> Expr :
427+ """Execute the UDWF.
428+
429+ This function is not typically called by an end user. These calls will
430+ occur during the evaluation of the dataframe.
431+ """
432+ args_raw = [arg .expr for arg in args ]
433+ return Expr (self ._udwf .__call__ (* args_raw ))
434+
435+ @staticmethod
436+ def udwf (
437+ func : Callable [..., _R ],
438+ input_type : pyarrow .DataType ,
439+ return_type : _R ,
440+ volatility : Volatility | str ,
441+ name : str | None = None ,
442+ ) -> WindowUDF :
443+ """Create a new User Defined Window Function.
444+
445+ Args:
446+ func: The python function.
447+ input_type: The data type of the arguments to ``func``.
448+ return_type: The data type of the return value.
449+ volatility: See :py:class:`Volatility` for allowed values.
450+ name: A descriptive name for the function.
451+
452+ Returns:
453+ A user defined window function.
454+ """
455+ if name is None :
456+ name = func .__qualname__ .lower ()
457+ return WindowUDF (
458+ name = name ,
459+ func = func ,
460+ input_type = input_type ,
461+ return_type = return_type ,
462+ volatility = volatility ,
463+ )
0 commit comments