@@ -532,3 +532,35 @@ def list_agg(
532532 )
533533 ]
534534 )
535+
536+
537+ def list_sort (
538+ array : ChunkedArrayAny , * , descending : bool , nulls_last : bool
539+ ) -> ChunkedArrayAny :
540+ sort_direction : Literal ["ascending" , "descending" ] = (
541+ "descending" if descending else "ascending"
542+ )
543+ nulls_position : Literal ["at_start" , "at_end" ] = "at_end" if nulls_last else "at_start"
544+ idx , v = "idx" , "values"
545+ is_not_sorted = pc .greater (pc .list_value_length (array ), lit (0 ))
546+ indexed = pa .Table .from_arrays (
547+ [arange (start = 0 , end = len (array ), step = 1 ), array ], names = [idx , v ]
548+ )
549+ not_sorted_part = indexed .filter (is_not_sorted )
550+ pass_through = indexed .filter (pc .fill_null (pc .invert (is_not_sorted ), lit (True ))) # pyright: ignore[reportArgumentType]
551+ exploded = pa .Table .from_arrays (
552+ [pc .list_flatten (array ), pc .list_parent_indices (array )], names = [v , idx ]
553+ )
554+ sorted_indices = pc .sort_indices (
555+ exploded ,
556+ sort_keys = [(idx , "ascending" ), (v , sort_direction )],
557+ null_placement = nulls_position ,
558+ )
559+ offsets = not_sorted_part .column (v ).combine_chunks ().offsets # type: ignore[attr-defined]
560+ sorted_imploded = pa .ListArray .from_arrays (
561+ offsets , pa .array (exploded .take (sorted_indices ).column (v ))
562+ )
563+ imploded_by_idx = pa .Table .from_arrays (
564+ [not_sorted_part .column (idx ), sorted_imploded ], names = [idx , v ]
565+ )
566+ return pa .concat_tables ([imploded_by_idx , pass_through ]).sort_by (idx ).column (v )
0 commit comments